pytorch数据加载常用torch.utils.data.Dataset和torch.utils.data.DataLoader两个类进行实现,简单来说:
- torch.utils.data.Dataset:完成数据的初步读取和加载,其内的每一条数据是"零散"的
- torch.utils.data.DataLoader:对torch.utils.data.Dataset中"零散"的数据进行打包,同时也可以进行一些后处理操作和采样操作。
下面就通过代码的方式详细介绍上面的两个类。
torch.utils.data.Dataset
Dataset类简单来说就是完成数据的读取操作【当然也可以做一些简单操作】,pytorch中也内置了很多常用的计算机视觉的数据集【如如MNIST、CIFAR10、ImageNet】,通常是通过torchvision.datasets块来实现。
import torch
from torchvision import datasets, transforms
# 加载MNIST数据集
= datasets.MNIST(root='./data', train=True, download=True)
train_dataset # 加载CIFAR10数据集
= datasets.CIFAR10(root='./data', train=True, transform=None)
train_dataset ...
除了上面torchvision中已经有的数据库,用Dataset类来完成自己的数据集的读取也是很重要的,比如现在手里手里有一批用于训练ControlNet模型的数据,每一组内包含【RGB图像,Conditioning图像,promot】,数据的目录结构如下:
读取的示例如下:
'''
data.Dataset是pytorch数据加载的基准,一般我们要重写其三个方法:
__init__方法:进行类的初始化,一般是用来读取原始数据。
__getitem__方法:根据下标对每一组数据进行处理。return:对dataset[index]处理后的一组数据
__len__方法:return:数据集的数量(int)
'''
class Custom_Dataset(data.Dataset):
def __init__(self, data_root):
self.image_dir = os.path.join(data_root, "images")
self.conditioning_image_dir = os.path.join(data_root, "conditioning_images")
self.prompt_dir = os.path.join(data_root, "prompts")
self.paired_data_paths = self.get_paired_data_paths()
# 进行一些简单的数据预处理,比如这里的缩放和灰度化
self.image_transforms = transforms.Compose([
512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(
transforms.Grayscale()
] )
def get_paired_data_paths(self):
= []
pair_data_paths = glob("{}/*.png".format(self.image_dir))
image_paths
image_paths.sort()for img_path in image_paths:
= os.path.basename(img_path)
img_name = os.path.join(self.conditioning_image_dir, img_name)
cond_img_path = os.path.join(self.prompt_dir, img_name.replace(".png", ".txt"))
prompt_path
pair_data_paths.append((img_path, cond_img_path, prompt_path))return pair_data_paths
# item_idx表示整个数据集中当前读到的数据下标
def __getitem__(self, item_idx):
= self.paired_data_paths[item_idx]
img_path, cond_img_path, prompt_path = self.image_transforms(Image.open(img_path).convert("RGB"))
image = self.image_transforms(Image.open(cond_img_path).convert("RGB"))
cond_image with open(prompt_path, "r") as fr:
= json.loads(fr.read())["text"]
prompt # 1.这里的返回形式没有要求,可以是tensor、np.darray、numbers、dicts、lists、PIL.Image等各种形式
# 2.【后面会说可先跳过】若包在DataLoader中且collate_fn为None,则这里不能有PIL.Image类型,比如可以换成np.darray
# return np.array(image), np.array(cond_image), prompt
return image, cond_image, prompt
def __len__(self):
return len(self.paired_data_paths)
if __name__ == "__main__":
= "./fill50k_imagefolder"
data_root = Custom_Dataset(data_root)
custom_dataset # 可以直接对Dataset进行一个个样本的遍历
for image, cond_image, prompt in custom_dataset:
"img.png")
image.save("cond_img.png")
cond_image.save(print(prompt)
torch.utils.data.DataLoader
DataLoader是对Dataset类的进一层封装,其将零散的数据进行打包,同时有一些常用的参数更近一步规范数据形式:
- batch_size:每批数据共有多少组零散的数据
- shuffle:是否将dataset中的元素打乱
- drop_last:当batch_size无法整除dataset的数量,那么是否要舍掉最后一个不完整的batch
- collate_fn:取出batch_size组元素然后组成一个batch送到collate_fn函数中进行后处理,最终DataLoader每次迭代返回的batch就是collate_fn的返回值
- sampler:表示从dataset中采样的规则,比如逆序、只要前80%等...
- ...
比如我们要对上面零散的ControlNet训练数据进行“打包”,示例如下:
# 这里定义了一个collate_fn的传入参数(是一个函数),第一个参数表示从dataset中取出的一个batch的数据
def centercrop_collate_fn(batch_item, crop_size=256):
# 这里模拟了一个对batch中每组数据进行CenterCrop的后处理操作
= transforms.CenterCrop(crop_size)
gray_trans = []
image_batch = []
cond_image_batch = []
prompt_batch for image, cond_image, prompt in batch_item:
image_batch.append(gray_trans(image))
cond_image_batch.append(gray_trans(cond_image))
prompt_batch.append(prompt)# 返回数据的类型只能是tensors, numpy arrays, numbers, dicts 或 lists
return image_batch, cond_image_batch, prompt_batch
# return {"gray_image": image_batch, "cond_image": cond_image_batch, "prompt": cond_image_batch}
# collate_fn的传入参数也可以是一个类
class CenterCrop_Collater():
def __init__(self, crop_size=256):
self.crop_size = crop_size
# 在__call__函数中实现对batch的后处理,这里也是实现同样的CenterCrop操作
def __call__(self, batch_item):
= transforms.CenterCrop(self.crop_size)
gray_trans = []
image_batch = []
cond_image_batch = []
prompt_batch for image, cond_image, prompt in batch_item:
image_batch.append(gray_trans(image))
cond_image_batch.append(gray_trans(cond_image))
prompt_batch.append(prompt)return image_batch, cond_image_batch, prompt_batch
# return {"gray_image": image_batch, "cond_image": cond_image_batch, "prompt": cond_image_batch}
# 这里定义了一个sampler的传入参数(是一个类):
class my_sampler(data.Sampler):
# data_source是一个可迭代的数组类型, 可以是data.Dataset类型也可以是一般的list类型
def __init__(self, data_source):
self.data_source = data_source
# __iter__是Sampler的主要函数,其返回一个可迭代对象【是一个数值类型的下标集合】,即表示每次采样的下标
def __iter__(self):
# 比如这里自己规定了四个简单的采样:
return iter(range(len(self.data_source))[::-1]) # 逆序采样
# return iter(range(len(self.data_source))) # 顺序采样,也是默认的采样方式
# return iter(range(len(self.data_source))[:10]) # 只采样前10个(这种切片形式其实就可以从一整个数据集中划分训练/测试集合等)
# return iter(torch.randperm(len(self.data_source)).tolist()) # 随机不重复采样
# 此外还可以根据每个类别的比例进行加权采样....
def __len__(self):
return len(self.data_source)
if __name__ == "__main__":
# 2.对DataLoader一组的图像进行遍历
= torch.utils.data.DataLoader(
custom_dataloader # 表示要对哪个dataset进行打包,【复用了上面的custom_dataset】
custom_dataset, =2,
batch_size=False,
shuffle=True,
drop_last# collate_fn=functools.partial(centercrop_collate_fn, crop_size=384),
=CenterCrop_Collater(crop_size=384),
collate_fn# sampler=my_sampler(custom_dataset),
# batch_sampler = None
)
for custom_data in custom_dataloader:
# images, cond_images, prompts = custom_data["gray_image"], custom_data["cond_image"], custom_data["prompt"] # collate_fn的返回类型是dict则可以用这一行
= custom_data
images, cond_images, prompts 0].save("img.png")
images[0].save("cond_img.png")
cond_images[print(prompts[0])
其他要注意的事项:
DataLoader最终返回的每个batch中的元素必须是在
tensors, numpy arrays, numbers, dicts or lists
中,而不能包含PIL.Image.Image
等其他类型,所以'要么Dataset的__getitem__
返回就是符合要求的,要么collate_fn的返回类型是符合要求""的。比如把collate_fn置为None,那么就会报错"TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>",这时候如果再将Dataset的
__getitem__
的返回从PIL.Image变成np.darray就又ok了但是可以把其他不符合要求的类型打包在dicts或者lists中,比如上面collate_fn中返回的tuple或者dict
如果每组数据之间没有关联,那么其实在dataset的
__getitem__
中把数据都单独处理完就行了(比如resize、crop、to_tensor等),那么collate_fn写成None就行DataLoader其实还有
batch_sampler
参数,其是对sampler生成的indices再打包分组,得到一个又一个batch的index,不过一般用不到这么复杂的sampler参数定义每次从dataset中采样的规则,在其不为None时候则要求shuffle必须为False
pytorch中也提供了多种采样方式:
SequentialSampler: 顺序采样
RandomSampler: 随机采样
WeightedSampler: 权重采样
SubsetRandomSampler: 子集部分采样
...
dataset库
除了上面介绍的pytorch中自带的数据加载API,随着HuggingFace在Transformers和Diffusers的影响力,其下的dataset库也逐渐受到学术界和工业界的重视。
自动下载社区已有数据集
dataset库允许我们使用load_dataset
函数直接通过一句话在线下载HuggingFace社区已经有的数据集,一般有两种形式:
- 直接指定数据集名称。(一般是HuggingFace上的“用户名/数据集名”)
- 通过自定义的脚本,其实本质还是脚本内指定了数据集名称。(至于该脚本怎么写后面会进一步介绍,这里先用现成的)
比如下面一句话下载可用于ControlNet训练的数据集fusing/fill50k:
from datasets import load_dataset, load_from_disk
# 1.使用数据集名下载
= load_dataset(
my_dataset ='fusing/fill50k' # 指定数据集名称
path='train', # 指定下在训练集合还是测试集合,不指定则全部下载
split# 指定cache_dir下载到缓存路径,且下次在加载的时候优先从缓存中加载,避免重复下载
# 否则默认下载到hugging_face数据根目录下:~/.cache/huggingface/datasets
="./dataset_demo/test/fill50k_auto_down")
cache_dir# 2.使用自定义脚本下载
= load_dataset(path="./dataset_demo/test/fill50k_git_down/fill50k.py", # 指定加载脚本
git_down_dataset ='train',
split="./dataset_demo/test/fill50k_auto_down_script") cache_dir
当然我们也可以通过save_to_disk
将数据集保存到我们指定的位置,并通过load_from_disk
再加载进来:
"./dataset_demo/test/fill50k_auto_down")
auto_down_dataset.save_to_disk(= load_from_disk("./dataset_demo/test/fill50k_auto_down") auto_down_dataset
其中save_to_disk
会生成一些额外的数据和文件【因为该数据集只有训练集,所以没有test文件夹】,上面两中方式下载和添加后的文件是一致的,如下图所示:
Git下载好后本地加载
预先通过Git下载好数据集好,可以直接加载数据集
# 先git下载数据集:git clone https://huggingface.co/datasets/fusing/fill50k
= load_dataset("./dataset_demo/test/fill50k_git_down")
git_down_dataset "./dataset_demo/test/fill50k_git_down")
git_down_dataset.save_to_disk(print(git_down_dataset)
ImageFolder加载
ImageFolder允许我们通过很简单的方式加载我们自定义的数据集,比如我们现在拥有一下目录结构的数据
= load_dataset( path="imagefolder",
image_imagefolder ="./dataset_demo/test/fill50k_imagefolder",
data_dir=False # 设置为False,则train/test下根据文件夹分类别,得到"label"字段)
drop_labelsprint(image_imagefolder)
print(image_imagefolder["train"][0])
'''
输出:
DatasetDict({
train: Dataset({
features: ['image', 'label', 'prompt'],
num_rows: 8
})
test: Dataset({
features: ['image', 'label', 'prompt'],
num_rows: 8
})
})
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x2A12C0040>,
'label': 0,
'prompt': 'pale golden rod circle with old lace background'}
'''
其中"./dataset_demo/test/fill50k_imagefolder"
的目录结构如下:
Dataset.from_dict加载
上面的ImageFolder支持我们快速加载数据,但是无法支持更加复杂的数据,而
from datasets import Dataset, Image
from torchvision import transforms
# 按照下面的形式定义数据,如果是图片数据则传入图像路径
# 沿用上面fill50k_imagefolder的数据
= {
data_dict "pixel_values":[
"./dataset_demo/test/fill50k_imagefolder/train/images/0.png",
"./dataset_demo/test/fill50k_imagefolder/train/images/1.png",
"./dataset_demo/test/fill50k_imagefolder/train/images/2.png",
"./dataset_demo/test/fill50k_imagefolder/train/images/3.png",
],"conditioning_pixel_values":[
"./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/0.png",
"./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/1.png",
"./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/2.png",
"./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/3.png"
],"prompts":["pale golden rod circle with old lace background",
"light coral circle with white background",
"aqua circle with light pink background",
"cornflower blue circle with light golden rod yellow background"]
}
= Dataset.from_dict(data_dict) # 加载数据
my_vis_dataset print(my_vis_dataset)
'''
输出:
Dataset({
features: ['pixel_values', 'conditioning_pixel_values', 'prompts'],
num_rows: 4
})
'''
print(my_vis_dataset[0]["pixel_values"]) # 这时候仍然是str类型的图像路径
= my_vis_dataset.cast_column("pixel_values", Image()) # 利用cast_column将路径转为PIL图像数据
my_vis_dataset print(my_vis_dataset[0]["pixel_values"]) # 转为PIL的图像形式了(当然是并未做过任何归一化等处理)
= my_vis_dataset.cast_column("conditioning_pixel_values", Image())
my_vis_dataset
def custom_mapping(examples):
# 这里使用map对数据集进行操作,比如这里将图像缩放到256
"pixel_values"] = [image.convert("RGB").resize((256, 256)) for image in examples["pixel_values"]]
examples["conditioning_pixel_values"] = [image.convert("RGB").resize((256, 256)) for image in examples["conditioning_pixel_values"]]
examples["prompts"] = ["add_"+prompt for prompt in examples["prompts"]]
examples[return examples
# 如果有必要可以使用map()函数对数据集中再进行处理(比如对这里将图像resize,或者修改prompt)
= my_vis_dataset.map(custom_mapping, batched=True)
my_vis_dataset
# Dataset.from_dict返回的结果是可以直接作为DataLoader的dataset
# 当然也要符合基本的要求,比如上面提到的,如果dataset返回形式有PIL,则需要转为tensor或numpy等符合要求的形式
# 比如这里可以提前使用with_transform进行一些变换
def prepare_train_dataset(dataset):
= transforms.Compose(
image_transforms
[512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(256),
transforms.CenterCrop(
transforms.ToTensor(),0.5], [0.5]),
transforms.Normalize([
]
)def preprocess_train(examples):
= [image.convert("RGB") for image in examples["pixel_values"]]
images = [image_transforms(image) for image in images]
images "pixel_values"] = images
examples[
= [image.convert("RGB") for image in examples["conditioning_pixel_values"]]
conditioning_images = [image_transforms(image) for image in conditioning_images]
conditioning_images "conditioning_pixel_values"] = conditioning_images
examples[return examples
return dataset.with_transform(preprocess_train)
= prepare_train_dataset(my_vis_dataset)
my_vis_dataset = torch.utils.data.DataLoader(my_vis_dataset, shuffle=True, batch_size=1)
my_data_loader
for step, batch in enumerate(my_data_loader):
= batch["pixel_values"].permute(0, 2, 3, 1).numpy()[0]
pixel_values = (pixel_values * 0.5) + 0.5
pixel_values = (pixel_values * 255.0).astype(np.uint8)
pixel_values = batch["conditioning_pixel_values"].permute(0, 2, 3, 1).numpy()[0]
conditioning_pixel_values = (conditioning_pixel_values * 0.5) + 0.5
conditioning_pixel_values = (conditioning_pixel_values * 255.0).astype(np.uint8)
conditioning_pixel_values = batch["prompts"]
prompt print("prompt:", prompt)
"pixel_values",
cv2.imshow(
cv2.cvtColor(np.concatenate([pixel_values, =1),
conditioning_pixel_values], axis
cv2.COLOR_RGB2BGR)) cv2.waitKey()
自定义脚本加载
其实上面Dataset.from_dict
已经能够满足我们大多数的需求了,但是HuggingFace中很多有通过脚本下载数据的例子,包括上面通过脚本自动下载社区数据。
这里我们定义一个脚本来处理下面这种结构的数据:
= load_dataset(
dataset_script # 1.通过指定的py脚本进行数据生成
="/Users/bytedance/Documents/Code/SelfStudy/dataset_demo/test/fill50k_script/fill50k_custom.py",
path# 2.如果是指定的文件夹路径,那么会在该文件夹下寻找和文件夹同名的py文件作为生成脚本。
# 比如这里将上面fill50k_custom.py复制一份得到fill50k_script.py,并简单修改其中的一些路径,就可以直接通过下面只通过指定文件夹的形式加载数据
# path="/Users/bytedance/Documents/Code/SelfStudy/dataset_demo/test/fill50k_script",
="./dataset_demo/test/fill50k_script")
cache_dirprint(dataset_script)
'''
输出:
DatasetDict({
train: Dataset({
features: ['image', 'conditioning_image', 'text'],
num_rows: 4
})
})
'''
"train"][2]["image"].save("./tmp.png")
dataset_script["./dataset_demo/test/fill50k_script") dataset_script.save_to_disk(
其中fill50k_custom.py中主要是定义了继承datasets.GeneratorBasedBuilder
的类,该类要重写三个方法:
- 函数_info(self) : 用于描述该数据集,比如homepage、license等,
- 函数_split_generators(self, dl_manager) : 用于划分数据集(训练/测试等)
- 函数_generate_examples(self, dl_manager): 对每个数据集进行处理(比如读取图片等)
完整的fill50k_custom.py
代码如下
from PIL import Image
import pandas as pd
import datasets
import os
import logging
= datasets.Version("0.0.2")
_VERSION # 数据集路径设置(因为这里是从本地读取数据,所以这里要指定数据集路径)
= "train.jsonl"
META_DATA_PATH = "./"
IMAGE_DIR = "./"
CONDITION_IMAGE_DIR
# 定义数据集中有哪些特征,及其类型
= datasets.Features(
_FEATURES
{# 定义数据名称和类型,注意这里数据名称不一定要和train.jsonl中的一致,
# 但是这个key是最为最终dataset的column_names
"image": datasets.Image(),
"conditioning_image": datasets.Image(),
"text": datasets.Value("string"),
},
)
= datasets.BuilderConfig(name="default", version=_VERSION)
_DEFAULT_CONFIG # 定义数据集
class My_Fill50k(datasets.GeneratorBasedBuilder):
= [_DEFAULT_CONFIG]
BUILDER_CONFIGS = "default"
DEFAULT_CONFIG_NAME
def _info(self):
return datasets.DatasetInfo(
="None",
description=_FEATURES,
features=None,
supervised_keys="None",
homepage="None",
license="None",
citation
)
def _split_generators(self, dl_manager):
= dl_manager.download(META_DATA_PATH)
metadata_path = dl_manager.download(IMAGE_DIR)
images_dir = dl_manager.download(CONDITION_IMAGE_DIR)
conditioning_images_dir
# 上面的操作其实就是利用dl_manager.download下载hf_hub_url,并且拼接一些路径
# 所以这里如果就是从本地下载可以抛弃dl_manager.download,直接赋值三个路径就行:
'''
metadata_path = "./dataset_demo/test/fill50k_script/train.jsonl"
images_dir = "./dataset_demo/test/fill50k_script"
conditioning_images_dir = "./dataset_demo/test/fill50k_script"
'''
return [
datasets.SplitGenerator(=datasets.Split.TRAIN,
name# These kwargs will be passed to _generate_examples
# 这里是将图像全部用于训练集合了,当然也可以读取每个路径下的文件路径,然后划分训练/测试,重新传参数
# 下面的参数是和_generate_examples函数的入口参数对齐的
={
gen_kwargs"metadata_path": metadata_path,
"images_dir": images_dir,
"conditioning_images_dir": conditioning_images_dir,
},
),
]
# 输入参数是和_split_generators返回的dict中的gen_kwargs参数对应的
def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
= pd.read_json(metadata_path, lines=True)
metadata
for _, row in metadata.iterrows():
= row["text"] # 这里的key是和train.jsonl文件中对应的
text
= os.path.join(images_dir, row["image"])
image_path = open(image_path, "rb").read() # 保存成二进制形式
image
= os.path.join(conditioning_images_dir, row["conditioning_image"])
conditioning_image_path = open(conditioning_image_path, "rb").read()
conditioning_image
'''
当然我们也可以对已有的数据做一些操作后再作为真正dataset数据
1.先读取pillow图像数据
image_pil = Image.open(image_path).convert("RGB")
2.然后对image_pil做一些自定义的操作,比如取反...
image_pil = Image.eval(image_pil, lambda x: 255 - x)
3.最后再保存成bytes类型
image_bytes = io.BytesIO() # 注意:这里不要使用cv2.imencode保存成bytes类型!!!不然会导致RGB顺序错乱
image_pil.save(image_bytes, format="JPEG")
image = image_bytes
# with open("./show_img.png", "wb") as fw:
# fw.write(image_bytes.getvalue()) # 这里可以保存后查看操作后的数据是不是符合预期的
'''
yield row["image"], {
"text": text, # 这里的key要和_FEATURES保持一直,也是用来作为最终dataset的column_names
"image": {
"path": "", # 如果原始数据是经过操作的了,那么这里可以不指定'path'
"bytes": image,
},"conditioning_image": {
"path": conditioning_image_path, # 同理这里也可以不指定'path'
"bytes": conditioning_image,
}, }
webdataset加载
webdataset是一种tar压缩包形式的数据存储形式,比较适合数据量比较大的情况下:
比如我们有以下一组数据,其中image
后缀和conditioning_image
后缀就是直接将png图片后缀从png改过来的,prompt
后缀也是从txt/jsonl
后缀改过来的,每一组的数据使用相同的文件名,相同特征的数据用相同的后缀名:
我们先用tar命令将数据打包:
tar --sort=name -cf tar1.tar tar1 # --sort=name命令一定要加
将数据集打包后就可以直接用load_dataset("webdataset",...)
进行读取了:
= load_dataset("webdataset",
webdataset_data ={"train": ["./dataset_demo/test/fill50k_webdataset/tar1.tar"]},
data_files="train",
split=True
streaming
)print(webdataset_data) # IterableDataset类型,也是可以直接用DataLoader包装的
'''
输出:
IterableDataset({
features: ['__key__', '__url__', 'conditioning_image', 'image', 'prompt'],
n_shards: 1
})
'''
for item in webdataset_data:
# 所以数据都会保存为bytes形式,这里进行decode还原数据
= cv2.imdecode(np.asarray(bytearray(item["image"]), dtype="uint8"), cv2.IMREAD_COLOR)
image = cv2.imdecode(np.asarray(bytearray(item["conditioning_image"]), dtype="uint8"), cv2.IMREAD_COLOR)
conditioning_image = item["prompt"].decode('utf-8')
json_info print(json_info)
"result", np.concatenate([image, conditioning_image], axis=1))
cv2.imshow( cv2.waitKey()
总结
本篇文章主要通过具体的实例,介绍了PyTorch中数据加载的机制,同时也介绍了目前dataset库常用的几种数据加载方式。这些强大的数据读取与加载机制,能够让我们极大程度减少在数据加载上的精力,而更加专注模型结构的设计与训练。