«上一篇 下一篇»
  计算机工程  2022, Vol. 48 Issue (2): 72-78  DOI: 10.19678/j.issn.1000-3428.0060046
0

引用本文  

吴鹏翔, 李凡长. 一种基于群等变卷积的度量元学习算法[J]. 计算机工程, 2022, 48(2), 72-78. DOI: 10.19678/j.issn.1000-3428.0060046.
WU Pengxiang, LI Fanzhang. A Metric Meta-Learning Algorithm Based on Group Equivariant Convolution[J]. Computer Engineering, 2022, 48(2), 72-78. DOI: 10.19678/j.issn.1000-3428.0060046.

基金项目

国家重点研发计划“变革性技术关键科学问题”重点专项(2018YFA0701700,2018YFA0701701)

通信作者

李凡长(通信作者), 教授、博士

作者简介

吴鹏翔(1995-), 男, 硕士研究生, 主研方向为深度学习、元学习

文章历史

收稿日期:2020-11-18
修回日期:2021-02-09
一种基于群等变卷积的度量元学习算法
吴鹏翔 , 李凡长     
苏州大学 计算机科学与技术学院, 江苏 苏州 215006
摘要:传统机器学习方法泛化性能不佳,需要通过大规模数据训练才能得到较好的拟合结果,因此不能快速学习训练集外的少量数据,对新种类任务适应性较差,而元学习可实现拥有类似人类学习能力的强人工智能,能够快速适应新的数据集,弥补机器学习的不足。针对传统机器学习中的自适应问题,利用样本图片的局部旋转对称性和镜像对称性,提出一种基于群等变卷积神经网络(G-CNN)的度量元学习算法,以提高特征提取能力。利用G-CNN构建4层特征映射网络,根据样本图片中的局部对称信息,将支持集样本映射到合适的度量空间,并以每类样本在度量空间中的特征平均值作为原型点。同时,通过同样的映射网络将查询机映射到度量空间,根据查询集中样本到原型点的距离完成分类。在Omniglot和miniImageNet数据集上的实验结果表明,该算法相比孪生网络、关系网络、MAML等传统4层元学习算法,在平均识别准确率和模型复杂度方面均具有优势。
关键词元学习    群等变卷积    深度学习    自适应性    度量学习    
A Metric Meta-Learning Algorithm Based on Group Equivariant Convolution
WU Pengxiang , LI Fanzhang     
School of Computer Science and Technology, Soochow University, Suzhou, Jiangsu 215006, China
Abstract: Traditional machine learning methods have poor generalization performance and need large-scale data training to get better fitting results.They cannot quickly learn a small amount of data outside the training set, and have poor adaptability to new types of tasks.Meta-learning can realize strong artificial intelligence systems with similar learning ability to human beings.These artificial intelligence systems can quickly adapt to new data sets and make up for the shortcomings of machine learning.In order to solve the adaptibility problem of traditional machine learning, a metric meta-learning algorithm based on Group equivariant Convolution Neural Network(G-CNN) is proposed by using the local rotation symmetry and mirror symmetry of sample images to improve feature extraction ability.The G-CNN is used to form a 4-layer feature mapping network.According to the local symmetry information in the sample picture, the support set samples are mapped to the appropriate metric space, and the average feature of each kind of samples in the metric space is used as the prototype point.At the same time, the query machine is mapped to the metric space through the same mapping network, so as to complete the classification according to the distance between the sample in the query set and the prototype point.Experimental results on Omniglot and miniImageNet data sets show that the proposed algorithm displays an obvious advantage in average recognition accuracy and model complexity compared with traditional 4-layer meta learning algorithms, such as Siamese network, relational network and MAML.
Key words: meta-learning    group equivariant convolution    deep learning    adaptability    metric learning    

开放科学(资源服务)标志码(OSID):

0 概述

随着计算设备并行计算性能的大幅提升,以及近年来深度神经网络在各个领域不断取得重大突破,由深度神经网络模型衍生而来的多个机器学习新领域逐渐成型,如强化学习、深度监督学习等[1-2]。在大量训练数据的加持下,深度神经网络技术已经在机器翻译、机器人控制、大数据分析、智能推送、模式识别等方面得到了广泛应用[3-4]。深度学习在完成这些任务时需要在大量数据上进行训练才能拟合出一个好的结果,一旦需要被识别物体类别不在训练集中,便无法进行正确识别。但是在实际的许多任务中,要求在少量数据上进行快速学习和适应[5]

元学习的提出为上述问题提供了一个解决方案,其目的是解决传统神经网络模型泛化能力不足、对新种类任务适应性较差的问题。快速学习的能力是人类区别于人工智能的一个关键特征[6],人类能够有效地利用以前的知识和经验来快速学习新的技能。元学习的训练和测试可类比为人类在掌握一些基本技能后快速学习并适应新的任务[7]。例如:人类可以根据一张从未见过的动物的照片辨认出该动物,而不是需要大量该动物的照片才能辨认。人类在幼儿阶段掌握的对世界的大量基础知识和对行为模式的认知基础便对应元学习中的“元”概念[8-9]。元学习的最终目标是实现拥有类似人类学习能力的强人工智能,这在当前阶段体现为对新数据集的快速适应以得到较高的准确度,因此,目前元学习目标主要表现为提高泛化性能、获取好的初始参数,以及通过少量计算和新训练数据即可在模型上实现和海量训练数据一样的识别准确度[10]。受当前计算资源与算法能力限制,元学习往往以小样本学习以及对新任务的快速适应作为切入点,因此,当前研究也多以在小样本数据集上的识别准确率作为实验衡量标准[11]

基于度量的元学习方法是一种可行的元学习方法。KOCH等于2015年提出了一种用于解决单样本学习图像分类问题的方法:孪生网络(Siamese network)[12],通过训练集学习一个卷积孪生网络,利用该网络计算待测试图像与所有单标注样本的相似度,相似度最高的单标注样本所对应的类别即是待测试图像的类别。VINYALS于2016年提出了匹配网络模型[13],其主要创新体现在建模过程和训练过程。对于建模过程的创新,该文通过设计基于记忆和注意力机制的匹配网络,使得模型能够对参与训练的样本进行快速学习。对于训练过程的创新,该文基于传统机器学习的一个重要原则,即训练和测试应在同样条件下进行,提出在训练时每次仅使用每一类任务的少量样本参与网络的训练,与测试过程保持一致。SNELL于2017年提出了原型网络[14],该网络模型基于一个基本假设,即在数据集中,对于每种不同的类型都存在一个原型点。数据集中距离该原型点越近的样本,其标签与该原型点对应的标签相同的概率就越大。文献[15]提出了由嵌入模块和关系模块组成的关系网络,其中嵌入模块用于提取输入图像的特征,关系模块用于得到输入特征的相似度。

传统基于度量的元学习算法采用卷积神经网络(Convolutional Neural Network,CNN)提取特征,但是元学习问题中的某些样本图片特征不仅具有平移对称性[16],而且还具有旋转对称性和镜像对称性[17],但是CNN只具有平移不变性,不存在对后两者的不变性,这就使得传统的元学习算法不能利用具有对称性的特征。常用的解决方法是数据增强[18],即对样本进行随机变换。此类方法虽然在一定程度上增强了泛化性,但是并不能保留局部对称性[19],更不能保证在每一层卷积上的等变性。群等变卷积神经网络(Group equivariant CNN,G-CNN)则能较好地解决这一问题[20],其不仅具有平移不变性,而且还具有旋转和镜像不变性。

为有效利用样本图片的局部旋转对称性和镜像对称性,提高特征提取能力,本文提出一种基于G-CNN的度量元学习算法。通过由群等变卷积构成的4层映射网络学习一个合适的度量空间,根据查询集中样本离原型点的距离完成分类。

1 元任务

元学习的目标是跨任务的泛化。考虑一个任务分布PT),即该模型所适配的数据的全体,目的是使这个模型可以适应这个任务分布PT)。与传统机器学习不同,元学习不是根据每个样本来优化,而是根据元任务来优化。每个元任务包含一个支持集和一个对应的查询集。在n-way k-shot元学习问题中,对于每个元任务定义支持集S和查询集Q,支持集和查询集中包含n个类别的样本,支持集中每类样本只存在k个,查询集中每类样本个数不定,支持集S定义如式(1)所示:

$ S=\left\{\right({x}_{1}, {y}_{1}), ({x}_{2}, {y}_{2}), \cdots , ({x}_{N}, {y}_{N}\left)\right\} \\ x_i \in \mathbb{R}^D, y_i \in \{1,2, \cdots , n\} $ (1)

其中:xi表示样本的D维向量表示;yi表示样本对应的标签;n表示样本类别总数。查询集Q取自数据集中和支持集S同类别但不同的样本,不带标签。图 1给出了5-way 1-shot元学习问题中训练时所使用的的支持集和查询集示例。

Download:
图 1 5-way 1-shot元学习问题中元训练使用的支持集和查询集示例 Fig. 1 Example of support set and query set using in meta-training for 5-way 1-shot meta-learning problems

在训练阶段,从PT)的训练数据集上采样训练元任务Titrain,通过元任务对损失函数进行最小化,从而优化模型参数。在训练结束后,从同取自PT)未参与训练的测试数据集(测试集中的样本和训练集中的样本类别不同)中采样测试元任务Titest,对训练好的模型进行测试。

2 映射网络

尽管现阶段的神经网络研究缺少理论支撑,但是大量经验证据表明,卷积权值共享和网络深度对于神经网络的效果起到了重要作用[21-22]。卷积权值共享的有效性依赖于其在多数感知任务中都具有平移不变性,即预测标签的函数和数据分布对于平移变换都近似于不变。由于平移不变性,共享权重的卷积核可以从图像的局部区域提取特征,并且参数量远少于全连接网络[23],同时能够学到更多有效的变换信息[24-25]。卷积层可以有效地应用于深度网络中,因为这种网络中的所有层都具有平移不变性:将图片平移后再送入若干卷积层得到的结果,与将原图直接送入相同卷积层再对特征图进行平移所得到的结果相同[26]。因此,为提高特征提取能力,本文使用G-CNN来构建映射网络,使映射网络对具有旋转对称的特征和镜像对称的特征也能保持不变性。映射网络使用4层G-CNN构建,每层由卷积核、batch-norm、relu激活函数和最大池化层组成。

2.1 群等变卷积

对于输入的2维图片,卷积是不断平移卷积核和特征图做点积运算的过程,以群G上的函数代替平移就得到了群卷积,如式(2)所示:

$ \left[f\mathrm{*}\phi \right]\left(g\right)=\sum\limits_{y\in {Z}^{2}}\sum\limits_{k}{f}_{k}\left(y\right){\phi }_{k}^{}({g}^{-1}y) $ (2)

其中:Z2是2维图片上的整数平移群;群运算是加法(nm)+(pq)=(n+pm+q);f是输入的特征图;φ是卷积核。fφ都是Z2上的函数,只适用于群卷积的第1层,但由于卷积输出的结果是离散群G上的函数,因此第1层后的卷积如式(3)所示:

$ \left[f\mathrm{*}\phi \right]\left(g\right)=\sum\limits_{h\in G}\sum\limits_{k}{f}_{k}\left(h\right){\phi }_{k}\left({g}^{-1}h\right) $ (3)

其中:输入的特征图f是群G上的函数。

h=uh,等变性证明如式(4)所示:

$ \begin{array}{l}\left[\right[{L}_{u}f\left]\mathrm{*}\phi \right]\left(g\right)=\sum\limits_{h\in G}\sum\limits_{k}{f}_{k}\left({u}^{-1}h\right)\phi \left({g}^{-1}h\right)=\\ \qquad\qquad \sum\limits_{h\in G}\sum\limits_{k}f\left(h\right)\phi \left({g}^{-1}uh\right)=\\ \qquad\qquad \sum\limits_{h\in G}\sum\limits_{k}f\left(h\right)\phi \left(\right({u}^{-1}{g)}^{-1}h)=\left[{L}_{u}\right[f\mathrm{*}\phi \left]\right]\left(g\right)\end{array} $ (4)
2.2 非线性单元

映射网络中的非线性单元包括激活函数,可以将非线性单元看作一个映射:$ v:\mathbb{R}\to \mathbb{R} $,非线性单元作用于特征图f可以视为一系列操作算子的组合,如式(5)所示:

$ v\circ [f\circ {h}^{-1}]=[v\circ f]\circ {h}^{-1} $ (5)

因此,使用非线性单元处理特征图后依然能保持等变性。

2.3 池化层

池化可以分解为不带步长的池化和下采样[27]两部分。对于不带步长的池化,定义池化操作为P,作用于特征图f的最大池化如式(6)所示(平均池化同理):

$ Pf\left(g\right)=\underset{k\in gU}{\mathrm{m}\mathrm{a}\mathrm{x}}f\left(k\right) $ (6)

其中:gUG的子群U上的一个g变换。在G-CNN中,下采样表示在G的子群H上下采样。例如:对输入2维图片做步长为2的最大池化,等价于先进行不带步长的池化,再在Z2的子群H={(2i,2j)|(ij)∈Z2}上进行下采样。

2.4 网络实现

对于具有90°旋转对称特征的图片,群G使用p4群;对于具有90°旋转对称和镜像对称的特征,群G使用p4m群[28]。p4群的群元定义如式(7)所示:

$ g(r, u, v)=\left[\begin{array}{ccc}\mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& -\mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& u\\ \mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& \mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& v\\ 0& 0& 1\end{array}\right] $ (7)

其中:0≤r < 4,r=0表示无旋转,r=1表示旋转90°;(uv)∈Z2,表示在二维平面上的水平和垂直移动。群运算为矩阵乘法。对于输入的特征图上的某点(xy),p4群作用于点(xy)的运算如式(8)所示:

$ g(x, y)=\left[\begin{array}{ccc}\mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& -\mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& u\\ \mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& \mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& v\\ 0& 0& 1\end{array}\right]\left[\begin{array}{c}x\\ y\\ 1\end{array}\right] $ (8)

p4m群的群元定义如式(9)所示:

$ g(m, r, u, v)=\left[\begin{array}{ccc}{(-1)}^{m}\mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& -{(-1)}^{m}\mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& u\\ \mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& \mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& v\\ 0& 0& 1\end{array}\right] $ (9)

其中:m=0或1,1表示镜像翻转,其余定义与p4群相同,群运算为矩阵乘法。作用于输入特征图上某点(x,y)的运算如式(10)所示:

$ g(x, y)=\left[\begin{array}{ccc}{(-1)}^{m}\mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& -{(-1)}^{m}\mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& u\\ \mathrm{s}\mathrm{i}\mathrm{n}\left(\frac{r\mathrm{\pi }}{2}\right)& \mathrm{c}\mathrm{o}\mathrm{s}\left(\frac{r\mathrm{\pi }}{2}\right)& v\\ 0& 0& 1\end{array}\right]\left[\begin{array}{c}x\\ y\\ 1\end{array}\right] $ (10)

当群G使用p4群时,第1层的G-CNN是Z2-p4卷积层,操作如图 2所示,依次将卷积核旋转90°,得到4组卷积核,分别与输入图片做卷积,得到4组映射特征。第一层后面的G-CNN是p4-p4卷积,操作如图 3所示,对于前层输入的4组映射特征,卷积核依次旋转90°得到4组卷积核,然后每组卷积核依次和输入的4组特征做卷积,将得到的结果求和得到输出特征。使用p4m群构建映射网络时,卷积核需要额外进行镜像翻转,因此,卷积核的数目是8组,得到的输出特征也是8组,操作与使用p4群类似。

Download:
图 2 Z2-p4卷积层示意图 Fig. 2 Schematic diagram of Z2-p4 convolution layer
Download:
图 3 p4-p4卷积层示意图 Fig. 3 Schematic diagram of p4-p4 convolution layer
3 特征度量

本文算法基于以下基本假设:存在一个空间,在这个空间中,属于相同类别的样本距离近,不同类别的样本距离远,这样就可以通过简单度量函数进行分类。本文算法是通过学习一个映射网络将样本映射到合适的度量空间,然后通过简单度量方法完成分类。在n-way k-shot元学习问题中,对于每个元任务,支持集中每类有k个样本,支持集经过映射网络映射到度量空间后,每一类就有k个表示,取每类k个表示的均值作为该类在度量空间中的代表。每个类在度量空间的代表称为该类的原型点cj,计算公式如式(11)所示:

$ {c}_{j}=\frac{1}{k}\sum\limits_{({x}_{i}, {y}_{i})\in S, {y}_{i}=j}{f}_{\theta }\left({x}_{i}\right) $ (11)

其中:k表示支持集中每类样本的个数;fθ表示映射网络;(xiyj)表示输入的样本和对应的标签。查询集经过同样的映射网络映射到度量空间中,利用距离计算函数d来计算查询集中待分类样本到每类原型点的距离,再利用softmax函数计算属于每个类的概率,如式(12)所示:

$ {p}_{\theta }(y=j|x)=\frac{\mathrm{e}\mathrm{x}\mathrm{p}(-d({f}_{\theta }\left(x\right), {c}_{j}\left)\right)}{\sum\limits_{j\text{'}}\mathrm{e}\mathrm{x}\mathrm{p}(-d({f}_{\theta }\left(x\right), {c}_{j\text{'}}\left)\right)} $ (12)

最后,使用交叉熵作为损失函数,如式(13)所示:

$ J\left(\theta \right)=-\sum\limits_{j}{y}_{j}lb\left({p}_{\theta }\right(y=j\left|x\right)) $ (13)

通过Adam优化器来最小化损失函数,从而优化映射网络的参数,不断从训练集中抽取样本组成元任务来训练模型,直到得到一个能很好地将训练样本映射到合适度量空间的模型。

本文提出的基于群等变卷积的度量元学习算法(Metric Meta-learning algorithm Based on Group Equivariant Convolution,MMBOGEC)如算法1所示。

算法1  MMBOGEC

输入  训练集D={(x1y1),(x2y2),…,(xNyN)}

输出  模型在测试集上的分类准确率

1)在训练集中随机选取n个类,对于选取的每个类,取k个样本组成支持集,取Nq个样本组成查询集。

2)通过映射网络得到支持集样本在度量空间中的表示,取每个类所有样本在度量空间中特征表示的均值作为该类的原型点。

3)利用同样的映射网络得到查询集样本在度量空间中的表示,利用距离计算公式计算查询集样本在度量空间中的表示到每个类原型点的距离,利用softmax函数计算属于每个类的概率,将概率最大的类别作为预测类别。

4)使用交叉熵作为损失函数更新损失J

5)使用Adam优化器最小化损失J来更新网络参数。

6)重复步骤1~步骤5,直到损失J不再下降。

7)在测试集中生成若干个元任务,每个元任务随机选取n个类,对于选取的每个类,取k个样本组成支持集,取Nq个样本组成查询集,将这些元任务输入训练好的模型,得到分类准确率,最后将分类准确率的平均值作为输出结果。

4 实验结果与分析

本文在常用的小样本数据集miniImageNet和Omniglot上进行实验。

4.1 miniImageNet数据集

miniImageNet数据集包含60 000张彩色图片,分为100个类,每个类600张。首先将所有图片处理成84像素×84像素大小,将其中的64类作为训练集,16类作为验证集,剩下的20类作为测试集。本文使用64类来训练模型,验证集仅仅用来判断模型泛化性的好坏,不参与模型的参数优化。

输入的样本图片经过映射网络得到其在度量空间中的特征表示,映射网络包含4层由G-CNN构成的卷积,每一层使用64个3×3卷积核,包含batch-norm、relu激活函数以及3×3的最大池化层。最后将得到的特征表示展开成一维向量,利用距离计算函数计算其到各个原型点的距离,将距离最近的类别作为预测标签。以交叉熵作为损失函数,不添加正则项损失,学习率设置为10-3,使用Adam优化器对网络参数进行优化。

针对miniImageNet数据集常用的有两种训练方法,分别是5-way 1-shot和5-way 5-shot。5-way 1-shot训练方法先任意地从训练集中选5个类别,每个类别包含1个样本,总计5个样本作为支持集,再从前面5类中每类选取若干个不同的样本(本文实验中设置为15个)作为查询集,使模型根据支持集来分类查询集。5-way 5-shot训练方法将支持集每类选取样本数改为5,其余和前面一致。当验证集上的验证损失不再下降时,停止训练模型,在测试集上验证模型的效果,测试方法和训练方法保持一致,测试使用随机产生的600个元任务,以平均准确率作为评估指标。

4.1.1 不同距离计算公式对实验结果的影响

不同距离的度量公式会对算法的实验结果产生影响,本文使用常用的4种距离计算公式进行测试,分别是欧式距离、余弦距离、切比雪夫距离和曼哈顿距离,测试结果对比如表 1所示。可以看出,在miniImageNet数据集5-way 1-shot和5-way 5-shot方法中,欧氏距离作为距离计算公式最有效,其次是曼哈顿距离,切比雪夫距离最差。

下载CSV 表 1 使用不同距离计算公式的实验结果对比 Table 1 Comparison of experimental results using different distance calculation formulas 
4.1.2 消融实验

为验证本文算法的有效性,分别使用p4群、p4m群和普通CNN构建映射网络行实验,对比实验结果如表 2所示。可以看出:不使用群等变卷积的方法,实验结果最差;使用p4群的方法,实验结果优于使用普通CNN的方法,表明在本实验中,具有旋转不变性的方法比不具有旋转不变性的方法更有效;使用p4m群的方法,实验效果最好,表明利用旋转不变性和镜像对称不变性能有效提高元学习的自适应性。

下载CSV 表 2 消融实验结果对比 Table 2 Comparison of ablation experimental results 
4.1.3 G-CNN层数对实验结果的影响

为进一步验证群等变卷积的有效性,在部分卷积层上使用群等变卷积进行实验,实验结果如表 3所示,其中第1列表示使用群等变卷积的卷积层,如1表示仅在第1层使用,其余层使用普通CNN。可以看出,在5-way 1-shot和5-way 5-shot的实验中,仅仅在单层中使用群等变卷积,不论是在哪一层使用,实验结果都相差不大,表明仅在某一层具有等变性不能很好地提升元学习的自适应性。随着使用群等变卷积层数的增加,实验效果随之提升,完整的4层群等变卷积网络效果最好,表明整个网络都具有等变性才能更好地适用于元学习问题。

下载CSV 表 3 在不同卷积层使用G-CNN的实验结果对比 Table 3 Comparison of experimental results using G-CNN in different convolutional layers 
4.1.4 与4层元学习算法的实验结果对比

将本文算法与传统4层元学习算法进行对比,实验结果如表 4所示(加粗数据表示最优数据)。可以看出,无论是5-way 1-shot还是5-way 5-shot,本文算法性能都优于传统4层元学习算法。

下载CSV 表 4 不同算法在miniImageNet数据集上的实验结果对比 Table 4 Comparison of experimental results of different algorithms on miniImageNet dataset 
4.2 Omniglot数据集

Omniglot数据集包含50种不同语言,共计1 623种手写字符,每种字符包含20个样本,每个样本由不同人书写。本文将样本图片大小统一为28像素×28像素,使用其中的1 028类作为训练集,423类作为测试集,剩下的作为验证集。

输入的样本图片经过映射网络得到其在度量空间中的特征表示,映射网络包含4层由G-CNN构成的卷积层,每层使用64个3×3卷积核、batch-norm、relu激活函数以及3×3的最大池化层。在度量空间中使用欧氏距离计算查询集到原型点的距离,将距离最短的原型点对应的标签作为预测标签,以交叉熵作为损失函数,不添加正则项损失,学习率设置为10-3,使用Adam优化器对网络参数进行优化。

Omniglot数据集常用的有4种训练方法,分别是5-way 1-shot、5-way 5-shot、20-way 1-shot和20-way 5-shot,测试时同样使用对应的方法。测试使用随机产生的1 000个元任务,以平均准确率作为最后的结果。

本文算法与传统4层元学习算法在Omniglot数据集上实验结果对比如表 5所示(加粗数据表示最优数据),可以看出,在5-way 1-shot、5-way 5-shot实验中,本文算法性能都优于其他算法。

下载CSV 表 5 不同算法在Omniglot数据集上的实验结果对比 Table 5 Comparison of experimental results of different algorithms on Omniglot dataset 
4.3 模型复杂度分析

本文算法针对n-way k-shot元学习问题,对于每个元任务,需要n类支持集样本,每类样本包含k个实例,对q个支持集样本进行分类,因此每个元任务的平均复杂度为On×k×q)。

MMBOGEC算法与传统4层元学习算法的参数量对比如表 6所示(加粗数据表示最优数据)。可以看出,MMBOGEC算法参数量只比原型网络算法多,而少于其他4种算法。

下载CSV 表 6 不同算法的参数量对比 Table 6 Comparison of the number of parameters of different algorithms
5 结束语

针对传统机器学习的自适应性问题,本文提出一种基于群等变卷积的度量元学习算法,使用群等变卷积神经网络构建映射网络,充分利用样本图片的局部旋转对称性和镜像对称性,将样本图片映射到合适的度量空间,根据所提取特征到每个类原型点的距离远近来实现分类。在Omniglot数据集和miniImageNet数据集上的实验结果表明,该算法对于元学习问题的学习性能优于传统4层元学习算法。下一步将对本文算法进行改进,探索更有效的特征映射网络和特征距离比较方法。

参考文献
[1]
JIN C, YANG Z, WANG Z, et al. Provably efficient reinforcement learning with linear function approximation[C]//Proceedings of the 33rd Annual Conference on Learning Theory. Graz, Austria: [s. n. ], 2020: 2137-2143.
[2]
SILVER D, SCHRITTWIESER J, SIMONYAN K, et al. Mastering the game of go without human knowledge[J]. Nature, 2017, 550(7676): 354-359. DOI:10.1038/nature24270
[3]
ERIA K, JAYABALAN M. Neural machine translation: a review of the approaches[J]. Journal of Computational and Theoretical Nanoscience, 2019, 16(8): 3596-3602. DOI:10.1166/jctn.2019.8331
[4]
SETHI A, GU M, GUMUSGOZ E, et al. Supervised enhancer prediction with epigenetic pattern recognition and targeted validation[J]. Nature Methods, 2020, 17(8): 807-814. DOI:10.1038/s41592-020-0907-8
[5]
LAI N, KAN M, HAN C, et al. Learning to learn adaptive classifier-predictor for few-shot learning[J]. IEEE Transactions on Neural Networks and Learning Systems, 2021, 32(8): 3458-3470. DOI:10.1109/TNNLS.2020.3011526
[6]
MISHRA N, ROHANINEJAD M, CHEN X, et al. A simple neural attentive meta-learner[EB/OL]. (2017-07-11)[2020-11-10]. https://arxiv.org/pdf/1707.03141v3.pdf.
[7]
LI F F, FERGUS R, PERONA P. One-shot learning of object categories[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2006, 28(4): 594-611. DOI:10.1109/TPAMI.2006.79
[8]
SANTORO A, BARTUNOV S, BOTVINICK M, et al. Meta-learning with memory-augmented neural networks[C]//Proceedings of International Conference on Machine Learning. New York, USA: IEEE Press, 2016: 1842-1850.
[9]
KAISER Ł, NACHUM O, ROY A, et al. Learning to remember rare events[EB/OL]. (2017-03-09)[2020-11-10]. https://arxiv.org/pdf/1703.03129.pdf.
[10]
FINN C, ABBEEL P, LEVINE S. Model-agnostic meta-learning for fast adaptation of deep networks[EB/OL]. (2017-03-09)[2020-11-10]. https://arxiv.org/pdf/1703.03400v3.pdf
[11]
CHEN Z, FU Y, WANG Y X, et al. Image deformation meta-networks for one-shot learning[C]//Proceedings of 2019 IEEE Conference on Computer Vision and Pattern Recognition. Washington D.C., USA: IEEE Press, 2019: 8680-8689.
[12]
KOCH G, ZEMEL R, SALAKHUTDINOV R. Siamese neural networks for one-shot image recognition[C]//Proceedings of the 32nd International Conference on Machine Learning. Lille, France: [s. n. ], 2015: 1-8.
[13]
VINYALS O, BLUNDELL C, LILLICRAP T, et al. Matching networks for one shot learning[C]//Proceedings of NIPS'16. Berlin, Germany: Springer, 2016: 3630-3638.
[14]
SNELL J, SWERSKY K, ZEMEL R. Prototypical networks for few-shot learning[C]//Proceedings of NIPS'17. Berlin, Germany: Springer, 2017: 4077-4087.
[15]
SUNG F, YANG Y, ZHANG L, et al. Learning to compare: relation network for few-shot learning[C]//Proceedings of 2018 IEEE Conference on Computer Vision and Pattern Recognition. Washington D.C., USA: IEEE Press, 2018: 1199-1208.
[16]
LING Y, XIAN Z, ZHOU Z. Holographic shear viscosity in hyperscaling violating theories without translational invariance[J]. Journal of High Energy Physics, 2016(11): 7.
[17]
JIANG Y, LAI X, WATANABE K, et al. Charge order and broken rotational symmetry in magic-angle twisted bilayer graphene[J]. Nature, 2019, 573(7772): 91-95. DOI:10.1038/s41586-019-1460-4
[18]
NAMOZOV A, CHO Y I. An improvement for medical image analysis using data enhancement techniques in deep learning[C]//Proceedings of 2018 International Conference on Information and Communication Technology Robotics. Washington D.C., USA: IEEE Press, 2018: 1-3.
[19]
DOMMASCHK M, ECHAVARREN J, LEIGH D A, et al. Dynamic control of chiral space Through local symmetry breaking in a rotaxane organocatalyst[J]. Angewandte Chemie International Edition, 2019, 58(42): 14955-14958. DOI:10.1002/anie.201908330
[20]
COHEN T, WELLING M. Group equivariant convolutional networks[C]//Proceedings of International Conference on Machine Learning. New York, USA: ACM Press, 2016: 2990-2999.
[21]
VIALATTE J C, GRIPON V, MERCIER G. Generalizing the convolution operator to extend CNNs to irregular domains[EB/OL]. (2016-06-03)[2020-11-10]. https://arxiv.org/pdf/1606.01166.pdf.
[22]
ULLRICH K, MEEDS E, WELLING M. Soft weight-sharing for neural network compression[EB/OL]. (2017-02-13)[2020-11-10]. https://arxiv.org/pdf/1702.04008.pdf.
[23]
LI J, LI B, XU J, et al. Fully connected network-based intra prediction for image coding[J]. IEEE Transactions on Image Processing, 2018, 27(7): 3236-3247. DOI:10.1109/TIP.2018.2817044
[24]
JIANG W, HUANG C, DENG X. A new probability transformation method based on a correlation coefficient of belief functions[J]. International Journal of Intelligent Systems, 2019, 34(6): 1337-1347. DOI:10.1002/int.22098
[25]
CLEMONS E K, DEWAN R M, KAUFFMAN R J, et al. Understanding the information-based transformation of strategy and society[J]. Journal of Management Information Systems, 2017, 34(2): 425-456. DOI:10.1080/07421222.2017.1334474
[26]
LIU M Y, HUANG X, MALLYA A, et al. Few-shot unsupervised image-to-image translation[C]//Proceedings of 2019 IEEE International Conference on Computer Vision. Washington D.C., USA: IEEE Press, 2019: 10551-10560.
[27]
ETIKAN I, BALA K. Sampling and sampling methods[J]. Biometrics & Biostatistics International Journal, 2017, 5(6): 1-5.
[28]
DEY N, CHEN A, GHAFURIAN S. Group equivariant generative adversarial networks[EB/OL]. (2020-05-04)[2020-11-10]. https://arxiv.org/pdf/2005.01683v2.pdf.