Pytorch学习笔记15-图片数据建模流程范例
由于时效问题,该文某些代码、技术可能已经过期,请注意!!!本文最后更新于:3 年前
图片数据
1 |
|
一,准备数据
cifar2数据集为cifar10数据集的子集,只包括前两种类别airplane和automobile。
训练集有airplane和automobile图片各5000张,测试集有airplane和automobile图片各1000张。
cifar2任务的目标是训练一个模型来对飞机airplane和机动车automobile两种图片进行分类。
在Pytorch中构建图片数据管道通常有两种方法。
第一种是使用 torchvision中的datasets.ImageFolder来读取图片然后用 DataLoader来并行加载。
第二种是通过继承 torch.utils.data.Dataset 实现用户自定义读取逻辑然后用 DataLoader来并行加载。
第二种方法是读取用户自定义数据集的通用方法,既可以读取图片数据集,也可以读取文本数据集。
本篇我们介绍第一种方法。
1 |
|
1 |
|
1 |
|
1 |
|
二,定义模型
使用Pytorch通常有三种方式构建模型:使用nn.Sequential按层顺序构建模型,继承nn.Module基类构建自定义模型,继承nn.Module基类构建模型并辅助应用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict)进行封装。
此处选择通过继承nn.Module基类构建自定义模型。
1 |
|
1 |
|
三,训练模型
Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。
有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。
此处介绍一种较通用的函数形式训练循环。
1 |
|
1 |
|
1 |
|
1 |
|
四,评估模型
1 |
|
五,使用模型
1 |
|
六,保存模型
1 |
|
搬运自:
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!