学习报告:FlexMatch:使用课程式伪标签的策略来改进半监督学习

        本篇学习报告《FlexMatch: Boosting Semi-Supervised Learning with C urriculum Pseudo Labeling》来自机器学习国际权威会议NeurIPS 2021,第一作者为日本东京工业大学的张博闻和王一栋,其他作者来自东京工业大学和微软亚洲研究院。

一、算法背景

        近年来,半监督学习(SSL)由于其在利用大量未标记数据方面的优势而越来越受到关注。目前基于伪标签技术的半监督学习算法,往往设定一个高且固定的阈值,如果模型针对无标注样本的置信度超过设定的阈值,才会给其赋予一个伪标签。这种策略确保只有高质量的无标签数据才能参与模型训练,可以滤除大量的噪声数据标签。

        但这种方法忽略了大量的其他未被标记数据,这种情况在早期阶段最为严重,因为只有少数未标记数据的预测置信度高于阈值。此外,不少SSL算法平等地处理所有类,而不考虑它们不同的学习困难。

        本文作者认为,这种高且固定的阈值存在一定的问题,这将导致模型在处理不同训练状态和不同类别的训练难度缺乏区分,进而导致性能较差,因此作者提出了一种课程式学习方法——课程式伪标签(Curriculum Pseudo Labeling, CPL)。

二、主要思想

        为了解决这些问题,作者提出了课程式伪标签(CPL)的学习策略,以在半监督学习中考虑每个类的学习状态。CPL用灵活的阈值代替预定义的阈值,灵活的阈值根据每个类当前的学习状态动态调整。CPL显着提高了几种流行的SSL算法在常见基准上的准确性和收敛性能。将CPL学习策略直接应用于FixMatch,就得到了FlexMatch。

三、算法详解

①CPL方案

        当前的SSL算法只对由预定义阈值所筛选的高置信度未标记数据生成伪标签,而CPL则对不同的类别不同训练来时刻更新阈值并生成伪标签。这个过程是根据模型中每个类的学习状态来调整阈值来实现的。

        然而,根据学习状态来动态调整阈值并非易事。最理想的方法是为每个类别计算评估准确性,并使用它们来调整阈值,如下所示:

        其中,Tt(c)代表了当时间为t时类别c的灵活阈值,at(c)则代表了相应的评估精度,τ代表阈值。这个将鼓励学习更多低精度的样本。

        然而却存在两个问题:首先,在SSL场景下,因为带标签的数据已经很稀缺了,所以这种与训练集分离的带标签的验证集是昂贵的。其次,为了在训练过程中动态调整阈值,必须在每个时间步长t上连续进行精度评估,这将大大降低训练速度。

        而CPL能做到不需要加入额外推理过程、不需要额外验证集,即可改善上述问题。其结构如图所示:

        其基于一个理论:当阈值很高时,一个类的学习效果可以通过其预测超过阈值且属于该类别的样本数量来反映。换言之,当一个类只有较少数量的样本达到置信度则被视为学习难度较大。其公式如下:

        其中σt(c)代表第c类在t时刻的学习效果,也就是所有样本中对高于固定阈值且属于类别c的样本数。pm,t(y|un)是模型在时间步长t对未标记数据un的预测。N是未标记数据的总数。σt(c)越大表明估计的学习效果越好。通过归一化,使其范围在0至1之间,然后可以用它来缩放固定阈值τ,其公式如下:

        从上边归一化式子可以看出,当学习的最好的类别,缩放比例βt(c)为1。比较难学习的类别其阈值将会降低,从而使其更多的样本参与到学习过程。如果在以后的迭代中将未标记的数据分类到不同的类中,阈值也可能会减少。这个新的阈值用于计算FlexMatch中的无监督损失,可以表示为:

        其中qb=pm(y|ω(ub))。这些阈值在每次迭代中都会更新,这里用了Tt(arg max(qb))替代了Fixmatch中的固定阈值,其余均相同。两个损失函数相结合成总的损失函数(λ为权重):

        标注数据的损失函数与Fixmatch相同,如下所示:

②阈值的warm up

        在训练的初始阶段,模型可能把大多数无标签数据预测成一个类,因此要采用阈值的warm up,其公式如下:

        公式将“学习效果最大值”改为“学习效果最大值”(max σt ) 和 “尚未被选择过的样本数“(N- ∑ σt (c)) 二者的最大值。在前期,未被选择过的样本占优势,后项起作用,后期则前项起作用。

③非线性映射

        非线性映射使得阈值的调整可以更加自由。其中凸函数可能更加有效,因为其在自变量较小时因变量的变化也较小,而在自变量大时比较敏感,比较符合预估学习效果的变化特征。其公式如下:

        其中M(·)是一个非线性映射函数,将映射范围设定为从0到1,使灵活阈值范围从0到τ。

④算法流程

        其完整的算法流程如下所示:

四、实验结果

        作者在CIFAR-10,CIFAR-100,STL-10,SVHN等数据集上进行了实验,并和其他算法进行对比,其结果如表:

        观察上表可知CPL在除了SVHN以外大多数数据集上取得了较大的提升。可能的原因是CPL不适合数据分布不平衡且又很简单的任务,对于简单的任务而言,一个固定的高阈值似乎已经足够了。同样可以得出结论,带标记的数据越少,CPL带来的提升越大。反之,任务越难,CPL的提升越大。

        FlexMatch还在收敛速度上有优势,如图所示:

五、感想

        本篇文章对Fixmatch和一些常见的半监督学习算法存在的问题进行了改进,有效提高了准确率和收敛速度,尤其是在标记数据较少时提升更为显著,对改进半监督学习具有参考意义。

撰稿人:陈泉霖

审稿人:李景聪


登录用户可以查看和发表评论, 请前往  登录 或  注册
SCHOLAT.com 学者网
免责声明 | 关于我们 | 用户反馈
联系我们: