# Pytorch学习记录

# 自带数据集介绍

位置在torchvision.datasets下

# 基础图像分类数据集(新手必学)

数据集 核心用途 关键特点
MNIST 手写数字识别(入门中的入门) 60k训练/10k测试,28×28灰度图,10个数字类别(0-9),CV学习者的第一个经典数据集
EMNIST 手写字符识别(MNIST扩展) 包含大小写字母+数字,多套划分方案,适合复杂字符分类任务
FashionMNIST 时尚服饰分类(替代MNIST,难度稍高) 60k训练/10k测试,28×28灰度图,10类服饰(T恤、裤子、鞋子等),避免MNIST过于简单的局限性
KMNIST 手写日文假名识别 与MNIST格式一致,10类日文平假名,适合特定语言字符识别和跨数据集验证
QMNIST 扩展手写数字识别(MNIST升级) 包含更多样本和额外标注信息,兼容MNIST格式,适合需要更丰富数据的手写识别任务
CIFAR10 小尺寸彩色图像分类(入门核心) 60k张32×32彩色图,10类日常物体(飞机、猫、狗等),入门彩色图像分类的必用数据集
CIFAR100 小尺寸彩色图像细分类(CIFAR10进阶) 60k张32×32彩色图,100个细分类别(分20个超类),难度更高,验证模型细分类能力
STL10 中等尺寸图像分类(支持半监督学习) 10类物体,96×96彩色图,包含大量无标注数据,介于CIFAR和ImageNet之间
SVHN 街景门牌号数字识别 从街景中提取的数字,32×32彩色图,带标注框,可兼顾分类和简单目标检测任务
USPS 手写数字识别(邮政系统数据集) 9k训练/2k测试,28×28灰度图,与MNIST类似但数据来源不同,用于交叉验证模型性能
SEMEION 手写数字像素级标注分类 手写数字灰度图,附带像素级标注,可兼顾分类和简单像素级分割任务

# 大规模/细粒度图像分类数据集

数据集 核心用途 关键特点
ImageNet 大规模通用图像分类(CV领域里程碑) 数百万张图像,1000个类别,覆盖日常物体,是ResNet、ViT等预训练模型的核心训练数据
Caltech101 细粒度物体分类(入门级细粒度) 101类物体,每类约70张图,标注清晰,适合细粒度分类任务入门
Caltech256 细粒度物体分类(Caltech101进阶) 256类物体,样本量更多,类内差异更大,提升细粒度分类模型的训练难度
Flowers102 花卉细分类 102类常见花卉,每类固定样本数,标注质量高,是细粒度视觉分类的经典基准
Food101 食物细分类 101类全球常见食物,每类1000张图,类内差异大(同一种食物不同做法),适合生活场景分类
StanfordCars 汽车细分类(品牌+型号+年份) 196类汽车,标注精细,包含车辆年份信息,对模型细节提取能力要求较高
FGVCAircraft 飞机细分类(机型+型号) 100类民航飞机,细粒度标注,类内差异小,专注于交通工具的细分类任务
OxfordIIITPet 宠物细分类(猫+狗品种) 37类宠物(猫/狗不同品种),同时包含语义分割标注,可兼顾分类和分割任务
EuroSAT 卫星遥感图像分类 10类土地覆盖类型(农田、森林、城市等),基于卫星图像,适合遥感视觉任务
Country211 场景图像的国家分类 211个国家的场景图像,通过图像判断拍摄国家,结合场景识别与地理信息
DTD 视觉纹理图像分类 47类视觉纹理(布纹、木纹、石纹等),专注于图像纹理特征,而非物体本身
SUN397 大规模日常场景分类 397类日常场景(客厅、公园、教室等),样本量巨大,是场景分类的核心基准
INaturalist 自然物种超大规模分类 数百万张自然物种图像(动物、植物、真菌等),类别数上万,适合生物多样性相关视觉任务

# 人脸/人物相关数据集

数据集 核心用途 关键特点
CelebA 人脸属性分析、人脸关键点检测 超20万张名人人脸图像,包含40个人脸属性标注(是否戴眼镜、是否微笑等),支持人脸关键点检测,人脸任务入门首选
LFWPeople 野生环境下人脸数据集(人脸识别) Labeled Faces in the Wild,野生环境人脸,用于构建人脸识别模型的基础数据集
LFWPairs 人脸验证(判断是否为同一人) 提供人脸对样本,用于验证模型的人脸匹配能力,是人脸识别任务的经典基准
WIDERFace 野生环境大规模人脸检测 包含各种姿态、尺度、遮挡的人脸,标注了人脸边界框,是人脸检测任务的核心基准

# 目标检测/实例分割数据集

数据集 核心用途 关键特点
CocoDetection 通用目标检测、实例分割(行业标杆) COCO数据集,80类常见物体,标注包含边界框、实例分割、关键点检测,是目标检测/分割竞赛的核心基准
VOCDetection 目标检测入门(经典老基准) PASCAL VOC数据集,20类物体,标注清晰,样本量适中,目标检测入门首选
Kitti 自动驾驶场景目标检测 自动驾驶街景数据,标注了汽车、行人、自行车等物体,包含3D标注,适合自动驾驶视觉任务

# 语义分割/场景分割数据集

数据集 核心用途 关键特点
Cityscapes 城市街景语义分割(行业标杆) 城市道路场景的语义分割数据集,30类标注(道路、建筑、行人、车辆等),标注质量高,是城市场景分割的核心基准
VOCSegmentation 语义分割入门(PASCAL VOC配套) PASCAL VOC的分割版本,21类(含背景),样本量适中,适合语义分割模型入门验证
SBDataset 语义分割扩展(补充VOC) 对PASCAL VOC的分割标注进行了扩展,样本量更多,标注更细致,用于提升分割模型的训练效果
OxfordIIITPet 宠物语义分割(细粒度) 除了宠物分类,还包含宠物的像素级分割标注(前景宠物+背景),适合小目标分割入门

# 小众/进阶CV数据集

数据集 核心用途 关键特点
CocoCaptions 图像字幕生成(跨模态任务) COCO的字幕版本,每张图像对应5句描述性字幕,用于训练图像到文本的生成模型
Flickr8k 图像字幕生成(入门级) 基于Flickr图片的字幕数据集,样本量较小,适合图像字幕任务入门
Flickr30k 图像字幕生成(进阶) Flickr8k的扩展版,样本量更多,字幕质量更高,提升跨模态模型训练效果
KittiFlow 光流估计(自动驾驶场景) 自动驾驶街景的光流数据集,标注了像素级运动轨迹,适合自动驾驶运动分析
Sintel 光流估计(电影场景) 基于电影片段的光流数据集,场景复杂,对模型的鲁棒性要求更高
FlyingChairs 光流估计(入门级) 合成的椅子运动场景,光流标注清晰,适合光流估计模型入门
FlyingThings3D 3D光流/立体视觉 基于3D合成场景的光流数据集,包含3D标注,适合立体视觉和3D运动分析
HMDB51 视频动作识别(入门级) 51类人物动作视频,样本量适中,适合视频动作识别任务入门
UCF101 视频动作识别(进阶) 101类人物动作视频,样本量更大,场景更丰富,是视频动作识别的经典基准
Kinetics400 视频动作识别(大规模) 400类人物动作视频,数百万条视频片段,是大规模视频动作识别的核心基准
CLEVRClassification 视觉推理(逻辑判断) 合成的简单场景图像,用于训练模型的视觉推理能力(如“有多少个红色立方体”)
FakeData 生成假数据(测试用) 自动生成的假图像数据,用于快速测试模型输入输出、数据加载流程,无需下载真实数据
ImageFolder 自定义图像数据集加载(工具类) 按目录结构(类名=目录名)组织私有图像数据,快速加载为PyTorch可处理的数据集
DatasetFolder 自定义通用数据集加载(工具类) 比ImageFolder更通用,支持非图像数据,按目录结构加载自定义私有数据
PCAM 医学图像分类(细胞病理) 医学细胞图像的二分类任务(是否为癌症细胞),适合医学影像分析任务入门
Places365 超大规模场景分类 365类场景图像,样本量巨大,专注于场景识别而非物体识别

# 数据集父类

数据集 核心用途 关键特点
VisionDataset 所有torchvision视觉数据集的基类 定义了基础方法和属性(数据加载、预处理、标签获取等),保证所有数据集接口一致性

# 损失函数

# 分类任务损失函数

# 交叉熵损失(Cross Entropy Loss)

nn.CrossEntropyLoss()

核心特色

  1. 是「对数似然损失(NLLLoss)」+「Softmax 激活函数」的组合,直接接收模型输出的原始得分(logits),无需手动添加 Softmax,避免数值不稳定。
  2. 自动计算真实类别对应得分的负对数概率,惩罚预测错误且置信度高的样本,对分类任务的优化效果更优。
  3. 支持类别权重设置(weight参数),解决类别不平衡问题(比如少数类样本权重设高,提升模型对少数类的关注)。

适用场景单标签多分类任务(最常用,几乎是单标签分类的默认选择)

  • 示例:CIFAR10/100 图像分类(10/100 类选 1)、手写数字识别(10 类选 1)、文本主题分类(多个主题选 1)。

# 二元交叉熵损失(Binary Cross Entropy, BCE)

nn.BCELoss()

核心特色

  1. 针对二分类任务设计,计算每个样本的二元概率(0 或 1)的交叉熵损失。
  2. 要求模型输出经过Sigmoid 激活,将得分映射到[0, 1]区间(表示属于正类的概率)。
  3. 支持样本权重设置(weight参数),解决样本不平衡问题。

适用场景二元分类任务(二选一)

  • 示例:垃圾邮件识别(是 / 否)、情感分析(正面 / 负面)、疾病检测(患病 / 未患病)。

# 带 Sigmoid 的二元交叉熵损失(BCEWithLogitsLoss)

nn.BCEWithLogitsLoss()

核心特色

  1. 是「nn.BCELoss()」+「nn.Sigmoid()」的组合,直接接收模型输出的 logits,无需手动添加 Sigmoid。
  2. 采用数值稳定的计算方式,避免手动 Sigmoid 后可能出现的梯度消失 / 爆炸问题,性能优于nn.BCELoss()
  3. 支持类别权重(pos_weight参数),专门解决二分类中正负样本不平衡(比如正样本极少)的问题。

适用场景二元分类任务(替代nn.BCELoss(),是二分类的首选)、多标签分类任务(一个样本对应多个类别,每个类别独立判断是 / 否)

  • 示例:图片标签标注(一张图可能同时包含 “猫”、“狗”、“户外” 多个标签)、文本情感细粒度标注(同时包含 “开心”、“激动” 多个情绪)。

# 负对数似然损失(NLLLoss)

nn.NLLLoss()

核心特色

  1. 计算真实类别对应预测概率的负对数,要求模型输出先经过LogSoftmax激活(将 logits 映射为对数概率)。
  2. nn.CrossEntropyLoss()的底层组成部分,灵活性更高(可自定义前序激活逻辑)。

适用场景单标签多分类任务(较少直接使用,通常在自定义分类模型时使用)

  • 示例:自定义带有LogSoftmax输出层的分类模型。

# 回归任务损失函数

# 均方误差损失(Mean Squared Error, MSE)

nn.MSELoss()

核心特色

  1. 计算预测值与真实值之间平方差的平均值,对异常值(离群点)非常敏感(平方会放大异常值的损失)。
  2. 梯度平滑,优化过程稳定,是回归任务的默认选择。
  3. 支持不同的归约方式(reduction参数:mean求平均、sum求和、none返回每个样本的损失)。

适用场景一般回归任务(预测连续平稳的数值)

  • 示例:房价预测、温度预测、图像超分重建(像素值回归)。

# 平均绝对误差损失(Mean Absolute Error, MAE)

nn.L1Loss()

核心特色

  1. 计算预测值与真实值之间绝对差的平均值,对异常值不敏感(无平方放大效应),鲁棒性更强。
  2. 损失函数在预测值等于真实值处不可导(梯度不连续),优化速度可能比 MSE 慢,容易陷入局部最优。

适用场景存在异常值的回归任务(希望忽略离群点的影响)

  • 示例:传感器数据预测(容易出现异常噪声)、销售额预测(存在突发极端值)。

# 平滑 L1 损失(Smooth L1 Loss)

nn.SmoothL1Loss()(也叫 Huber Loss 的简化版)

核心特色

  1. 结合了 MSE 和 MAE 的优点:误差小时(|x| < 1)用 MSE,梯度平滑;误差大时(|x| ≥ 1)用 MAE,避免异常值放大损失
  2. 既保证了优化的稳定性,又具备较强的鲁棒性,是目标检测任务的核心损失组成。

适用场景

  1. 存在异常值的回归任务(替代 MAE,优化效果更好)。
  2. 目标检测任务(预测边界框的坐标偏移量,如 YOLO、Faster R-CNN)。

# 常用优化器

# 通用说明

所有 PyTorch 优化器均继承自 torch.optim.Optimizer,拥有两个通用必填参数

  1. params:模型可训练参数,通常传入 model.parameters()
  2. lr:学习率,控制参数更新步长,常用范围 1e-4 ~ 1e-2

# 基础优化器:SGD(随机梯度下降)

核心原理

全称 Stochastic Gradient Descent,随机抽取小批次样本计算梯度并更新参数(区别于全量数据的批量梯度下降 BGD)。引入「动量」可模拟物理惯性,缓解梯度震荡,加速收敛,是最基础的优化器基线。

核心参数(PyTorch:torch.optim.SGD

参数名 类型 默认值 核心作用
params 迭代器 - 模型可训练参数(通用必填)
lr 浮点数 - 学习率(必填,SGD 常用 1e-31e-2
momentum 浮点数 0 动量系数,平滑梯度震荡、加速收敛(推荐设为 0.9
dampening 浮点数 0 动量阻尼系数,与 momentum 配合使用,保持默认 0 即可
weight_decay 浮点数 0 权重衰减(L2 正则化),防止过拟合(推荐 1e-4 ~ 1e-5
nesterov 布尔值 False 是否启用 Nesterov 动量(进阶优化,启用需保证 momentum > 0

# 主流优化器:Adam(自适应矩估计)

核心原理

全称 Adaptive Moment Estimation,工业界/学术界首选通用优化器。结合「SGD 动量」和「RMSprop 自适应学习率」的优点:

  1. 计算梯度一阶矩(均值),模拟动量平滑梯度波动
  2. 计算梯度二阶矩(方差),为不同参数分配自适应学习率
  3. 对矩估计进行偏差修正,提升训练初期稳定性

核心参数(PyTorch:torch.optim.Adam

参数名 类型 默认值 核心作用
params 迭代器 - 模型可训练参数(通用必填)
lr 浮点数 1e-3 学习率(必填,Adam 常用 1e-3,无需过大)
betas 元组 (0.9, 0.999) 一阶矩/二阶矩指数衰减系数:
β₁(第一个值):控制动量效果(默认 0.9,无需修改)
β₂(第二个值):控制二阶矩衰减(默认 0.999,保持默认)
eps 浮点数 1e-8 极小值,防止分母为 0,避免数值计算不稳定(保持默认)
weight_decay 浮点数 0 权重衰减(L2 正则化),防止过拟合(推荐 1e-4
amsgrad 布尔值 False 是否启用 AMSGrad 变体(解决二阶矩估计偏差,通常无需启用)

# Adam 改进版:AdamW

核心原理

Adam 的进阶变体,核心改进是将权重衰减与参数更新解耦

  • Adam 的 weight_decay 对「更新后参数」惩罚,效果受自适应学习率影响
  • AdamW 对「原始参数」进行 L2 正则化,再执行参数更新,防过拟合效果更优,适合大模型

核心参数(PyTorch:torch.optim.AdamW

参数名 类型 默认值 核心作用
params 迭代器 - 模型可训练参数(通用必填)
lr 浮点数 1e-3 学习率(必填,常用 5e-41e-3
betas 元组 (0.9, 0.999) 一阶矩/二阶矩衰减系数,与 Adam 一致(保持默认)
eps 浮点数 1e-8 防止分母为 0,保持默认
weight_decay 浮点数 1e-2 权重衰减(L2 正则化),推荐 1e-4 ~ 1e-3
amsgrad 布尔值 False 是否启用 AMSGrad 变体,保持默认

# 自适应学习率:RMSprop

核心原理

全称 Root Mean Square Propagation,专注自适应调整学习率,解决 SGD 学习率全局统一的问题。通过计算梯度平方的移动平均值,为每个参数分配不同学习率,有效缓解非凸优化中的震荡,收敛速度快于 SGD。

核心参数(PyTorch:torch.optim.RMSprop

参数名 类型 默认值 核心作用
params 迭代器 - 模型可训练参数(通用必填)
lr 浮点数 1e-2 学习率(必填,常用 1e-31e-4
alpha 浮点数 0.99 梯度平方移动平均系数(保持默认,控制平滑程度)
eps 浮点数 1e-8 防止分母为 0,保持默认
momentum 浮点数 0 动量系数,可选启用(推荐 0.9,提升收敛速度)
weight_decay 浮点数 0 权重衰减(L2 正则化),推荐 1e-4
centered 布尔值 False 是否使用中心化梯度(减少偏差,计算量略大,保持默认)

# 稀疏数据适配:Adagrad

核心原理

全称 Adaptive Gradient Algorithm,最早的自适应学习率优化器之一。通过累积每个参数的梯度平方和,为更新频繁的参数减小学习率,为更新稀疏的参数增大学习率。缺点:学习率单调递减,后期可能趋近于 0 导致收敛停滞

核心参数(PyTorch:torch.optim.Adagrad

参数名 类型 默认值 核心作用
params 迭代器 - 模型可训练参数(通用必填)
lr 浮点数 1e-2 学习率(必填,常用 1e-3
lr_decay 浮点数 0 学习率衰减系数,加速后期学习率下降(保持默认 0)
weight_decay 浮点数 0 权重衰减(L2 正则化),推荐 1e-4
eps 浮点数 1e-10 防止分母为 0,保持默认
Last Updated: 1/30/2026, 8:32:15 AM