学习报告:原型网络用于小样本学习
本篇学习报告基于论文《Prototypical Networks for Few-shot Learning》,作者是来自多伦多大学的Jake Snell、推特用户Kevin Swersky和来自多伦多大学和矢量研究所的Richard S. Zemel。该论文中介绍了一种基于原型的神经网络模型,该模型在小样本学习任务上取得了良好的性能,并且为解决实际应用中样本稀缺的问题提供了一种有效的方法。
小样本学习是指在只有很少样本数据的情况下进行学习和分类任务。传统的深度学习算法在这种情况下表现较差,因为它们需要大量的训练样本来建立准确的模型。而原型网络通过学习类别的原型向量来解决这个问题。在原型网络中,每个类别都被表示为原型向量,该向量是该类别所有样本特征向量的平均值。在训练阶段,网络通过优化损失函数来调整各个类别的原型向量,使得在特征空间中不同类别之间的距离最大化。
一、背景
在许多现实场景中,我们经常面临只有很少样本数据的情况下需要进行学习和分类任务的挑战。传统的深度学习算法在这种情况下往往表现不佳,因为它们需要大量的训练样本来建立准确的模型。
为了解决小样本学习问题,文章提出了一种名为"Prototypical Networks"的神经网络模型,用于小样本学习。该模型基于原型的思想,通过学习每个类别的原型向量来进行分类。通过优化原型向量之间的距离,该模型能够在特征空间中有效地区分不同的类别,从而实现准确的分类。
二、方法
在原型网络中,使用一个神经网络来学习将输入数据映射到特定嵌入空间中。通过在嵌入空间中将类别聚类在原型周围,可以实现在小样本情况下的准确分类。同时,在零样本学习中,通过学习元数据的嵌入表示,可以在没有标签样本的情况下进行分类。该嵌入空间如图所示。通过找到嵌入查询点x最近的类别原型实现分类任务,在实验中距离的选择非常重要,欧氏距离明显优于更常用的余弦相似度。
在上图中,每个原型ck是其所属类别的嵌入支持点的均值向量。其计算公式如下:
其中Sk为支持集,(xi,yi)为样本的特征向量和标记,fΦ为嵌入函数。
原型网络通过对嵌入空间中到原型的距离进行 softmax 处理得出概率pΦ,公式如下:
其中d为距离函数,采用的是欧氏距离。公式中求出x到每个原型的距离后归一化得出x属于各个类的概率分布。
学习过程通过使用随机梯度下降(SGD)最小化真实类别 k 的负对数概率 J(Φ) = -log pΦ(y = k | x) 来进行。计算训练集的损失 J(Φ) 的伪代码如下:
从伪代码中可以看出,算法输入训练集D后输出损失J。计算过程为选择训练集类别、选择支持集、选择查询集、从支持集中计算原型、初始化损失为0、更新损失。通过执行以上伪代码,可以计算出一个随机生成的训练集的损失J,用于在训练过程中更新模型参数以最小化损失。
三、实验
针对少样本学习,作者在Omniglot和 ILSVRC-2012的miniImageNet版本上进行了实验,同时还在Caltech UCSD鸟类数据集的2011版本上进行了零样本实验。
1.Omnilot少样本分类
Omniglot是一个由50个字母表中的1623个手写字符组成的数据集。每个字符有20个示例,每个示例由不同的人绘制。作者采用的嵌入式架构由四个卷积块组成,每个块包括一个64个过滤器的3×3卷积层、批量归一化层、ReLU非线性激活函数和一个2×2最大池化层。使用欧式距离在1-shot和5-shot的场景下训练了原型网络并计算了模型在测试集上通过1000个随机生成的episode进行分类准确率的平均值。结果如下所示。
2.miniImageNet少样本分类
miniImageNet数据集最初由Vinyals等人提出,是基于较大的ILSVRC-12数据集产生的。为了与最先进的少样本学习算法进行直接比较,作者使用了Ravi和Larochelle引入的划分方法。实验对比结果如下,很明显原型网络取得了最先进的结果。
作者进一步分析了距离度量和每个episode中训练类别数量对原型网络和匹配网络性能的影响。实验结果如下。
由图象可以看出20-way的准确性高于5-way,可以猜想是由于难度增加使得模型做出更细致的决策,这有助于网络更好的泛化。也可以看出使用欧式距离的效果好于余弦距离,因此本文选择欧式距离作为距离函数。
3.CUB零样本分类
CUB数据集包含11,788张来自200种鸟类的图像。在此数据集上原型网络运用于零样本学习的实验结果如下。可以看出原型网络可以很好的运用到零样本学习上并取得最先进的结果。
四、总结
原型网络的基本思想是通过神经网络在学习到的嵌入空间中,用每个类别的样本均值来表示该类别。该方法的突出优点是更简单、更高效,并能产生领先的结果。还能推广到零样本学习上并取得业界领先的结果。总体而言,原型网络的简单性和有效性使其成为一种有前景的少样本学习方法。
撰稿人:马一鸣
审稿人:李景聪