在模型蒸馏过程中,Kullback-Leibler (KL) 散度用于衡量两个概率分布之间的差异。在这个情境下,我们通常会计算学生模型输出的概率分布与教师模型输出的概率分布之间的KL散度。
KL 散度定义
对于两个概率分布 PPP 和 QQQ,KL 散度 DKL(P∥Q)D_{KL}(P \| Q)DKL(P∥Q) 定义为:
DKL(P∥Q)=∑iP(i)logP(i)Q(i)D_{KL}(P \| Q) = \sum_{i} P(i) \log \frac{P(i)}{Q(i)}DKL(P∥Q)=∑iP(i)logQ(i)P(i)
在模型蒸馏中, PPP 是教师模型输出的概率分布, QQQ 是学生模型输出的概率分布。
在 PyTorch 中实现 KL 散度
在 PyTorch 中,可以使用 torch.nn.functional.kl_div
函数来计算 KL 散度。注意,在使用这个函数时,输入需要是对数概率(logits)。通常情况下,教师模型的输出会先通过 softmax 函数得到概率分布,然后再取对数。
以下是一个示例,展示如何在模型蒸馏过程中定义和使用 KL 散度:
python复制代码import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义 KL 散度损失函数
def kl_divergence(student_logits, teacher_logits, temperature):
# 计算软化后的概率分布
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
# 使用 PyTorch 内置的 kl_div 计算 KL 散度
kl_div = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
return kl_div
# 示例训练过程
temperature = 2.0 # 温度参数,通常选择大于1的值
alpha = 0.5 # 权衡参数,控制交叉熵损失和 KL 散度损失的比例
# 假设我们有教师模型和学生模型的输出 logits
teacher_logits = torch.randn(32, 10000) # 教师模型输出 (batch_size, vocab_size)
student_logits = torch.randn(32, 10000) # 学生模型输出 (batch_size, vocab_size)
# 计算 KL 散度损失
kl_loss = kl_divergence(student_logits, teacher_logits, temperature)
# 计算交叉熵损失
labels = torch.randint(0, 10000, (32,)) # 假设我们有真实标签 (batch_size,)
ce_loss = F.cross_entropy(student_logits, labels)
# 综合损失
loss = alpha * ce_loss + (1 - alpha) * kl_loss
# 反向传播和优化
loss.backward()
optimizer.step()
解释
- 温度参数 (temperature):用于控制 softmax 输出的平滑程度。较高的温度会使概率分布更加平滑,更容易学习到细微的概率信息。
- 软化后的概率分布:通过
F.softmax(teacher_logits / temperature, dim=-1)
计算教师模型的概率分布,通过F.log_softmax(student_logits / temperature, dim=-1)
计算学生模型的对数概率分布。 - KL 散度损失计算:使用
torch.nn.functional.kl_div
函数计算学生模型和教师模型之间的 KL 散度。 - 综合损失:最终的损失函数是交叉熵损失和 KL 散度损失的加权和,其中
alpha
是权衡参数,控制这两部分损失的比例。
通过上述步骤,学生模型可以通过优化损失函数,学习到教师模型的知识,从而实现模型蒸馏的目的。