pytorch数据加载常用torch.utils.data.Dataset和torch.utils.data.DataLoader两个类进行实现,简单来说:

  • torch.utils.data.Dataset:完成数据的初步读取和加载,其内的每一条数据是"零散"的
  • torch.utils.data.DataLoader:对torch.utils.data.Dataset中"零散"的数据进行打包,同时也可以进行一些后处理操作和采样操作。

下面就通过代码的方式详细介绍上面的两个类。

torch.utils.data.Dataset

Dataset类简单来说就是完成数据的读取操作【当然也可以做一些简单操作】,pytorch中也内置了很多常用的计算机视觉的数据集【如如MNIST、CIFAR10、ImageNet】,通常是通过torchvision.datasets块来实现。

import torch
from torchvision import datasets, transforms

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
# 加载CIFAR10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=None)
...

除了上面torchvision中已经有的数据库,用Dataset类来完成自己的数据集的读取也是很重要的,比如现在手里手里有一批用于训练ControlNet模型的数据,每一组内包含【RGB图像,Conditioning图像,promot】,数据的目录结构如下:

数据集结构示例

读取的示例如下:

'''
data.Dataset是pytorch数据加载的基准,一般我们要重写其三个方法:
__init__方法:进行类的初始化,一般是用来读取原始数据。
__getitem__方法:根据下标对每一组数据进行处理。return:对dataset[index]处理后的一组数据
__len__方法:return:数据集的数量(int)
'''
class Custom_Dataset(data.Dataset):
    def __init__(self, data_root):
        self.image_dir = os.path.join(data_root, "images")
        self.conditioning_image_dir = os.path.join(data_root, "conditioning_images")
        self.prompt_dir = os.path.join(data_root, "prompts")
        self.paired_data_paths = self.get_paired_data_paths()
                
        # 进行一些简单的数据预处理,比如这里的缩放和灰度化
        self.image_transforms = transforms.Compose([
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.Grayscale()
        ] )

    def get_paired_data_paths(self):
        pair_data_paths = []
        image_paths = glob("{}/*.png".format(self.image_dir))
        image_paths.sort()
        for img_path in image_paths:
            img_name = os.path.basename(img_path)
            cond_img_path = os.path.join(self.conditioning_image_dir, img_name)
            prompt_path = os.path.join(self.prompt_dir, img_name.replace(".png", ".txt"))
            pair_data_paths.append((img_path, cond_img_path, prompt_path))
        return pair_data_paths
        
    # item_idx表示整个数据集中当前读到的数据下标
    def __getitem__(self, item_idx):
        img_path, cond_img_path, prompt_path = self.paired_data_paths[item_idx]
        image = self.image_transforms(Image.open(img_path).convert("RGB"))
        cond_image = self.image_transforms(Image.open(cond_img_path).convert("RGB"))
        with open(prompt_path, "r") as fr:
            prompt = json.loads(fr.read())["text"]
        # 1.这里的返回形式没有要求,可以是tensor、np.darray、numbers、dicts、lists、PIL.Image等各种形式
        # 2.【后面会说可先跳过】若包在DataLoader中且collate_fn为None,则这里不能有PIL.Image类型,比如可以换成np.darray
        # return np.array(image), np.array(cond_image), prompt
        return image, cond_image, prompt

    def __len__(self):
        return len(self.paired_data_paths)
      
if __name__ == "__main__":
    data_root = "./fill50k_imagefolder"
    custom_dataset = Custom_Dataset(data_root)
    # 可以直接对Dataset进行一个个样本的遍历
    for image, cond_image, prompt in custom_dataset:
    image.save("img.png")
    cond_image.save("cond_img.png")
    print(prompt)

torch.utils.data.DataLoader

DataLoader是对Dataset类的进一层封装,其将零散的数据进行打包,同时有一些常用的参数更近一步规范数据形式:

  • batch_size:每批数据共有多少组零散的数据
  • shuffle:是否将dataset中的元素打乱
  • drop_last:当batch_size无法整除dataset的数量,那么是否要舍掉最后一个不完整的batch
  • collate_fn:取出batch_size组元素然后组成一个batch送到collate_fn函数中进行后处理,最终DataLoader每次迭代返回的batch就是collate_fn的返回值
  • sampler:表示从dataset中采样的规则,比如逆序、只要前80%等...
  • ...

比如我们要对上面零散的ControlNet训练数据进行“打包”,示例如下:

# 这里定义了一个collate_fn的传入参数(是一个函数),第一个参数表示从dataset中取出的一个batch的数据
def centercrop_collate_fn(batch_item, crop_size=256):
    # 这里模拟了一个对batch中每组数据进行CenterCrop的后处理操作
    gray_trans = transforms.CenterCrop(crop_size)
    image_batch = []
    cond_image_batch = []
    prompt_batch = []
    for image, cond_image, prompt in batch_item:
        image_batch.append(gray_trans(image))
        cond_image_batch.append(gray_trans(cond_image))
        prompt_batch.append(prompt)
    # 返回数据的类型只能是tensors, numpy arrays, numbers, dicts 或 lists
    return image_batch, cond_image_batch, prompt_batch
    # return {"gray_image": image_batch, "cond_image": cond_image_batch, "prompt": cond_image_batch}

# collate_fn的传入参数也可以是一个类
class CenterCrop_Collater():
    def __init__(self, crop_size=256):
        self.crop_size = crop_size

    # 在__call__函数中实现对batch的后处理,这里也是实现同样的CenterCrop操作
    def __call__(self, batch_item):
        gray_trans = transforms.CenterCrop(self.crop_size)
        image_batch = []
        cond_image_batch = []
        prompt_batch = []
        for image, cond_image, prompt in batch_item:
            image_batch.append(gray_trans(image))
            cond_image_batch.append(gray_trans(cond_image))
            prompt_batch.append(prompt)
        return image_batch, cond_image_batch, prompt_batch
        # return {"gray_image": image_batch, "cond_image": cond_image_batch, "prompt": cond_image_batch}

# 这里定义了一个sampler的传入参数(是一个类):
class my_sampler(data.Sampler):
    # data_source是一个可迭代的数组类型, 可以是data.Dataset类型也可以是一般的list类型
    def __init__(self, data_source):
        self.data_source = data_source

    # __iter__是Sampler的主要函数,其返回一个可迭代对象【是一个数值类型的下标集合】,即表示每次采样的下标
    def __iter__(self):
        # 比如这里自己规定了四个简单的采样:
        return iter(range(len(self.data_source))[::-1])  # 逆序采样
        # return iter(range(len(self.data_source)))  # 顺序采样,也是默认的采样方式
        # return iter(range(len(self.data_source))[:10]) # 只采样前10个(这种切片形式其实就可以从一整个数据集中划分训练/测试集合等)
        # return iter(torch.randperm(len(self.data_source)).tolist())   # 随机不重复采样
        # 此外还可以根据每个类别的比例进行加权采样....

    def __len__(self):
        return len(self.data_source)
      
if __name__ == "__main__":
    
    # 2.对DataLoader一组的图像进行遍历
    custom_dataloader = torch.utils.data.DataLoader(
        custom_dataset,  # 表示要对哪个dataset进行打包,【复用了上面的custom_dataset】
        batch_size=2,            
        shuffle=False,        
        drop_last=True,        
        # collate_fn=functools.partial(centercrop_collate_fn, crop_size=384),
        collate_fn=CenterCrop_Collater(crop_size=384),   
        # sampler=my_sampler(custom_dataset),          
        # batch_sampler = None 
    )

    for custom_data in custom_dataloader:
        # images, cond_images, prompts = custom_data["gray_image"], custom_data["cond_image"], custom_data["prompt"]  # collate_fn的返回类型是dict则可以用这一行
        images, cond_images, prompts = custom_data
        images[0].save("img.png")
        cond_images[0].save("cond_img.png")
        print(prompts[0])

其他要注意的事项:

  • DataLoader最终返回的每个batch中的元素必须是在tensors, numpy arrays, numbers, dicts or lists中,而不能包含PIL.Image.Image等其他类型,所以'要么Dataset的__getitem__返回就是符合要求的,要么collate_fn的返回类型是符合要求""的。

    • 比如把collate_fn置为None,那么就会报错"TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>",这时候如果再将Dataset的__getitem__的返回从PIL.Image变成np.darray就又ok了

    • 但是可以把其他不符合要求的类型打包在dicts或者lists中,比如上面collate_fn中返回的tuple或者dict

  • 如果每组数据之间没有关联,那么其实在dataset的__getitem__中把数据都单独处理完就行了(比如resize、crop、to_tensor等),那么collate_fn写成None就行

  • DataLoader其实还有batch_sampler参数,其是对sampler生成的indices再打包分组,得到一个又一个batch的index,不过一般用不到这么复杂的

  • sampler参数定义每次从dataset中采样的规则,在其不为None时候则要求shuffle必须为False

  • pytorch中也提供了多种采样方式:

    • SequentialSampler: 顺序采样

    • RandomSampler: 随机采样

    • WeightedSampler: 权重采样

    • SubsetRandomSampler: 子集部分采样

    • ...

dataset库

除了上面介绍的pytorch中自带的数据加载API,随着HuggingFace在TransformersDiffusers的影响力,其下的dataset库也逐渐受到学术界和工业界的重视。

自动下载社区已有数据集

dataset库允许我们使用load_dataset函数直接通过一句话在线下载HuggingFace社区已经有的数据集,一般有两种形式:

  • 直接指定数据集名称。(一般是HuggingFace上的“用户名/数据集名”)
  • 通过自定义的脚本,其实本质还是脚本内指定了数据集名称。(至于该脚本怎么写后面会进一步介绍,这里先用现成的)

比如下面一句话下载可用于ControlNet训练的数据集fusing/fill50k:

from datasets import load_dataset, load_from_disk
# 1.使用数据集名下载
my_dataset = load_dataset(
                                                path='fusing/fill50k'  # 指定数据集名称
                          split='train',   # 指定下在训练集合还是测试集合,不指定则全部下载
                          # 指定cache_dir下载到缓存路径,且下次在加载的时候优先从缓存中加载,避免重复下载
                          # 否则默认下载到hugging_face数据根目录下:~/.cache/huggingface/datasets
                          cache_dir="./dataset_demo/test/fill50k_auto_down")
# 2.使用自定义脚本下载
git_down_dataset = load_dataset(path="./dataset_demo/test/fill50k_git_down/fill50k.py",  # 指定加载脚本
                                split='train',
                                cache_dir="./dataset_demo/test/fill50k_auto_down_script")

当然我们也可以通过save_to_disk将数据集保存到我们指定的位置,并通过load_from_disk再加载进来:

auto_down_dataset.save_to_disk("./dataset_demo/test/fill50k_auto_down")  
auto_down_dataset = load_from_disk("./dataset_demo/test/fill50k_auto_down")

其中save_to_disk会生成一些额外的数据和文件【因为该数据集只有训练集,所以没有test文件夹】,上面两中方式下载和添加后的文件是一致的,如下图所示:

save_to_disk后文件变化

Git下载好后本地加载

预先通过Git下载好数据集好,可以直接加载数据集

# 先git下载数据集:git clone https://huggingface.co/datasets/fusing/fill50k
git_down_dataset = load_dataset("./dataset_demo/test/fill50k_git_down")
git_down_dataset.save_to_disk("./dataset_demo/test/fill50k_git_down")   
print(git_down_dataset)

ImageFolder加载

ImageFolder允许我们通过很简单的方式加载我们自定义的数据集,比如我们现在拥有一下目录结构的数据

image_imagefolder = load_dataset( path="imagefolder", 
                                                                data_dir="./dataset_demo/test/fill50k_imagefolder", 
                                                                drop_labels=False # 设置为False,则train/test下根据文件夹分类别,得到"label"字段)
print(image_imagefolder)
print(image_imagefolder["train"][0])
  
'''
输出:
DatasetDict({
    train: Dataset({
        features: ['image', 'label', 'prompt'],
        num_rows: 8
    })
    test: Dataset({
        features: ['image', 'label', 'prompt'],
        num_rows: 8
    })
})
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x2A12C0040>, 
 'label': 0, 
 'prompt': 'pale golden rod circle with old lace background'}
'''

其中"./dataset_demo/test/fill50k_imagefolder"的目录结构如下:

imagefolder数据结构示例

Dataset.from_dict加载

上面的ImageFolder支持我们快速加载数据,但是无法支持更加复杂的数据,而

from datasets import Dataset, Image
from torchvision import transforms
# 按照下面的形式定义数据,如果是图片数据则传入图像路径
# 沿用上面fill50k_imagefolder的数据
data_dict = {
    "pixel_values":[
        "./dataset_demo/test/fill50k_imagefolder/train/images/0.png",
        "./dataset_demo/test/fill50k_imagefolder/train/images/1.png",
        "./dataset_demo/test/fill50k_imagefolder/train/images/2.png",
        "./dataset_demo/test/fill50k_imagefolder/train/images/3.png",

    ],
    "conditioning_pixel_values":[
        "./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/0.png",
        "./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/1.png",
        "./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/2.png",
        "./dataset_demo/test/fill50k_imagefolder/train/conditioning_images/3.png"
    ],
    "prompts":["pale golden rod circle with old lace background",
               "light coral circle with white background",
               "aqua circle with light pink background",
               "cornflower blue circle with light golden rod yellow background"]
}

my_vis_dataset = Dataset.from_dict(data_dict)  # 加载数据
print(my_vis_dataset)
'''
输出:
Dataset({
    features: ['pixel_values', 'conditioning_pixel_values', 'prompts'],
    num_rows: 4
})
'''
print(my_vis_dataset[0]["pixel_values"])  # 这时候仍然是str类型的图像路径
my_vis_dataset = my_vis_dataset.cast_column("pixel_values", Image())  # 利用cast_column将路径转为PIL图像数据
print(my_vis_dataset[0]["pixel_values"])   # 转为PIL的图像形式了(当然是并未做过任何归一化等处理)
my_vis_dataset = my_vis_dataset.cast_column("conditioning_pixel_values", Image())


def custom_mapping(examples):
    # 这里使用map对数据集进行操作,比如这里将图像缩放到256
    examples["pixel_values"] = [image.convert("RGB").resize((256, 256)) for image in examples["pixel_values"]]
    examples["conditioning_pixel_values"] = [image.convert("RGB").resize((256, 256)) for image in examples["conditioning_pixel_values"]]
    examples["prompts"] = ["add_"+prompt for prompt in examples["prompts"]]
    return examples

# 如果有必要可以使用map()函数对数据集中再进行处理(比如对这里将图像resize,或者修改prompt)
my_vis_dataset = my_vis_dataset.map(custom_mapping, batched=True)

# Dataset.from_dict返回的结果是可以直接作为DataLoader的dataset
# 当然也要符合基本的要求,比如上面提到的,如果dataset返回形式有PIL,则需要转为tensor或numpy等符合要求的形式
# 比如这里可以提前使用with_transform进行一些变换
def prepare_train_dataset(dataset):
    image_transforms = transforms.Compose(
        [
            transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(256),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    def preprocess_train(examples):
        images = [image.convert("RGB") for image in examples["pixel_values"]]
        images = [image_transforms(image) for image in images]
        examples["pixel_values"] = images

        conditioning_images = [image.convert("RGB") for image in examples["conditioning_pixel_values"]]
        conditioning_images = [image_transforms(image) for image in conditioning_images]
        examples["conditioning_pixel_values"] = conditioning_images
        return examples
    return dataset.with_transform(preprocess_train)

my_vis_dataset = prepare_train_dataset(my_vis_dataset)
my_data_loader = torch.utils.data.DataLoader(my_vis_dataset, shuffle=True, batch_size=1)

for step, batch in enumerate(my_data_loader):
    pixel_values = batch["pixel_values"].permute(0, 2, 3, 1).numpy()[0]
    pixel_values = (pixel_values * 0.5) + 0.5
    pixel_values = (pixel_values * 255.0).astype(np.uint8)
    conditioning_pixel_values = batch["conditioning_pixel_values"].permute(0, 2, 3, 1).numpy()[0]
    conditioning_pixel_values = (conditioning_pixel_values * 0.5) + 0.5
    conditioning_pixel_values = (conditioning_pixel_values * 255.0).astype(np.uint8)
    prompt = batch["prompts"]
    print("prompt:", prompt)
    cv2.imshow("pixel_values", 
               cv2.cvtColor(np.concatenate([pixel_values, 
                                            conditioning_pixel_values], axis=1), 
                            cv2.COLOR_RGB2BGR))
    cv2.waitKey()

自定义脚本加载

其实上面Dataset.from_dict已经能够满足我们大多数的需求了,但是HuggingFace中很多有通过脚本下载数据的例子,包括上面通过脚本自动下载社区数据。

这里我们定义一个脚本来处理下面这种结构的数据:

脚本下载的数据结构示例
dataset_script = load_dataset(
        # 1.通过指定的py脚本进行数据生成
        path="/Users/bytedance/Documents/Code/SelfStudy/dataset_demo/test/fill50k_script/fill50k_custom.py",
        # 2.如果是指定的文件夹路径,那么会在该文件夹下寻找和文件夹同名的py文件作为生成脚本。
        # 比如这里将上面fill50k_custom.py复制一份得到fill50k_script.py,并简单修改其中的一些路径,就可以直接通过下面只通过指定文件夹的形式加载数据
            # path="/Users/bytedance/Documents/Code/SelfStudy/dataset_demo/test/fill50k_script",
      cache_dir="./dataset_demo/test/fill50k_script")
print(dataset_script)
'''
输出:
DatasetDict({
    train: Dataset({
        features: ['image', 'conditioning_image', 'text'],
        num_rows: 4
    })
})
'''
dataset_script["train"][2]["image"].save("./tmp.png")
dataset_script.save_to_disk("./dataset_demo/test/fill50k_script")

其中fill50k_custom.py中主要是定义了继承datasets.GeneratorBasedBuilder的类,该类要重写三个方法:

  • 函数_info(self) : 用于描述该数据集,比如homepage、license等,
  • 函数_split_generators(self, dl_manager) : 用于划分数据集(训练/测试等)
  • 函数_generate_examples(self, dl_manager): 对每个数据集进行处理(比如读取图片等)

完整的fill50k_custom.py代码如下

from PIL import Image
import pandas as pd
import datasets
import os
import logging

_VERSION = datasets.Version("0.0.2")
# 数据集路径设置(因为这里是从本地读取数据,所以这里要指定数据集路径)
META_DATA_PATH = "train.jsonl"
IMAGE_DIR = "./"
CONDITION_IMAGE_DIR = "./"

# 定义数据集中有哪些特征,及其类型
_FEATURES = datasets.Features(
    {
        # 定义数据名称和类型,注意这里数据名称不一定要和train.jsonl中的一致,
        # 但是这个key是最为最终dataset的column_names
        "image": datasets.Image(),   
        "conditioning_image": datasets.Image(),
        "text": datasets.Value("string"),
    },
)

_DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
# 定义数据集
class My_Fill50k(datasets.GeneratorBasedBuilder):
    BUILDER_CONFIGS = [_DEFAULT_CONFIG]
    DEFAULT_CONFIG_NAME = "default"

    def _info(self):
        return datasets.DatasetInfo(
            description="None",
            features=_FEATURES,
            supervised_keys=None,
            homepage="None",
            license="None",
            citation="None",
        )

    def _split_generators(self, dl_manager):
        metadata_path = dl_manager.download(META_DATA_PATH)
        images_dir = dl_manager.download(IMAGE_DIR)
        conditioning_images_dir = dl_manager.download(CONDITION_IMAGE_DIR)

        # 上面的操作其实就是利用dl_manager.download下载hf_hub_url,并且拼接一些路径
        # 所以这里如果就是从本地下载可以抛弃dl_manager.download,直接赋值三个路径就行:
        '''
        metadata_path = "./dataset_demo/test/fill50k_script/train.jsonl"
        images_dir = "./dataset_demo/test/fill50k_script"
        conditioning_images_dir = "./dataset_demo/test/fill50k_script"
        '''

        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                # These kwargs will be passed to _generate_examples
                # 这里是将图像全部用于训练集合了,当然也可以读取每个路径下的文件路径,然后划分训练/测试,重新传参数
                # 下面的参数是和_generate_examples函数的入口参数对齐的
                gen_kwargs={
                    "metadata_path": metadata_path,
                    "images_dir": images_dir,
                    "conditioning_images_dir": conditioning_images_dir,
                },
            ),
        ]

    # 输入参数是和_split_generators返回的dict中的gen_kwargs参数对应的
    def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
        metadata = pd.read_json(metadata_path, lines=True)

        for _, row in metadata.iterrows():
            text = row["text"]  # 这里的key是和train.jsonl文件中对应的

            image_path = os.path.join(images_dir, row["image"])
            image = open(image_path, "rb").read() # 保存成二进制形式

            conditioning_image_path = os.path.join(conditioning_images_dir, row["conditioning_image"])
            conditioning_image = open(conditioning_image_path, "rb").read()

            '''
            当然我们也可以对已有的数据做一些操作后再作为真正dataset数据
            1.先读取pillow图像数据
            image_pil = Image.open(image_path).convert("RGB")  
            2.然后对image_pil做一些自定义的操作,比如取反...
            image_pil = Image.eval(image_pil, lambda x: 255 - x)
            3.最后再保存成bytes类型
            image_bytes = io.BytesIO()  # 注意:这里不要使用cv2.imencode保存成bytes类型!!!不然会导致RGB顺序错乱
            image_pil.save(image_bytes, format="JPEG")
            image = image_bytes
            
            # with open("./show_img.png", "wb") as fw:
            #     fw.write(image_bytes.getvalue())  # 这里可以保存后查看操作后的数据是不是符合预期的
            '''
            
            yield row["image"], {
                "text": text,  # 这里的key要和_FEATURES保持一直,也是用来作为最终dataset的column_names
                "image": {
                    "path": "",  # 如果原始数据是经过操作的了,那么这里可以不指定'path'
                    "bytes": image,
                },
                "conditioning_image": {
                    "path": conditioning_image_path,  # 同理这里也可以不指定'path'
                    "bytes": conditioning_image,
                },
            }

webdataset加载

webdataset是一种tar压缩包形式的数据存储形式,比较适合数据量比较大的情况下:

比如我们有以下一组数据,其中image后缀和conditioning_image后缀就是直接将png图片后缀从png改过来的,prompt后缀也是从txt/jsonl后缀改过来的,每一组的数据使用相同的文件名,相同特征的数据用相同的后缀名:

webdataset数据结构示例

我们先用tar命令将数据打包:

tar --sort=name -cf tar1.tar tar1   # --sort=name命令一定要加

将数据集打包后就可以直接用load_dataset("webdataset",...)进行读取了:

webdataset_data = load_dataset("webdataset",
             data_files={"train": ["./dataset_demo/test/fill50k_webdataset/tar1.tar"]},
             split="train",
             streaming=True
             )
print(webdataset_data)  # IterableDataset类型,也是可以直接用DataLoader包装的
'''
输出:
IterableDataset({
    features: ['__key__', '__url__', 'conditioning_image', 'image', 'prompt'],
    n_shards: 1
})
'''
for item in webdataset_data:
    # 所以数据都会保存为bytes形式,这里进行decode还原数据
    image = cv2.imdecode(np.asarray(bytearray(item["image"]), dtype="uint8"), cv2.IMREAD_COLOR)
    conditioning_image = cv2.imdecode(np.asarray(bytearray(item["conditioning_image"]), dtype="uint8"), cv2.IMREAD_COLOR)
    json_info = item["prompt"].decode('utf-8')
    print(json_info)
    cv2.imshow("result", np.concatenate([image, conditioning_image], axis=1))
    cv2.waitKey()

总结

本篇文章主要通过具体的实例,介绍了PyTorch中数据加载的机制,同时也介绍了目前dataset库常用的几种数据加载方式。这些强大的数据读取与加载机制,能够让我们极大程度减少在数据加载上的精力,而更加专注模型结构的设计与训练。