神经网络知识蒸馏初探


知识蒸馏(Knowledge Distillation)

知识蒸馏通常用于模型压缩,用一个已经训练好的模型A去“教”另外一个模型B。这两个模型称为老师-学生模型。

通常模型A比模型B更强。在模型A的帮助下,模型B可以突破自我,学得更好。

PyTorch相关函数

•Softmax:将一个数值序列映射到概率空间

•log_softmax:在softmax的基础上取对数

•NLLLoss:对log_softmax与one-hot进行计算

•CrossEntropy:衡量两个概率分布的差别

Softmax

image-20231027220753202

例如:

image-20231027220840530

注意:此处的 output 仅仅是一个得分,即得分越大,神经网络认为其更属于该类。

image-20231027221022874

log_softmax

这个很好理解,其实就是对softmax处理之后的结果执行一次对数运算。

可以理解为 log(softmax(output))

image-20231027221209144

注意:由于softmax返回的是概率,即0-1之间的数,那么0-1之间的数取对数当然就是负数了。

那么我们什么时候需要用到这个函数呢?

一种常见的情况是在多类别分类任务中使用交叉熵损失函数(例如下文提及到的NLLLoss负对数似然损失函数)。交叉熵损失函数需要输入为对数概率,而不是原始概率。因此,在训练模型时,经常需要将模型的输出转换为对数概率。

NLLLoss

image-20231027222353355

image-20231027222711821

例如上图中的 torch.tensor([0]) 代表这张图片真实的结果所在的标签,即它是属于第0类的,那么该函数就会返回一个损失值。

其实计算的方法非常简单,就是在 [-1.2, -2, -3] 中找到下标为0的数,然后取负数。

通常我们结合 log_softmax 和 nll_loss一起用

image-20231027223749825

例如上图,神经网络认为输出中索引为2的值得分最高,如果真实数据的标签也是索引为2,那么神经网络就认为输出属于索引为2的这一类的概率是最大的,所以损失就是最小的。

CrossEntropy

在分类问题中,CrossEntropy等价于log_softmax 结合 nll_loss

image-20231027224501903

例如上图,我们已经知道 NLLLoss 返回的结果越大,说明此数据和真实数据相差的程度就越大,从而它们两个的乘积就会越大;如果二者相等,那么乘积就会变成0。

one-hot编码形式:

在One-hot编码中,如果一个分类变量有n个可能的取值,那么该变量将被编码为一个长度为n的二进制向量,其中只有一个元素为1,表示当前取值的索引位置,其他元素都为0。

例如,假设有一个颜色变量,可能的取值为[“红色”, “蓝色”, “绿色”]。使用One-hot编码,可以将这个变量编码为以下三个向量之一:

- 红色: [1, 0, 0]

- 蓝色: [0, 1, 0]

- 绿色: [0, 0, 1]

总结

log_softmax + NLLLoss 等价于 CrossEntropy

output = torch.tensor([[1.2, 2, 3]])
target = torch.tensor([0])

log_sm_output = F.log_softmax(output, dim=1)
nll_loss_of_log_sm_output = F.nll_loss(log_sm_output, target)
print(nll_loss_of_log_sm_output)
output = torch.tensor([[1.2, 2, 3]])
target = torch.tensor([0])

ce_loss = F.cross_entropy(output, target)
print(ce_loss)

两者计算结果都是一样的。

Softmax相关知识

  • 初始实现

    import numpy as np
    
    def softmax(x):
        x_exp = np.exp(x)
        return x_exp / np.sum(x_exp)
    
    output = np.array([0.1, 1.6, 3.6])
    print(softmax(output))
    
    # [0.02590865 0.11611453 0.85797681]
  • 改造升级

    def softmax_t(x, t):
        x_exp = np.exp(x / t)
        return x_exp / np.sum(x_exp)
    
    output = np.array([0.1, 1.6, 3.6])
    print(softmax_t(output, 5))
    
    # [0.22916797 0.3093444  0.46148762]

    观察输出结果可以看到,加入了 t 后,输出的概率分布更平滑了,大的概率变小,小的概率变大了,t 越大,概率直接的差距就越小。

老师网络

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data


torch.manual_seed(0)
torch.cuda.manual_seed(0)

class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.3)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output

注意:在前向函数中没有经过上述提到的任何函数加工,而是直接输出。

def train_teacher(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')


def test_teacher(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

注意:cross_entropy的参数中也是使用了直接输出后的 output

根据我们之前的理论:log_softmax + NLLLoss 等价于 CrossEntropy,所以如果我们在前向函数中使用了log_softmax,那么在训练和测试中使用的损失函数就要是 NLLLoss。

def teacher_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)

    model = TeacherNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())
    
    teacher_history = []

    for epoch in range(1, epochs + 1):
        train_teacher(model, device, train_loader, optimizer, epoch)
        loss, acc = test_teacher(model, device, test_loader)
        
        teacher_history.append((loss, acc))

    torch.save(model.state_dict(), "teacher.pt")
    return model, teacher_history
# 训练教师网络

teacher_model, teacher_history = teacher_main()

# 输出
Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0510, accuracy: 9832/10000 (98%)
Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0364, accuracy: 9882/10000 (99%)
Train epoch 3: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0354, accuracy: 9883/10000 (99%)
Train epoch 4: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0325, accuracy: 9902/10000 (99%)
Train epoch 5: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0314, accuracy: 9897/10000 (99%)
Train epoch 6: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0330, accuracy: 9898/10000 (99%)
Train epoch 7: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0325, accuracy: 9908/10000 (99%)
Train epoch 8: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0322, accuracy: 9906/10000 (99%)
Train epoch 9: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0330, accuracy: 9910/10000 (99%)
Train epoch 10: 60000/60000, [-------------------------------------------------->] 100%
Test: average loss: 0.0302, accuracy: 9906/10000 (99%)

那么如何把老师网络中的知识蒸馏出来呢?

通过蒸馏老师网络中的知识也被称为老师网络的暗知识:

import numpy as np
from matplotlib import pyplot as plt

def softmax_t(x, t):
    x_exp = np.exp(x / t)
    return x_exp / np.sum(x_exp)

test_loader_bs1 = torch.utils.data.DataLoader(
    datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=1, shuffle=True) # batch_size = 1,每次只拿一张图片出来
teacher_model.eval()
with torch.no_grad():
    data, target = next(iter(test_loader_bs1))
    data, target = data.to('cuda'), target.to('cuda')
    output = teacher_model(data)

test_x = data.cpu().numpy()
y_out = output.cpu().numpy()
y_out = y_out[0, ::]
print('Output (NO softmax):', y_out) # 先将原输出直接打印出来,为10个实数

# 输出
Output (NO softmax): [-31.14481   -30.600847   -3.2787514 -20.624037  -31.863455  -37.684086
 -35.177486  -22.72263   -16.028662  -26.460657 ]
# 可以观察到索引为2的值是最大的


# 将这个图片展示出来
plt.subplot(3, 1, 1)
plt.imshow(test_x[0, 0, ::])
# 绘制直方图,此时是经过加强版的softmax函数处理过的,t = 1
plt.subplot(3, 1, 2)
plt.bar(list(range(10)), softmax_t(y_out, 1), width=0.3)
# 绘制直方图,此时是经过加强版的softmax函数处理过的,t = 10
plt.subplot(3, 1, 3)
plt.bar(list(range(10)), softmax_t(y_out, 10), width=0.3)
plt.show()

上述代码对应的图:

image-20231027231657077

可以观察得出,t = 1时其他标签的相似度都是接近于0,而t = 10时其他标签的相似度就有体现了,尤其在下图更为明显:

image-20231027233137704

那么接下来就是学生的工作:学生会明白原来这个数字和3和5都有些相似。

所称知识蒸馏,原来参数 t 就是温度,温度越高,暗知识就被蒸馏出的越多。

老师和学生的关系

image-20231027233533986

解释一下:

  • 如果没有老师的帮助,此时学生网络通过前向迭代,需要用Loss来衡量预测值和真实值之间的损失,就要使用softmax这个函数进行计算,得出的q是一个概率分布,最终求出p和q之间的Loss,当然我们希望这个Loss越小越好。

    此处的Loss称为HARD Loss,因为p是one-hot编码形式。

  • 如果有老师的帮助,那我们最终得出的Loss就不只是一个HARD Loss 了,而是有一部分还来自于老师蒸馏出来的SOFT Loss,二者相互结合,才得到最终的Loss。

定义学生网络

class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = F.relu(self.fc3(x))
        return output

注意:此处学生网络比较简单,而老师网络比较复杂,为神经网络。

定义汇总的Loss

def distillation(y, labels, teacher_scores, temp, alpha):
    return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
            temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

distillation接收5个参数:

  • 学生没有经过softmax的输出
  • 真实标签
  • 老师没有经过softmax的输出
  • 温度
  • 全能因子
  • +号前面的是 SOFT Loss,后面的是HARD Loss。
  • F.softmax(teacher_scores / temp, dim=1)):老师输出经过蒸馏,平滑过的softmax,提取出暗知识。
  • F.log_softmax(y / temp, dim=1):学生输出经过蒸馏,平滑过的softmax。
  • nn.KLDivLoss()

通过KD的学生网络训练和测试过程

def train_student_kd(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        teacher_output = teacher_model(data)
        teacher_output = teacher_output.detach()  # 切断老师网络的反向传播
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)
        loss.backward()
        optimizer.step()

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')


def test_student_kd(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

注意:为什么我们要切断老师网络的反向传播?切断反向传播的目的有以下几个方面的考虑:

  1. 模型压缩:知识蒸馏的目标之一是将一个复杂的模型(老师网络)的知识传递给一个简化的模型(学生网络),以减少模型的复杂性和计算资源的消耗。如果学生网络直接利用老师的梯度信息来学习,学生网络可能会过度拟合老师网络的复杂结构,导致学生网络无法真正学习到简化的知识表示。

  2. 提高泛化能力:通过切断反向传播,学生网络只能通过老师网络的输出结果来学习,而不是直接利用梯度信息。这样做可以帮助学生网络更好地泛化到新的样本,而不是过度依赖于老师网络的特定梯度信息。

  3. 独立性:切断反向传播可以使学生网络独立地学习,不再依赖于老师网络的参数更新。这样,学生网络可以根据自身的特点和目标进行优化,而不受老师网络的限制。

def student_kd_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)

    model = StudentNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())
    
    student_history = []
    for epoch in range(1, epochs + 1):
        train_student_kd(model, device, train_loader, optimizer, epoch)
        loss, acc = test_student_kd(model, device, test_loader)
        student_history.append((loss, acc))

    torch.save(model.state_dict(), "student_kd.pt")
    return model, student_history
student_kd_model, student_kd_history = student_kd_main()

没有通过KD的学生网络训练和测试过程

## 让学生自己学,不使用KD
def train_student(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')


def test_student(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)
def student_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)

    model = StudentNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())
    
    student_history = []
    
    for epoch in range(1, epochs + 1):
        train_student(model, device, train_loader, optimizer, epoch)
        loss, acc = test_student(model, device, test_loader)
        student_history.append((loss, acc))

    torch.save(model.state_dict(), "student.pt")
    return model, student_history
student_simple_model, student_simple_history = student_main()

对比结果

import matplotlib.pyplot as plt

epochs = 10
x = list(range(1, epochs+1))

plt.subplot(2, 1, 1)
plt.plot(x, [teacher_history[i][1] for i in range(epochs)], label='teacher')
plt.plot(x, [student_kd_history[i][1] for i in range(epochs)], label='student with KD')
plt.plot(x, [student_simple_history[i][1] for i in range(epochs)], label='student without KD')

plt.title('Test accuracy')
plt.legend()


plt.subplot(2, 1, 2)
plt.plot(x, [teacher_history[i][0] for i in range(epochs)], label='teacher')
plt.plot(x, [student_kd_history[i][0] for i in range(epochs)], label='student with KD')
plt.plot(x, [student_simple_history[i][0] for i in range(epochs)], label='student without KD')

plt.title('Test loss')
plt.legend()

image-20231028000559484

终于,我们看到了,在老师的帮助下,学生的预测准确率比自学要高很多(当然没老师高),而且最关键的是学生的网络更加轻量化,从代码就可以看出,老师网络是一个相当复杂的神经网络,而学生网络只是一个线性网络——通过牺牲一定的准确度,来换取轻量化。

KD函数的原理

image-20231028001140019

首先是L2正则化计算(假设为分类问题),公式为预测值和真实值的交叉熵 + 正则化项(权重的L2范数 -> 参数向量W的平方和的平方根)。

通过增加正则化项,模型的优化过程会倾向于选择较小的参数值(因为参数会根据训练进行调整),因为我们的目的是为了最小化损失函数,首先我们想要交叉熵越小越好,因为我们期望预测值和真实值越接近越好;那么正则化项当然也要尽量最小化才能让最后的结果Loss最小,所以导致模型的参数被约束在了一个较小的值,不会变得过大。

那么为什么要让参数较小一些比较好呢?

  1. 过拟合:当模型的参数过大时,模型可能会过度拟合训练数据,而无法很好地泛化到未见过的数据。过大的参数值会导致模型过于复杂,过于敏感地去拟合训练数据的细节,从而在训练数据上表现很好,但在新数据上表现较差。
  2. 训练不稳定:参数过大可能导致训练过程不稳定。在优化过程中,梯度下降算法会尝试通过更新参数来最小化损失函数。当参数过大时,梯度下降算法可能会遇到梯度爆炸的问题,导致训练过程不收敛或收敛速度很慢。
  3. 过大的计算开销:参数过大会增加模型的计算开销。大量的参数需要更多的内存来存储和计算,这可能会增加模型的训练和推理时间,降低模型的效率。
  4. 过大的模型尺寸:参数过大会导致模型的尺寸变大。大型模型需要更多的存储空间来保存参数,这可能会对模型的部署和使用造成不便。

当然也不意味着L2正则化会选择非常小的参数值,可以通过正则化超参数λ控制正则化的强度,如何找到一个平衡点,又得需要其他技术了。

那么KD函数其实就是将这个权重的L2范数变成了SOFT Loss,第一个参数为学生网络的输出,后一个参数为老师网络的输出。

那么这样做的好处也是显而易见,SOFT Loss 比起学生自学的 HARD Loss 是要小的,所以保证了最终的 loss 也变得小,达成了我们的目的。


文章作者: QT-7274
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 QT-7274 !
评论
  目录