Vega在Dataset类中提供了数据转换和采样相关的接口和公共方法,用户数据处理类可继承自Dataset类,使用这些公共能力。
Vega提供了常用的数据集类,包括Cifar10
、Cifar100
、ImageNet
、Coco
、FMnist
、Mnist
、Cityscapes
、Div2K
等,具体描述,可参考配置参考。
以下以Cifar10
为例,来说明如何使用Dataset
,使用步骤如下:
调整缺省配置,比如要调整数据文件中训练集的位置为本地文件,如下:
dataset:
type: Cifar10
train:
data_path: "/data/dataset/"
在程序中,使用ClassFactory
来创建Dataset
,model
来初始化训练集或测试集,并使用Dataloader来加载数据,如下:
dataset = ClassFactory.get_cls(Classtype.Dataset)
train_data, test_data = dataset(model='train'), dataset(model='test')
data_loader = train_data.dataloader
for input, target in data_loader:
process_data(input, target)
Vega的所有数据集类都继承自基类Dataset
,Dataset
基类定义了数据集所需的接口, 并提供了dataloader
、transforms
、sampler
等属性,并提供了缺省的实现,派生类可以根据需要来重载这些缺省实现,以下会介绍如何自定义一个 Dataset。
假设用户数据为100张图片,放在一个文件夹中,我们需要实现一个名为 MyDataset
的数据集类,我们需要按照如下步骤进行:
Dataloader
。Transform
。如上所述,类 MyDataset
继承自 Dataset
,如下:
from vega.datasets import Dataset
from vega.core.common import ClassFactory, ClassType
@ClassFactory.register(ClassType.DATASET)
class MyDataset(Dataset):
def __init__(self):
super(MyDataset, self).__init__()
以上代码中,@ClassFactory.register(ClassType.DATASET)
是将 MyDataset
注册到Vega
库中。
将数据集分为训练集和测试集,训练集用于训练模型,测试集用于验证模型。假设示例中的图片都用于训练,则需要指定一个文件位置的配置参数 data_path
。
在模型训练过程中,一般也会动态的将数据集划分为训练集和验证集,需要确定采样方式,顺序采样,还是随机采样,需要增加一个配置参数 shuffle
。配置信息如下:
dataset:
type: MyDataset
train:
data_path: "/data/"
shuffle: false
valid:
data_path: "/data/"
shuffle: false
假定我们从数据集中每次加载1张图片,每次都从文件加载,使用cv2来加载图片,代码如下:
import cv2
class MyDataset(Dataset):
def __len__(self):
return len(self.file)
def __getitem__(self, idx):
img_file = self.file[idx]
img = cv2.imread(img_file)
return img
当前 Vega
已提供了多种 Transform
供参考。
假设 MyDataset
需要实现一个把图片翻转的 Transform
,输入为一张原始图片,输出为翻转后的图片,假设 Vega
并未提供该 Transform
,我们需要调用 ImageOps
的翻转函数来实现,代码如下:
import ImageOps
@TransformFactory.register()
class MyTransform():
def __call__(self, img):
return ImageOps.invert(img.convert('RGB'))
使用时只需在配置文件中加入该transform即可,如下:
dataset:
type: MyDataset
train:
data_path: "/data/dataset/"
transforms:
- type: MyTransform
若在模型训练过程中调整 Transfroms
,可参考调整Transforms。
以下是调测新实现的 MyDataset
类,代码如下:
import unittest
import torchvision.transforms as tf
from roma.env import register_roma_env
from vega.core.pipeline.pipe_step import PipeStep
from vega.core.common.class_factory import ClassFactory, ClassType
import vega
@ClassFactory.register(ClassType.PIPE_STEP)
class FakePipeStep(PipeStep, unittest.TestCase):
def __init__(self):
PipeStep.__init__(self)
unittest.TestCase.__init__(self)
def do(self):
dataset = ClassFactory.get_cls(ClassType.DATASET)(mode="train")
train = dataset.dataloader
self.assertEqual(len(train), 100)
for input, target in train:
self.assertEqual(len(input), 1)
break
class TestDataset(unittest.TestCase):
def test_cifar10(self):
vega.run('./dataset.yml')
if __name__ == "__main__":
unittest.main()
若运行成功,会有如下类似的信息输出:
Ran 1 test in 12.119s
OK
配置文件:
pipeline: [fake]
fake:
pipe_step:
type: FakePipeStep
dataset:
type: MyDataset
train:
data_path: "/data/dataset/train/"
shuffle: false
transform:
- type: MyTransform
valid:
data_path: "/data/dataset/valid/"
shuffle: false
代码:
import cv2
class MyDataset(Dataset):
def __init__(self, **kwargs):
"""Construct the MyDataset class."""
Dataset.__init__(self, **kwargs)
self.args.data_path = FileOps.download_dataset(self.args.data_path)
def __len__(self):
"""Get the length of the dataset."""
return len(self.file)
def __getitem__(self, idx):
"""Get an item of the dataset according to the index."""
img_file = self.file[idx]
img = cv2.imread(img_file)
return img
初始化 dataset
时指定Transforms
dataset = ClassFactory.get_cls(ClassType.DATASET)(
mode="train",
transforms=[tf.RandomCrop(32, padding=4), tf.RandomHorizontalFlip()]
)
在模型训练过程中动态调整 Transforms
提供了 append()
, insert()
, remove()
, replace()
等方法,分别提供了追加、插入、删除和替换方法,如下:
dataset.transforms.append(tf.ToTensor())
dataset.transforms.insert(2, "Color", level=2)
dataset.transforms.remove("Color")
dataset.transforms.replace(
[tf.RandomCrop(32, padding=4), tf.RandomHorizontalFlip()]
)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。