torchvision.datasets
中包含了以下数据集
Datasets
拥有以下API
:
__getitem__
__len__
由于以上Datasets
都是 torch.utils.data.Dataset
的子类,所以,他们也可以通过torch.utils.data.DataLoader
使用多线程(python的多进程)。
举例说明:
torch.utils.data.DataLoader(coco_cap, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads)
在构造函数中,不同的数据集直接的构造函数会有些许不同,但是他们共同拥有 keyword
参数。
In the constructor, each dataset has a slightly different API as needed, but they all take the keyword args:
transform
: 一个函数,原始图片作为输入,返回一个转换后的图片。(详情请看下面关于torchvision-tranform
的部分)
target_transform
- 一个函数,输入为target
,输出对其的转换。例子,输入的是图片标注的string
,输出为word
的索引。
dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)
参数说明:
processed/training.pt
和 processed/test.pt
的主目录True
= 训练集, False
= 测试集True
= 从互联网上下载数据集,并把数据集放在root
目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed
文件夹下。需要安装COCO API
dset.CocoCaptions(root="dir where images are", annFile="json annotation file", [transform, target_transform])
例子:
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
输出:
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
dset.CocoDetection(root="dir where images are", annFile="json annotation file", [transform, target_transform])
dset.LSUN(db_path, classes='train', [transform, target_transform])
参数说明:
一个通用的数据加载器,数据集中的数据以以下方式组织
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
dset.ImageFolder(root="root folder path", [transform, target_transform])
他有以下成员变量:
This is simply implemented with an ImageFolder dataset.
The data is preprocessed as described here
dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)
参数说明:
cifar-10-batches-py
的根目录True
= 训练集, False
= 测试集True
= 从互联上下载数据,并将其放在root
目录下。如果数据集已经下载,什么都不干。dset.STL10(root, split='train', transform=None, target_transform=None, download=False)
参数说明:
stl10_binary
的根目录True
= 从互联上下载数据,并将其放在root
目录下。如果数据集已经下载,什么都不干。此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。