torchvision中数据集使用
前面主要讲解了自定义数据集,自定义数据集可以用于自己的一些需求
同时讲解了一些比较常见的transforms,对图片进行一些处理,但是讲解的有一些不足,比如transforms只是显示如何对单个图片,真正使用时需要对数据集中每一个图片进行处理
所以以下学习如何将数据集和transforms结合在一起,同时介绍一些科研或毕设中一些标准的数据集该如何去下载、如何去组织、如何去查看、如何去使用
官方文档(DOCs),PyTorch被分为了不同的块,PyTorch可以被认为是核心模块,Domains里有一些语音的、视觉的、文本的等等
torchvision中左上角选择0.9.0便可跟讲者网站一样
里面有好几个模块,比如说torchvision.datasets,这就是PyTorch为我们提供的数据集的API文档,就是我们写代码的时候指定这些相应的数据集,给它设置一些参数,它就能自己去下载去使用这些标准的数据集、或者在科研当中使用的数据集
COCO数据集,一般用于一些目标检测或者是语义分割中
MNIST数据集,一个入门数据集,手写文字数据集
CIFAR数据集,一般用于物体识别,其中有不同的图片有不同的物体\
torchvision.io模块,一般不常用,不做讲解
torchvision.models模块,会提供一些比较常见的神经网络,这些神经网络有的预训练好了,这个模块后面会使用到,相对来说比较重要,比如说有用于分类的模型、语义分割的模型、目标检测、视频分类,这个在毕设或者科研中有可能会使用到
torchvision.ops模块,会提供一些比较少见的特殊的操作,基本上用不到
torchvision.transforms模块,已经讲解过了
torchvision.utils模块,之前的tensorboard就是来自这个模块,提供常见的小工具
这次主要讲解torchvision.datasets,以及datasets如何跟transforms进行联合的使用
首先看一下如何去使用torchvision提供的一些标准数据集,可以看一下官方文档CIFAR,这些参数都比较相近,而且相对来说比较简单
root,数据集在什么位置
train,为True就是训练集,为False就是测试集
transform,如果想对数据集当中的所有数据进行一个什么样的变化,把transform写在这里就可以
target_transform,对于target进行一个transform
download,为True就自动下载数据集,为False就不会下载,就是我们在网上如果需要下载数据集需要自己去搜索,但是如果使用torchvision提供的代码,直接把download设置为True的话就能省去很多麻烦
import torchvision
train_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=False, download=True)
运行就会下载,有的时候下载比较慢,可以复制下载地址在迅雷中下载
迅雷下载的好处就是它不一定从原地址进行下载, 比如说P2P加速、或者有其他人下载好了进行共享速度、或者来自镜像,我们可以从属性中查看,有一部分来着原始资源、一部分来自镜像加速、也有来自P2P加速、甚至会员加速,这种方式可以一定程度上加快下载速度
下载完成也会进行校验,首先下载了压缩文件,之后解压,其中有相应的数据集
查看数据集中的第一个print(test_set[0])
返回(<PIL.Image.Image image mode=RGB size=32x32 at 0x1FA28CA0BB0>, 3)
3是什么?是target
在print(test_set[0])
处打一个断点,debug一下
可以看到test_set里有一个属性是classes,数字是0时就是airplane,这个就是把真实的类别对应成了一个具体的数字
我们知道classes是test_set的一个属性,可以print(test_set.classes)
查看
返回['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
我们知道test_set的第一个组成是
我们就可以写,img, target = test_set[0]
,我们就获得了图片和target,可以验证一下
print(img)
返回<PIL.Image.Image image mode=RGB size=32x32 at 0x234CA7FA920>
print(target)
返回3
建议大家使用PyTorch中代码给的数据集的时候,download=True,img就是PIL格式,target就是3,3对应的就是0、1、2、3的cat
可以验证一下print(test_set.classes[target])
返回cat
想展示这个图片,对应PIL格式图片,直接img.show()
这个数据集像素比较小,32x32,我们可以隐约看出来是一只猫
接下来讲解CIFAR10 Dataset
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
CIFAR-10 数据集由 60000 张 32x32 彩色图像组成,分为 10 类,每类 6000 张。其中有 50000 张训练图像和 10000 张测试图像。
已经讲解了torchvision中常见的dataset的使用,接下来和transforms进行联动,原始图片是PIL格式,要给PyTorch使用时要转为tensor数据类型
import torchvision
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=True, transform= dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=False, transform= dataset_transform, download=True)
print(test_set[0])
Compose中可以添加Resize、Crop等操作,但此处图片比较小,只进行一个操作
dataset_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize
])
在数据集中设置transform这个属性
train_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=True, transform= dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=False, transform= dataset_transform, download=True)
其中print(test_set[0])
查看图片是否转为了tensor数据类型
返回:
(tensor([[[0.6196, 0.6235, 0.6471, ..., 0.5373, 0.4941, 0.4549],
[0.5961, 0.5922, 0.6235, ..., 0.5333, 0.4902, 0.4667],
[0.5922, 0.5922, 0.6196, ..., 0.5451, 0.5098, 0.4706],
...,
[0.2667, 0.1647, 0.1216, ..., 0.1490, 0.0510, 0.1569],
[0.2392, 0.1922, 0.1373, ..., 0.1020, 0.1137, 0.0784],
[0.2118, 0.2196, 0.1765, ..., 0.0941, 0.1333, 0.0824]],
[[0.4392, 0.4353, 0.4549, ..., 0.3725, 0.3569, 0.3333],
[0.4392, 0.4314, 0.4471, ..., 0.3725, 0.3569, 0.3451],
[0.4314, 0.4275, 0.4353, ..., 0.3843, 0.3725, 0.3490],
...,
[0.4863, 0.3922, 0.3451, ..., 0.3804, 0.2510, 0.3333],
[0.4549, 0.4000, 0.3333, ..., 0.3216, 0.3216, 0.2510],
[0.4196, 0.4118, 0.3490, ..., 0.3020, 0.3294, 0.2627]],
[[0.1922, 0.1843, 0.2000, ..., 0.1412, 0.1412, 0.1294],
[0.2000, 0.1569, 0.1765, ..., 0.1216, 0.1255, 0.1333],
[0.1843, 0.1294, 0.1412, ..., 0.1333, 0.1333, 0.1294],
...,
[0.6941, 0.5804, 0.5373, ..., 0.5725, 0.4235, 0.4980],
[0.6588, 0.5804, 0.5176, ..., 0.5098, 0.4941, 0.4196],
[0.6275, 0.5843, 0.5176, ..., 0.4863, 0.5059, 0.4314]]]), 3)
使用tensorboard进行显示,显示测试数据集中的前10张图片
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=True, transform= dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="P14_torchvision_transforms/dataset", train=False, transform= dataset_transform, download=True)
writer = SummaryWriter(log_dir='logs')
for i in range(10):
img, target = test_set[0]
writer.add_image("test_set", img, i)
writer.close()
在终端输入tensorboard --logdir=logs
即可在tensorboard中查看
这部分主要介绍了一些torchvision中一些数据集的使用方式,也介绍了datasets中transforms
其它的数据集使用方式都比较简单,比如COCO数据集
使用前可以注意Parameters的设置,也有一些Example
数据集下载很慢怎么办?之前说了使用迅雷下载,下载完之后呢?
可以新建一个目录(比如叫dataset),将下载好的数据集放在新建的目录下,将download设置为True
它会使用已经下载好的,同时进行校验
返回:
Using downloaded and verified file: ./dataset\cifar-10-python.tar.gz
Files already downloaded and verified
所以download一直设置为True会比较方便,下载慢的话可以使用迅雷进行下载
有时候想查看数据集,有的数据集没有显示下载地址怎么办?
以CIFAR10数据集为例,查看源代码,会看到:
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
原始资料地址:
torchvision中数据集使用
如有侵权联系删除 仅供学习交流使用