对比 Categorical vs. Multinominal

对比点 Categorical 分布 Multinomial 分布
抽样单位 抽一次,得到一个类别 抽多次,得到每类出现的次数
是谁的特例? 是 Multinomial 的特例(n=1) 是更广义的分布
应用 单标签分类、one-hot 文本计数、生成模型、多次采样
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch.distributions import Categorical, Multinomial

p = torch.tensor([0.2, 0.5, 0.3]) # 类别 A、B、C 的概率

# Categorical:一次只采样一个类别
cat = Categorical(probs=p)
# 采样一次
cat.sample() # 可能输出 tensor(2)
# 多次采样,得到多个标签
cat.sample((10,)) # 可能输出 tensor([1, 1, 2, 0, 2, 2, 2, 0, 0, 1])

# Multinomial:采样10次,统计每个类别出现几次
multi = Multinomial(total_count=10, probs=p)
multi.sample() # 可能输出 tensor([0., 3., 7.])

Few-Shot-learning

任务描述:我现在有一个小样本识别的任务:通过gelsight感知6种条状物体(三个文具(铅笔pencil,中性笔gel_pen,走珠笔roller_pen)三个工具(小螺丝刀small_screwdriver,大螺丝刀big_screwdriver,扎带cable_tie)),输入是物体接触gelsight得到的图片,输出物体类别。每类的训练数据约为50张,希望易于扩展新类别。请在以下方法中为我选择,并说明理由

方法 底层机制 是否可学习距离 优势 劣势
Prototypical Networks 类中心均值 简单高效 类内差异不适配
Relation Networks 学习相似度函数 灵活泛化强 模型复杂
Matching Networks 注意力+最近邻 可建模上下文 对顺序敏感
Contrastive Learning 特征拉近推远 适合无监督 需要大量负样本,训练技巧敏感
MAML 参数初始化优化 无需特定结构 快速适应新任务 优化不稳定

Metric-based

Prototypical Networks

核心思想:每类样本通过其支持集(support set)的特征均值生成一个“原型”,然后对查询样本(query)进行最近邻分类。

  • 过程

    • 将每类的支持样本嵌入到特征空间。

    • 对每一类计算原型向量(类中心):

    • 查询样本通过与各类原型的欧氏距离决定其类别:

  • 优点:简单高效,计算代价低。

  • 缺点:假设每类数据可用一个原型表示,不适合类内差异大的情况。

Relation Networks

核心思想:通过一个可学习的“关系模块”来度量支持样本与查询样本之间的相似度。

  • 结构
    • 用一个共享的编码器提取支持样本和查询样本的特征。
    • 构造特征对(support-query pairs)并输入一个关系网络(通常是小型CNN)进行关系打分。
    • 学习输出一个[0,1]之间的关系分数,表示是否为同一类。
  • 优点:关系度量是可学习的,比固定距离函数更灵活。
  • 缺点:训练和推理计算量较大。

Matching Network (metric learning )

核心思想:基于注意力机制的最近邻分类方法。结合支持集和查询样本的相似度来预测类别标签。

  • 关键机制

    • 使用带注意力的KNN:对于查询样本,按注意力权重计算其与支持样本的相似性加权结果:

    • 相似度函数一般使用余弦相似度或欧氏距离。

    • 使用BiLSTM等对支持集建模上下文信息。

  • 优点:建模支持集上下文,适合变结构任务。

  • 缺点:训练较复杂,对支持集顺序敏感。

Contrastive Learning

核心思想:通过构造正负样本对,使模型在特征空间中将相似样本拉近、不同样本推远。

  • 方法流程

    • 给定一对正样本(同类)和一对负样本(异类)。

    • 训练模型学习一个嵌入空间,使得:

    • 常用损失函数包括对比损失(Contrastive Loss)、InfoNCE、Triplet Loss等。

  • 优点:无需大量标签,适合自监督+小样本联合训练。

  • 缺点:对样本对的选择敏感,训练时间长。

Optimization-based

MAML (Model-Agnostic Meta-Learning)

核心思想:学习一个通用初始化参数,使得模型可以在少量梯度更新后快速适应新任务。

  • 流程

    1. 内循环:在任务的支持集上进行1-5步梯度下降,得出任务特定参数。
    2. 外循环:在任务的查询集上计算损失,回传以更新通用初始化参数。
  • 数学形式

  • 优点:适用于各种模型结构,不限制任务形式。

  • 缺点:对梯度敏感,优化不稳定,训练代价高。