元学习:让 Agent 学会快速学习
元学习:让 Agent 学会快速学习
嘿,各位正在搞大模型应用、智能体、强化学习或者终身学习的朋友们!你是否见过这样的场景?
你花了三个月,调了上万次超参数,好不容易训练出一个能在围棋AlphaGo复刻环境下打赢业余三段的智能体;转头你想让它玩一局简单到爆的五子棋,结果它愣在那里只会在棋盘角落乱点,比刚上手的三岁小孩还菜——这就是传统深度学习/强化学习的致命痛点:“数据饥渴症+过拟合性迁移能力差”,就像只会背一本《唐诗三百首》的机器人,你让它写一句“五言绝句写夏天”,它可能直接把“床前明月光”贴上去,完全不知道“举一反三”“触类旁通”是啥。
而今天我们要聊的元学习(Meta-Learning),就是解决这个问题的“终极钥匙”之一!它的核心目标不是让智能体“学会某一个具体任务”,而是让它“学会学习(Learn to Learn)”:通过在成百上千个“训练任务集”上的“元训练(Meta-Training)”,智能体能够掌握一套通用的“学习策略”或者“初始状态”;当它遇到全新的“测试任务”时,只需要用极少量的样本(比如1-10张图片、1-5次交互反馈)进行“快速适配(Few-Shot Adaptation)”,就能达到甚至超过传统模型在成百上千样本训练后的表现。
听起来是不是有点科幻?但别担心,我不会一开始就给你甩一堆复杂的数学公式(当然后面会有,但都是循序渐进拆解的)。在这篇10000字左右的技术博客里,我会带你从零入门元学习,理解它的核心原理、分类体系、经典算法(MAML、Reptile、Prototypical Networks、Relation Networks这些都得安排上),还会用Python和PyTorch写一个可运行的Few-Shot图像分类实战项目;最后,我们会聊聊元学习的进阶应用、最佳实践、未来趋势——甚至,我会给你留一个能让你思考三天的开放性问题!
准备好了吗?系好安全带,我们的“元学习之旅”正式出发!
二、基础知识/背景铺垫(Foundational Concepts)
哦不对,原目录里的第一部分是引言对吧?刚才的内容是引言的钩子+为什么要聊元学习,但得严格按照系统给的通用目录结构来,不然文章会乱。让我先把引言补全,再进入基础知识铺垫。
一、引言(Introduction)
1.1 钩子(The Hook):从人类学习 vs 机器学习的“天差地别”说起
先做一个小实验——请你闭上眼睛回忆一下:你第一次见到“猫”是什么时候?大概是你刚学会爬或者刚学会说话的时候,你的妈妈/爸爸指着一只毛茸茸的、喵喵叫的动物说:“宝贝,这是猫!”,然后第二次、第三次见到不同颜色、不同品种的猫(布偶、橘猫、狸花猫),你只需要几秒钟就能认出来:“哦,这也是猫!”
那如果换成一个传统的图像分类模型(比如ResNet-18)呢?假设你想让它学会识别“猫”和“狗”,你至少得准备10000张标注好的猫和狗的图片(ImageNet的猫狗子数据集就差不多这个量级),然后用GPU跑上几个小时甚至几天,调几十次超参数(学习率、批量大小、正则化系数等等),才能让它的准确率达到90%以上;但如果你现在给它一张之前完全没见过的动物图片——比如一只“薮猫”(比普通猫大一圈,耳朵特别长),你只给它看1张薮猫的图片,它大概率会把薮猫识别成“狗”或者“狐狸”,准确率可能连10%都不到。
为什么会有这么大的差距?因为人类大脑天生就有“元认知能力”和“类比推理能力”:我们在学会识别“猫”之前,已经通过观察其他动物、物体,掌握了一套通用的“视觉特征提取规则”——比如“毛茸茸的有尾巴的大概率是动物”“耳朵形状、眼睛大小、毛发颜色是区分不同动物的关键”“叫声是辅助判断的依据”;当我们第一次见到“薮猫”时,我们会自动把它和之前见过的“猫”的通用特征做类比:“哦,它有猫的耳朵形状(虽然长一点)、猫的眼睛、猫的毛发纹理、猫的尾巴,叫声也像猫——所以它应该是一种特殊的猫!”
而传统的深度学习/强化学习模型呢?它们本质上是“数据拟合机器”:它们的训练过程就是“在一个固定的数据集上,通过反向传播调整参数,最小化损失函数”,它们学到的只是“这个特定数据集的统计规律”,而不是“通用的学习策略”;当你把它们放到一个全新的任务或者全新的数据集上时,之前学到的“特定统计规律”就完全没用了——它们必须重新开始“数据拟合”的过程,这就是所谓的“灾难性遗忘(Catastrophic Forgetting)”和“过拟合性迁移能力差”。
1.2 定义问题/阐述背景(The “Why”):传统AI的三大“死穴”倒逼元学习的诞生
刚才的小实验只是元学习诞生的“感性原因”,但从“理性的技术发展角度”来看,元学习的诞生其实是为了解决传统AI(主要是深度学习和强化学习)面临的三大核心“死穴”:
1.2.1 死穴一:数据饥渴症(Data Scarcity/Annotation Cost)
在深度学习的黄金时代(2012-2020年左右),我们的口号是“数据为王”——谁拥有更多的标注数据,谁就能训练出更好的模型;但随着AI应用场景的不断拓展,“标注数据”逐渐变成了一种“稀缺资源”:
- 医疗影像领域:标注一张肺部CT切片的成本可能高达几十甚至上百元,而且需要专业的医生才能完成标注;
- 自动驾驶领域:标注一段1分钟的自动驾驶场景视频(包含行人、车辆、交通标志、障碍物等等)的成本可能高达几千元;
- 少样本个性化推荐领域:每个新用户注册的时候,只有1-5条浏览记录/购买记录,我们根本无法用传统的协同过滤或者深度学习推荐模型给他们做精准推荐;
- 强化学习落地领域:比如让机器人学会开门、倒水、叠衣服,在真实环境中训练机器人不仅成本高(机器人可能会损坏),而且效率极低(机器人可能需要尝试几万次甚至几十万次才能学会一个简单的动作)。
1.2.2 死穴二:灾难性遗忘(Catastrophic Forgetting)
什么是灾难性遗忘?简单来说,就是当你让一个已经学会任务A的模型去学习任务B时,它会很快忘记任务A的知识。
举个经典的例子:假设你有一个ResNet-18模型,先在ImageNet的“猫狗数据集”上训练,让它的准确率达到95%;然后你把它放到ImageNet的“汽车飞机数据集”上训练,让它的准确率也达到95%;这时候你再把它放回“猫狗数据集”上测试,你会发现它的准确率可能连50%都不到——它已经完全忘记了“猫和狗长什么样”!
为什么会发生这种情况?因为深度学习模型的参数是“共享的”:当你训练任务B时,反向传播会调整所有的参数,这些参数中包含了之前任务A学到的“知识”——调整之后,任务A的“知识”就被“覆盖”或者“擦除”了。
1.2.3 死穴三:过拟合性迁移能力差(Overfitting to Source Task & Poor Cross-Task Transfer)
刚才我们提到的“ResNet-18识别薮猫”的例子,就是过拟合性迁移能力差的典型表现;还有一个更极端的例子:DeepMind的AlphaGo Zero虽然能在围棋上打败人类世界冠军,但它连最简单的“井字棋(Tic-Tac-Toe)”都不会玩——因为AlphaGo Zero是专门为围棋设计的,它的神经网络架构、损失函数、强化学习算法都是针对围棋优化的,根本无法迁移到其他棋类游戏上。
那有没有一种方法,能够同时解决这三大死穴?——答案就是元学习!
1.3 亮明观点/文章目标(The “What” & “How”):这篇文章你能学到什么?
好的,既然元学习这么重要,那这篇文章我会带你系统地、深入浅出地掌握元学习的核心知识,具体来说,你读完这篇文章后,能够:
- 理解元学习的核心定义、本质、分类体系,知道什么是“元训练”“元测试”“Few-Shot Learning”“One-Shot Learning”“Zero-Shot Learning”;
- 掌握元学习的四大核心技术路线——基于优化的元学习(MAML、Reptile)、基于度量的元学习(Prototypical Networks、Relation Networks)、基于记忆的元学习(MANN、NTM)、基于模型的元学习(Meta-Learning with Hypernetworks),并且能够理解每个技术路线的核心原理、数学公式、优缺点;
- 用Python和PyTorch手写一个可运行的Few-Shot图像分类实战项目——我们会用经典的Omniglot数据集(被称为“Few-Shot Learning的MNIST”)作为训练集和测试集,实现基于MAML的Few-Shot分类算法和基于Prototypical Networks的Few-Shot分类算法,并且会对比这两种算法的效果;
- 了解元学习的进阶应用场景——比如大模型微调(LoRA其实可以看作是一种简化的元学习!)、终身学习、强化学习快速适配、少样本个性化推荐、医疗影像诊断、自动驾驶少样本场景适应;
- 掌握元学习的最佳实践——比如如何选择元训练任务集、如何设计元学习的损失函数、如何避免元学习的过拟合、如何加速元学习的训练过程;
- 了解元学习的发展历史和未来趋势——比如从早期的“学习学习的规则”到现在的“大模型元学习”,未来元学习可能会和“AGI(通用人工智能)”“世界模型”“因果推断”结合起来;
- 得到一个能让你思考三天的开放性问题——并且我会给你提供一些思考的方向。
为了让你更好地理解这些内容,我会在文章中加入大量的直观例子、类比、图解、数学公式(用LaTeX格式)、算法流程图(用Mermaid格式)、Python源代码(带详细注释)、实验结果对比图——我保证,即使你只有“机器学习入门级”的水平(知道什么是神经网络、反向传播、损失函数),你也能读懂这篇文章!
好的,引言部分就到这里!接下来,我们进入基础知识/背景铺垫部分——在这一部分,我会给你解释元学习中最核心的10个概念,并且会用表格和Mermaid架构图对比元学习和传统机器学习的区别,最后还会简单介绍一下元学习中最常用的几个数据集。
二、基础知识/背景铺垫(Foundational Concepts)
2.1 元学习的核心定义与本质
2.1.1 核心定义
在正式给出元学习的核心定义之前,我们先回忆一下传统监督学习的定义:
传统监督学习(Traditional Supervised Learning):给定一个固定的任务TTT(比如“识别猫和狗的图像分类任务”),以及一个标注好的训练数据集Dtrain={(x1,y1),(x2,y2),…,(xN,yN)}D_{\text{train}} = \{(x_1, y_1), (x_2, y_2), \dots, (x_N, y_N)\}Dtrain={(x1,y1),(x2,y2),…,(xN,yN)}(其中xix_ixi是输入样本,yiy_iyi是对应的标签),我们的目标是学习一个映射函数fθ:X→Yf_\theta: \mathcal{X} \rightarrow \mathcal{Y}fθ:X→Y(其中θ\thetaθ是模型的参数,X\mathcal{X}X是输入空间,Y\mathcal{Y}Y是标签空间),使得这个映射函数在未见过的测试数据集DtestD_{\text{test}}Dtest上的损失函数值最小(或者准确率最高)。
用更通俗的话来说,传统监督学习就是“针对一个任务,学一个模型”。
那元学习的定义是什么呢?目前学术界有很多种定义,但最被广泛接受的是Yann LeCun的学生、Meta AI(原Facebook AI Research)的研究员Chelsea Finn在她2017年的博士论文《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》中给出的定义:
元学习(Meta-Learning),也被称为“学会学习(Learn to Learn)”:给定一个任务分布p(T)p(T)p(T)(也就是“成百上千个相关的训练任务的集合”,比如Omniglot数据集中的“识别1-964个不同字符的分类任务”),以及每个任务TiT_iTi对应的支持集(Support Set)Si={(xi,1,yi,1),…,(xi,K,yi,K)}S_i = \{(x_{i,1}, y_{i,1}), \dots, (x_{i,K}, y_{i,K})\}Si={(xi,1,yi,1),…,(xi,K,yi,K)}(也就是“每个训练任务的少量标注样本”,KKK通常取1-20)和查询集(Query Set)Qi={(xi,1′,yi,1′),…,(xi,M,yi,M′)}Q_i = \{(x_{i,1}', y_{i,1}'), \dots, (x_{i,M}, y_{i,M}')\}Qi={(xi,1′,yi,1′),…,(xi,M,yi,M′)}(也就是“每个训练任务的未见过的测试样本”,MMM通常取10-50),我们的目标是学习一个元参数(Meta-Parameter)θ\thetaθ(或者一套通用的学习策略),使得这个元参数θ\thetaθ在未见过的测试任务Ttest∼p(T)T_{\text{test}} \sim p(T)Ttest∼p(T)上,只用支持集StestS_{\text{test}}Stest的少量样本进行快速适配(得到适配后的参数θtest\theta_{\text{test}}θtest)后,在查询集QtestQ_{\text{test}}Qtest上的损失函数值最小(或者准确率最高)。
用更通俗的话来说,元学习就是“针对一堆相关任务,学一个‘快速学习器’——这个快速学习器拿到新任务后,看一眼(Few-Shot)甚至零眼(Zero-Shot)就能学会”。
2.1.2 本质理解
为了让你更好地理解元学习的本质,我给你打两个非常直观的类比:
类比一:“准备期末考试的学生” vs “准备研究生复试的学生”
- 传统监督学习的模型:就像“一个只会背期末考试范围的学生”——老师给了他一个“固定的期末考试范围”(对应固定的任务TTT),他花了三个月的时间,把这个范围内的所有知识点(对应标注好的训练数据集DtrainD_{\text{train}}Dtrain)都背得滚瓜烂熟,期末考试(对应测试数据集DtestD_{\text{test}}Dtest)考了95分;但转头老师让他参加“研究生复试”(对应全新的任务TtestT_{\text{test}}Ttest),复试范围完全不在之前的期末考试范围内,而且老师只给了他“10页复试参考资料”(对应支持集StestS_{\text{test}}Stest的少量样本),结果他复试考了30分——因为他只会“背固定范围的知识点”,不会“快速学习新知识”。
- 元学习的模型:就像“一个掌握了‘通用学习方法’的准备研究生复试的学生”——这个学生之前已经参加过“成百上千次不同科目的小测验”(对应任务分布p(T)p(T)p(T)中的成百上千个训练任务),每次小测验老师只给了他“10页参考资料”(对应每个训练任务的支持集SiS_iSi),然后让他参加小测验(对应每个训练任务的查询集QiQ_iQi);通过这成百上千次小测验的训练,这个学生掌握了一套“通用的学习方法”——比如“如何快速抓住知识点的重点”“如何通过类比旧知识学习新知识”“如何快速做笔记”“如何快速调整心态应对新考试”;现在老师让他参加“全新的研究生复试”(对应测试任务TtestT_{\text{test}}Ttest),同样只给了他“10页复试参考资料”(对应支持集StestS_{\text{test}}Stest),结果他用这套“通用的学习方法”,只花了1天时间就把参考资料学完了,复试考了90分!
类比二:“只会做一种菜的厨师” vs “掌握了‘通用烹饪方法’的厨师”
- 传统监督学习的模型:就像“只会做‘鱼香肉丝’的厨师”——他花了三个月的时间,跟着师父学做“鱼香肉丝”,师父给了他“10000斤猪肉、10000斤胡萝卜、10000斤木耳、10000斤鱼香调料”(对应标注好的训练数据集DtrainD_{\text{train}}Dtrain),他反复练习了10000次,终于做出了“师父满意的鱼香肉丝”(对应测试数据集DtestD_{\text{test}}Dtest上的高准确率);但转头客人点了“宫保鸡丁”(对应全新的任务TtestT_{\text{test}}Ttest),他根本不会做——因为他只会“按照固定的菜谱做固定的菜”,不会“根据不同的食材和调料快速调整烹饪方法”。
- 元学习的模型:就像“掌握了‘通用烹饪方法’的厨师”——这个厨师之前已经跟着“成百上千个不同菜系的师父”学过做菜(对应任务分布p(T)p(T)p(T)中的成百上千个训练任务),每个师父只教了他“1-5道菜的做法”(对应每个训练任务的支持集SiS_iSi),然后让他给客人做“同菜系的另一道菜”(对应每个训练任务的查询集QiQ_iQi);通过这成百上千个师父的教导,这个厨师掌握了一套“通用的烹饪方法”——比如“如何根据食材的特性选择烹饪方式(炒、煮、蒸、炸)”“如何根据客人的口味调整调料的用量”“如何快速切菜配菜”“如何快速处理突发情况(比如盐放多了怎么办)”;现在客人点了“他之前完全没听过的‘泰式咖喱虾’”(对应测试任务TtestT_{\text{test}}Ttest),客人只给他看了“1张泰式咖喱虾的图片和1份简单的菜谱”(对应支持集StestS_{\text{test}}Stest),结果他用这套“通用的烹饪方法”,只花了30分钟就做出了“客人非常满意的泰式咖喱虾”!
好的,通过这两个类比,我相信你已经对元学习的本质有了非常直观的理解!接下来,我们来梳理一下元学习中最核心的10个概念——这10个概念是你后续学习元学习算法的基础,必须牢牢掌握!
2.2 元学习的10个核心概念(必须牢牢掌握!)
2.2.1 任务(Task)TTT
在元学习中,任务是最基本的单元——我们不是直接训练模型,而是训练模型在“一堆任务”上的表现。
那什么是“任务”呢?在监督学习场景下,一个任务TTT通常由三个部分组成:
- 输入空间(Input Space)X\mathcal{X}X:比如图像分类任务中的“所有28x28像素的灰度图的集合”;
- 标签空间(Label Space)Y\mathcal{Y}Y:比如图像分类任务中的“所有类别的集合”;
- 条件分布(Conditional Distribution)p(y∣x)p(y|x)p(y∣x):也就是“输入xxx对应的标签yyy的概率分布”——比如在图像分类任务中,p(y∣x)p(y|x)p(y∣x)就是“输入一张图片xxx,它属于类别yyy的概率”。
不过在实际的元学习训练中,我们通常不需要显式地定义条件分布p(y∣x)p(y|x)p(y∣x),而是用数据集来表示一个任务TTT——也就是“支持集SiS_iSi + 查询集QiQ_iQi”,这个我们后面会详细讲。
2.2.2 任务分布(Task Distribution)p(T)p(T)p(T)
在传统监督学习中,我们只有一个固定的任务TTT;但在元学习中,我们有一个任务分布p(T)p(T)p(T)——也就是“成百上千个相关的任务的集合,每个任务都有一个出现的概率”。
那什么是“相关的任务”呢?举个例子:
- 如果我们的目标是“让智能体学会快速识别手写字符”,那么“识别Omniglot数据集中的1-964个不同字符的分类任务”就是“相关的任务”——因为它们都是“手写字符识别任务”,输入空间都是“28x28像素的灰度图的集合”,标签空间的结构都是“CCC个互斥类别的集合”;
- 但如果我们把“识别手写字符的分类任务”和“预测股票价格的回归任务”放在一起,那它们就不是“相关的任务”——因为它们的输入空间、标签空间、条件分布都完全不同,元学习模型根本无法从这些任务中学到“通用的学习策略”。
任务分布p(T)p(T)p(T)的选择对元学习模型的效果影响非常大——这也是元学习的最佳实践之一,我们后面会详细讲。
2.2.3 元训练集(Meta-Training Set)Dmeta-train\mathcal{D}_{\text{meta-train}}Dmeta-train
元训练集就是从任务分布p(T)p(T)p(T)中采样出来的“成百上千个训练任务的集合”——也就是我们用来“训练元参数θ\thetaθ”的数据集。
每个训练任务Ti∈Dmeta-trainT_i \in \mathcal{D}_{\text{meta-train}}Ti∈Dmeta-train都由支持集(Support Set)SiS_iSi和查询集(Query Set)QiQ_iQi组成,这个我们后面会详细讲。
2.2.4 元验证集(Meta-Validation Set)Dmeta-val\mathcal{D}_{\text{meta-val}}Dmeta-val
元验证集就是从任务分布p(T)p(T)p(T)中采样出来的“几十个验证任务的集合”——也就是我们用来“调整元学习的超参数(比如元学习率、适配学习率、批量大小、支持集样本数KKK等等)”的数据集。
和传统监督学习一样,我们在元训练过程中,会定期在元验证集上测试元参数θ\thetaθ的效果,然后根据测试结果调整超参数,避免元学习模型在元训练集上过拟合。
2.2.5 元测试集(Meta-Test Set)Dmeta-test\mathcal{D}_{\text{meta-test}}Dmeta-test
元测试集就是从任务分布p(T)p(T)p(T)中采样出来的“几十个测试任务的集合”——也就是我们用来“最终评估元学习模型效果”的数据集。
注意:元测试集中的任务必须是元训练集和元验证集中完全没有见过的!——否则我们的评估结果就是“作弊”的,因为元学习模型已经见过这些任务了。
2.2.6 支持集(Support Set)SSS
支持集就是每个任务(训练任务、验证任务、测试任务)中的“少量标注样本”——也就是我们用来“让元学习模型快速适配这个任务”的样本,KKK通常取1-20(如果K=1K=1K=1,就叫One-Shot Learning;如果K=5K=5K=5,就叫5-Shot Learning;如果K=10K=10K=10,就叫10-Shot Learning;以此类推)。
在监督学习场景下,支持集SSS通常表示为:
S={(x1,y1),(x2,y2),…,(xK,yK)}S = \{(x_1, y_1), (x_2, y_2), \dots, (x_K, y_K)\}S={(x1,y1),(x2,y2),…,(xK,yK)}
其中xi∈Xx_i \in \mathcal{X}xi∈X是输入样本,yi∈Yy_i \in \mathcal{Y}yi∈Y是对应的标签。
2.2.7 查询集(Query Set)QQQ
查询集就是每个任务(训练任务、验证任务、测试任务)中的“未见过的测试样本”——也就是我们用来“评估元学习模型在这个任务上的适配效果”的样本,MMM通常取10-50。
在监督学习场景下,查询集QQQ通常表示为:
Q={(x1′,y1′),(x2′,y2′),…,(xM′,yM′)}Q = \{(x_1', y_1'), (x_2', y_2'), \dots, (x_M', y_M')\}Q={(x1′,y1′),(x2′,y2′),…,(xM′,yM′)}
其中xi′∈Xx_i' \in \mathcal{X}xi′∈X是输入样本,yi′∈Yy_i' \in \mathcal{Y}yi′∈Y是对应的标签。
注意:支持集SSS和查询集QQQ中的样本必须是完全不重叠的!——否则我们的评估结果就是“作弊”的,因为元学习模型已经见过这些查询样本的标签了。
2.2.8 元参数(Meta-Parameter)θ\thetaθ
元参数就是元学习模型在元训练过程中学到的“通用的初始状态”或者“通用的学习策略”——这是元学习模型的核心!
不同的元学习算法,元参数θ\thetaθ的含义是不同的:
- 在**基于优化的元学习(比如MAML、Reptile)**中,元参数θ\thetaθ就是“神经网络的通用初始参数”——当元学习模型遇到新任务时,只需要用支持集SSS的少量样本,对这个通用初始参数θ\thetaθ进行“几步梯度下降(或上升)”,就能得到适配后的参数θ′\theta'θ′,然后用θ′\theta'θ′在查询集QQQ上进行预测;
- 在**基于度量的元学习(比如Prototypical Networks、Relation Networks)**中,元参数θ\thetaθ就是“神经网络的通用特征提取器的参数”——当元学习模型遇到新任务时,先用这个通用特征提取器把支持集SSS和查询集QQQ中的样本映射到“特征空间”,然后在特征空间中用“某种度量方式(比如欧氏距离、余弦相似度、关系网络)”来判断查询样本的标签;
- 在**基于记忆的元学习(比如MANN、NTM)**中,元参数θ\thetaθ就是“神经网络的控制器和记忆读写头的参数”——当元学习模型遇到新任务时,先用支持集SSS的少量样本“写入”记忆矩阵,然后用查询集QQQ中的样本“读取”记忆矩阵,从而得到预测结果;
- 在**基于模型的元学习(比如Meta-Learning with Hypernetworks)**中,元参数θ\thetaθ就是“超网络(Hypernetwork)的参数”——超网络的作用是“根据新任务的支持集SSS,生成适配这个任务的神经网络的参数θ′\theta'θ′”。
不管元参数θ\thetaθ的含义是什么,它的核心目标都是一样的:让元学习模型在遇到新任务时,只用极少量的样本就能快速适配,并且在查询集上取得好的效果。
2.2.9 元训练(Meta-Training)
元训练就是**“训练元参数θ\thetaθ”的过程**——这是元学习中最耗时的过程,通常需要用GPU跑上几个小时甚至几天。
元训练的大致流程是这样的(不同的元学习算法流程略有不同,但核心思想是一样的):
- 从元训练集Dmeta-train\mathcal{D}_{\text{meta-train}}Dmeta-train中采样一个批量的训练任务B={T1,T2,…,TB}\mathcal{B} = \{T_1, T_2, \dots, T_B\}B={T1,T2,…,TB}(BBB是批量大小,通常取4-32);
- 对每个训练任务Ti∈BT_i \in \mathcal{B}Ti∈B,执行以下操作:
a. 从TiT_iTi的支持集SiS_iSi中采样样本,对元参数θ\thetaθ进行“快速适配”,得到适配后的参数θi′\theta_i'θi′;
b. 用适配后的参数θi′\theta_i'θi′在TiT_iTi的查询集QiQ_iQi上计算损失函数值Li(θi′)\mathcal{L}_i(\theta_i')Li(θi′); - 计算批量损失函数值L(θ)=1B∑i=1BLi(θi′)\mathcal{L}(\theta) = \frac{1}{B} \sum_{i=1}^B \mathcal{L}_i(\theta_i')L(θ)=B1∑i=1BLi(θi′);
- 用反向传播更新元参数θ\thetaθ——注意:这里的反向传播是“通过快速适配的过程”进行的,也就是所谓的“二阶导数(Second-Order Derivative)”或者“内循环-外循环(Inner Loop-Outer Loop)”结构(这个我们后面讲MAML的时候会详细拆解);
- 重复步骤1-4,直到元参数θ\thetaθ收敛(或者达到预设的训练轮数)。
2.2.10 元测试(Meta-Testing)
元测试就是**“最终评估元学习模型效果”的过程**——这个过程通常很快,只需要用GPU跑上几分钟。
元测试的大致流程是这样的(不同的元学习算法流程略有不同,但核心思想是一样的):
- 从元测试集Dmeta-test\mathcal{D}_{\text{meta-test}}Dmeta-test中采样一个测试任务TtestT_{\text{test}}Ttest;
- 从TtestT_{\text{test}}Ttest的支持集StestS_{\text{test}}Stest中采样样本,对元参数θ\thetaθ进行“快速适配”,得到适配后的参数θtest′\theta_{\text{test}}'θtest′;
- 用适配后的参数θtest′\theta_{\text{test}}'θtest′在TtestT_{\text{test}}Ttest的查询集QtestQ_{\text{test}}Qtest上计算损失函数值和准确率;
- 重复步骤1-3,对元测试集中的所有测试任务都进行评估;
- 计算所有测试任务的平均损失函数值和平均准确率——这就是元学习模型的最终效果。
好的,元学习的10个核心概念就讲到这里!接下来,我们用表格和Mermaid架构图对比一下元学习和传统监督学习的区别——这能让你更清晰地理解元学习的核心思想。
2.3 元学习 vs 传统监督学习:核心区别对比
2.3.1 核心区别对比表(Markdown格式)
| 对比维度 | 传统监督学习(Traditional Supervised Learning) | 元学习(Meta-Learning) |
|---|---|---|
| 核心目标 | 学会一个具体的任务(比如“识别猫和狗的图像分类任务”) | 学会学习的方法(Learn to Learn),也就是掌握一套通用的“快速适配策略” |
| 训练数据 | 一个固定的、标注好的大数据集(比如ImageNet的120万张标注图片) | 一个任务分布p(T)p(T)p(T),也就是成百上千个相关的小任务的集合,每个小任务只有少量标注样本(支持集SSS) |
| 核心参数 | 一个针对具体任务的模型参数ϕ\phiϕ | 一个通用的元参数θ\thetaθ(可以是初始模型参数、特征提取器参数、超网络参数等等) |
| 训练流程 | 单循环:在固定的大数据集上,通过反向传播调整模型参数ϕ\phiϕ,最小化损失函数值 | 双循环(内循环+外循环): 1. 内循环:在每个小任务的支持集SSS上,快速适配得到参数ϕi′\phi_i'ϕi′ 2. 外循环:在所有小任务的查询集QQQ上,通过反向传播调整元参数θ\thetaθ,最小化平均损失函数值 |
| 测试流程 | 直接用训练好的模型参数ϕ\phiϕ在测试集上进行预测 | 先用测试任务的支持集StestS_{\text{test}}Stest对元参数θ\thetaθ进行快速适配,得到ϕtest′\phi_{\text{test}}'ϕtest′,再用ϕtest′\phi_{\text{test}}'ϕtest′在查询集QtestQ_{\text{test}}Qtest上进行预测 |
| 对新任务的适配能力 | 极差:必须重新收集大量标注数据,重新训练模型参数ϕ\phiϕ,甚至需要重新设计模型架构 | 极强:只需要用测试任务的1-20个标注样本进行快速适配,就能取得好的效果 |
| 数据效率 | 极低:需要大量标注数据 | 极高:只需要少量标注数据(每个小任务的支持集只有1-20个样本) |
| 是否解决灾难性遗忘 | 否:当学习新任务时,会很快忘记旧任务的知识 | 是(部分元学习算法):比如基于记忆的元学习、基于正则化的元学习(和终身学习结合的算法) |
2.3.2 核心区别对比架构图(Mermaid格式)
传统监督学习的架构图
元学习的架构图
成百上千个相关小任务] -----------------------^ Expecting 'SQE', 'DOUBLECIRCLEEND', 'PE', '-)', 'STADIUMEND', 'SUBROUTINEEND', 'PIPE', 'CYLINDEREND', 'DIAMOND_STOP', 'TAGEND', 'TRAPEND', 'INVTRAPEND', 'UNICODE_TEXT', 'TEXT', 'TAGSTART', got 'PS'
好的,元学习和传统监督学习的区别就讲到这里!接下来,我们简单介绍一下元学习中最常用的几个数据集——这些数据集是你后续做元学习实战项目的基础,必须了解!
2.4 元学习中最常用的几个数据集
元学习中的数据集和传统监督学习中的数据集有很大的不同:传统监督学习中的数据集是“一个大的数据集,包含很多样本”;而元学习中的数据集是“一个大的数据集,被分成很多小的任务”。
接下来,我们介绍元学习中最常用的四个图像分类数据集和一个强化学习数据集:
2.4.1 Omniglot数据集(Few-Shot Learning的MNIST)
Omniglot数据集是元学习中最经典、最常用的图像分类数据集,被称为“Few-Shot Learning的MNIST”——因为它和MNIST一样简单,但更适合用来做元学习实验。
基本信息
- 发布者:Brenden M. Lake等人(2015年)
- 官网:https://github.com/brendenlake/omniglot
- 下载地址:可以从官网下载,也可以用PyTorch的
torchvision.datasets.Omniglot直接下载 - 样本数量:共1623个不同的字符(来自50个不同的字母表,比如拉丁字母、希腊字母、阿拉伯字母、中文汉字、日文假名等等),每个字符有20个不同的样本(由20个不同的人手写而成)
- 图像尺寸:105x105像素的灰度图(通常会被缩放到28x28像素,方便和MNIST对比)
- 预处理:通常会做“数据增强”(比如旋转90度、180度、270度——这样可以把字符数量从1623个增加到6492个)
如何划分成元训练集、元验证集、元测试集?
在元学习实验中,Omniglot数据集通常被这样划分:
- 元训练集:取前1200个字符(或者数据增强后的前4800个字符),每个字符作为一个“类别”;
- 元验证集:取接下来的100个字符(或者数据增强后的接下来的400个字符),每个字符作为一个“类别”;
- 元测试集:取最后323个字符(或者数据增强后的最后1292个字符),每个字符作为一个“类别”。
然后,对于每个任务(无论是训练任务、验证任务还是测试任务),我们会采用**NNN-Way KKK-Shot的划分方式**:
- NNN-Way:每个任务包含NNN个不同的类别(也就是NNN个不同的字符),NNN通常取5或者20;
- KKK-Shot:每个类别有KKK个样本作为支持集SSS,KKK通常取1、5或者10;
- 查询集QQQ:每个类别剩下的(20−K)(20-K)(20−K)个样本(或者数据增强后的剩下的(80−K)(80-K)(80−K)个样本)作为查询集QQQ,或者每个类别随机取10-20个样本作为查询集QQQ。
举个例子:如果我们做的是5-Way 1-Shot的实验,那么每个任务包含:
- 5个不同的字符(来自元训练集/元验证集/元测试集);
- 每个字符取1个样本作为支持集SSS,所以支持集SSS的总样本数是5×1=55 \times 1 = 55×1=5;
- 每个字符取10个样本作为查询集QQQ,所以查询集QQQ的总样本数是5×10=505 \times 10 = 505×10=50。
2.4.2 mini-ImageNet数据集(Few-Shot Learning的ImageNet简化版)
mini-ImageNet数据集是Omniglot数据集的“升级版”——因为Omniglot数据集的样本太简单了(都是手写字符),而mini-ImageNet数据集的样本是“真实世界的图片”,更接近实际应用场景。
基本信息
- 发布者:Vinyals等人(2016年)
- 官网:https://github.com/yaoyao-liu/mini-imagenet-tools
- 下载地址:可以从官网下载,也可以从Kaggle下载
- 样本数量:共100个不同的类别(来自ImageNet数据集,比如“狗”“猫”“汽车”“飞机”“鸟”“花”等等),每个类别有600个不同的样本
- 图像尺寸:84x84像素的RGB图
- 预处理:通常会做“归一化”(把像素值从0-255缩放到0-1,或者用ImageNet的均值和标准差进行归一化)和“数据增强”(比如随机裁剪、随机翻转、颜色抖动等等)
如何划分成元训练集、元验证集、元测试集?
在元学习实验中,mini-ImageNet数据集通常被这样划分(Vinyals等人2016年的划分方式):
- 元训练集:取前64个类别,每个类别有600个样本;
- 元验证集:取接下来的16个类别,每个类别有600个样本;
- 元测试集:取最后20个类别,每个类别有600个样本。
和Omniglot数据集一样,mini-ImageNet数据集也采用**NNN-Way KKK-Shot的划分方式**,NNN通常取5或者20,KKK通常取1、5或者10。
2.4.3 tiered-ImageNet数据集(更符合元学习任务分布假设的数据集)
tiered-ImageNet数据集是mini-ImageNet数据集的“进一步升级版”——因为mini-ImageNet数据集的元训练集、元验证集、元测试集的类别是“随机划分”的,而tiered-ImageNet数据集的类别是“按照ImageNet的语义层次结构划分”的,更符合元学习的“任务分布假设”(也就是“训练任务和测试任务是相关的,来自同一个语义领域”)。
基本信息
- 发布者:Ren等人(2018年)
- 官网:https://github.com/renmengye/few-shot-ssl-public
- 下载地址:可以从官网下载
- 样本数量:共608个不同的类别(来自ImageNet数据集的34个高级语义类别,比如“动物”“植物”“交通工具”“电子产品”等等),每个类别有600个不同的样本
- 图像尺寸:84x84像素的RGB图
- 预处理:和mini-ImageNet数据集一样
如何划分成元训练集、元验证集、元测试集?
在元学习实验中,tiered-ImageNet数据集通常被这样划分(Ren等人2018年的划分方式):
- 元训练集:取20个高级语义类别,对应的351个低级类别,每个类别有600个样本;
- 元验证集:取6个高级语义类别,对应的97个低级类别,每个类别有600个样本;
- 元测试集:取8个高级语义类别,对应的160个低级类别,每个类别有600个样本。
和Omniglot、mini-ImageNet数据集一样,tiered-ImageNet数据集也采用**NNN-Way KKK-Shot的划分方式**。
2.4.4 CIFAR-FS数据集(适合小模型的Few-Shot图像分类数据集)
CIFAR-FS数据集是基于CIFAR-100数据集构建的——因为CIFAR-100数据集的图像尺寸小(32x32像素),适合用来训练小模型,加快元学习的实验速度。
基本信息
- 发布者:Bertinetto等人(2019年)
- 官网:https://github.com/bertinetto/r2d2
- 下载地址:可以从官网下载,也可以用PyTorch的
torchvision.datasets.CIFAR100自己构建 - 样本数量:共100个不同的类别(来自CIFAR-100数据集),每个类别有600个不同的样本
- 图像尺寸:32x32像素的RGB图
- 预处理:通常会做“归一化”(用CIFAR-100的均值和标准差进行归一化)和“数据增强”(比如随机裁剪、随机翻转、颜色抖动等等)
如何划分成元训练集、元验证集、元测试集?
在元学习实验中,CIFAR-FS数据集通常被这样划分(Bertinetto等人2019年的划分方式):
- 元训练集:取前64个类别,每个类别有600个样本;
- 元验证集:取接下来的16个类别,每个类别有600个样本;
- 元测试集:取最后20个类别,每个类别有
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)