基本概念
在本文中,我们将对 torch.utils.data
模块中的 Dataset,BatchSampler 和 DataLoader 三个用于 PyTorch 框架数据加载的关键实体进行分析。
Python 中的迭代
理解 Python 中与迭代相关的概念,是理解 torch.utils.data
模块中的 Dataset,Sampler 和 DataLoader 三个实体的关键,因此如果读者朋友对 Python 中与迭代相关的概念尚不熟悉,建议先对我的另一篇文章 Python 中的迭代 进行阅读。
Python 中的多进程
我们在后面介绍 DataLoader 实体的相关内容的时候将会涉及到如何基于 Python 的多进程机制加速数据加载的过程,因此需要读者对 Python 的多进程机制有一定的了解。如果您对此不是特别熟悉,建议先对我的另一篇文章 Python 的多进程 进行阅读。
数据加载
首先,模型的训练基于数据集,因此在训练过程中,有一部分的工作是要关心数据集如何从磁盘中被加载到 Host Memory / GPU Memory,数据集如何进行采样生成 Mini-batch,生成的 Mini-batch 如何被依次送进模型中进行训练,这也就是本文要关心的训练过程中被称为

__getitem__(self)
魔法方法; Iterable-style Dataset 则提供了 __iter__(self)
魔法方法)。在训练过程中,会有多轮 Epoches,每一轮 Epoch 会有多轮 Iterations,每一轮 Iteration 一个 Mini-batch 送入模型进行前向传播、损失值计算、参数梯度计算和参数更新,在每一轮 Epoch 中会完成一次对训练集 (Mini-batches) 的遍历。为了更好的实现对 Mini-batches 的遍历的编程抽象,for mini-batch in DataLoader
的范式完成对 Mini-batches 的遍历。
DataLoader 在为某轮 Epoch 生成 Mini-batches 的过程中,是按照什么规则生成 Mini-batches 的呢?可以是按顺序在 Dataset 中进行采样,也可以是按照随机的规则进行采样,生成 Mini-batches 的采样规则就是由
综上,Dataset,BatchSampler 和 DataLoader 可以总结为:
Dataset : 将原始存储在磁盘中的数据集读取到内存中并封装为 Python 对应的对象,并且暴露提取接口;BatchSampler : 在每次迭代时输出当前 Iteration 所使用的 Samples 的索引;DataLoader : 基于设置好的 Dataset 和 BatchSampler 实体,在每次迭代时,DataLoader 将首先迭代 BatchSampler 实体以获得当前 Iteration 使用的 Samples 的索引,然后调用其 Fetcher 实体完成对应 Samples 的特征和标签数据的组装,最后进行输出。

下面我们分别对上述的功能性实体进行介绍。
Dataset 功能性介绍
Dataset 实体负责对来源于磁盘的 Raw Dataset 进行封装,将其封装成 Python 可识别的数据结构。torch.utils.data
模块提供了多种形式的 Dataset 实体抽象,并且分别提供了对应的接口类完成对这些类型的数据集的抽象,用户需要实现接口类中的相应接口,以完成对自定义数据集的封装。我们下面对这些数据集抽象分别进行分析。


Map-style Dataset
torch.utils.data
模块使用 Dataset
接口类抽象 Map-style 的数据集,Map-style 顾名思义就是数据集中的每一条 Sample 都拥有一个索引,用户可以通过索引的方式来获取 Samples。Dataset
的定义如下所示:
1
2
3
4
5
6class Dataset(Generic[T_co]):
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
可以发现,Map-style 的数据集通过定义 __getitem__()
魔法方法,实现了从索引到 Sample 的映射,使得用户可以使用 dataset[idx]
就可以访问 idx
对应的 Sample。
细心的读者会发现在上面展示的 Dataset
接口类中并没有规定实现 __len()__
接口,原因是 return NotImplemented
或者 raise NotImplementedError()
之类的默认实现都会存在各自的问题,因此 Dataset
接口类把对 __len()__
接口的实现留给了子类。
Iterable-style Dataset
Iterable-style 的数据集是一种通过实现 __iter__()
来获取数据的 Dataset,这种类型的数据集特别适用于以下情况: ① 像 Map-style 数据集一样基于索引随机读取的代价很大甚至不大可能;或者 ② 每轮 Iteration 所使用的 Batch Size 并不固定,而取决于获取的数据。torch.utils.data
模块使用 IterableDataset
接口类抽象 Iterable-style 的数据集,其定义如下所示:
1
2
3
4
5
6class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
其它 Dataset
上述的 Map-style 和 Iterable-style 的 Dataset 是 PyTorch 中两种最主要的 Dataset,下面我们介绍几种基于它们的其它 Datasets。
Concat Dataset
ConcatDataset
被用于级联多个数据集类,使得级联出来的数据集就像是一个统一大的数据集一样,可以基于索引/关键字对 Samples 进行访问,具体定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41class ConcatDataset(Dataset[T_co]):
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ConcatDataset, self).__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
Chain Dataset
ChainDataset
被用于包含多个 IterableDataset
数据集,在 IterableDataset
的 __add__()
方法中被调用,具体定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17class ChainDataset(IterableDataset):
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(ChainDataset, self).__init__()
self.datasets = datasets
def __iter__(self):
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
for x in d:
yield x
def __len__(self):
total = 0
for d in self.datasets:
assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset"
total += len(d) # type: ignore[arg-type]
return total
Subset Dataset
Subset
被用于将原有数据集指定下标的 Samples 封装为一个新的数据集,具体定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
if isinstance(idx, list):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
Tensor Dataset
TensorDataset
用于获取封装成 Tensor 的数据集,每一个样本都通过索引张量来获得,具体代码如下所示:
1
2
3
4
5
6
7
8
9
10
11
12class TensorDataset(Dataset[Tuple[Tensor, ...]]):
tensors: Tuple[Tensor, ...]
def __init__(self, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
BatchSampler 功能性介绍
现在我们定义好了 Dataset 实体抽象,已经拥有了一些接口可以获取 Dataset 中各条 Samples。在训练的各轮 Epoches 中,我们需要在每轮 Iteration 中从 Dataset 中读取单条 Sample (或多条 Samples 以形成 Mini-batch) 对模型进行训练,那么应该按照什么顺序来读取 Dataset 中的内容以形成这些 Mini-batches 呢?这也就是我们本节要讨论的 BatchSampler 实体抽象的工作。
对于 Map-style Dataset 来说,如
而对于 Iteration-style Dataset 来说,如 __next__(self)
魔法方法定义的访问顺序决定,因此本节讨论的 BatchSampler 实体并不对 Iteration-style Dataset 有效,因此
回顾我们在 dataloading 中描述的,Sampler 实体每次迭代输出一条 Sample 的索引,而 BatchSampler 实体每次迭代则输出多条 Samples 的索引。在 PyTorch 的实现中,通常实现定义出 Sampler 实体,以确定采样规则 (e.g. 随机采样,顺序采样,etc.),然后再基于定义好的 Sampler 实体结合 Batch Size 等设置生成 BatchSampler 实体。
Sampler 实体
我们下面首先来看 Sampler 实体的实现。在 torch.utils.data
模块中,Sampler
类定义了 Sampler 实体所应该拥有的接口,定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
从上面的代码和注释中可以看到,Sampler 实体需要定义 __iter__(self)
魔法方法,以可以通过 iter()
的方法来获得 Sampler 的 Iterator,对 Iterator 进行的每次迭代可以得到的采样得到的单条 Sample 的序号。
基于 Sampler
接口类,torch.utils.data
模块中提供了多种内置的 Sampler 实现,下面我们展示了其中的 RandomSampler
的相关定义。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
def __len__(self) -> int:
return self.num_samples
在 RandomSampler
的 __iter__(self)
魔法方法中我们可以看到,其返回的实际上是一个 Generator,采用 on-the-fly 的方式,随机地,可支持有放回地生成 Sample 的采样序号。
BatchSampler 实体
在 torch.utils.data
模块中,BatchSampler
类实现了 BatchSampler 实体所应该拥有的接口,定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
if self.drop_last:
sampler_iter = iter(self.sampler)
while True:
try:
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size
idx_in_batch = 0
for idx in self.sampler:
batch[idx_in_batch] = idx
idx_in_batch += 1
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]
def __len__(self) -> int:
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
从上面的代码中可以看到,BatchSampler
定义了三个初始化参数:
sampler
: 当前创建的BatchSampler
基于的 Sampler,用于确定采样的具体规则;batch_size
: 每一次对BatchSampler
进行迭代 (i.e. 每一轮 Iteration) 时,BatchSampler
要输出的 Samples 的索引的数目;drop_last
: 布尔变量,指示是否丢弃最后一个规模达不到batch_size
的 Mini-batch;
DataLoader 功能性介绍
铺垫完了 Dataset 和 BatchSampler 两个实体后,现在我们来到了数据加载流程的核心 —— DataLoader。torch.utils.data
提供的 DataLoader
是 PyTorch 加载数据的核心,其支持 Map-style 和 Iterable-style 的 Dataset 的数据加载,支持单进程/多进程数据加载,还可以设置 loading order, batch size, pin memory 等加载参数的设置,可以认为是数据加载流程的统一入口。
DataLoader 接口
DataLoader
的接口定义如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19DataLoader(
dataset: Dataset[T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: int = 2,
persistent_workers: bool = False,
pin_memory_device: str = ""
)
接口参数整理如下所示:
Attribute | Description | Default Value | Type |
---|---|---|---|
dataset |
要加载的 Dataset 实体 | Dataset |
|
batch_size |
每轮 Iteration 要加载的 Samples 的数目 | $1$ | int |
shuffle |
设置为 True 时,将调用 RandomSampler 进行随机索引 |
False |
bool |
sampler |
用户指定的 Sampler 实体,定义从 yield sampler , 则上述 shuffle 参数必须为 False ,否则会和 RandomSampler 互斥 |
None |
Sampler , Iterable |
batch_sampler |
用户指定的 BatchSampler 实体,定义从 yield |
None |
Sampler , Iterable |
num_workers |
要用于数据加载的子进程数,$0$ 表示仅在主进程中加载数据 | $0$ | int |
collate_fn |
在将 Map-style Dataset 取出的数据整合成最终 Mini-batch 时使用 | None |
callable |
pin_memory |
如果为 True ,则 DataLoader 在将张量返回之前将其复制到 CUDA 的锁页内存中 |
False |
bool |
drop_last |
若设置为 True ,则当该数据集大小 (i.e. len(dataset) ) 不能被该批次大小 (i.e. batch_size ) 整除时,删除最后一个不完整的批次;如果为 False 并且数据集的大小不能被批次大小整除,那么最后一批将较小 |
False |
bool |
timeout |
如果为正数,则为从 Worker 构建 Mini-batch 的超时值,应始终为非负数。超过这个时间还没从 Worker 读取到数据的话就会报错 | $0$ | numeric |
worker_init_fn |
如果不为 None ,它将会被每个 Worker 子进程调用。该函数将以 Worker Index (i.e. [0, num_workers - 1] 内的整形) 作为输入 |
None |
callable |
prefetch_factor |
每个 Worker 提前加载的 Sample 数量 | $2$ | int |
persistent_workers |
如果为 True ,则 DataLoader 将不会终止 Worker 进程,直到对 Dataset 的迭代完成 |
False |
bool |
自动化批处理 (Automatic Batching)

DataLoader
非常方便地为用户提供了
设置 BatchSampler
基于 SGD 的训练思路,通常在一轮 Iteration 中会使用多个 Samples 组成的 Mini-batch 进行训练,而通常不会只有一条 Sample 参与训练,DataLoader
提供了相关的参数进行 BatchSampler 相关的设置:
对于 Map-style Dataset 和 Iterable-style Dataset 来说,当 batch_size
(默认为 $1$) 不为 None
时,生成的 DataLoader 在每一次被迭代时将 yield
出一批 Samples,而不只是一个单独的 Sample,该参数与 drop_last
和 shuffle
参数配合,将决定每一轮 Iteration 所使用的 Samples 的顺序和数目。
在 DataLoader
的构造函数中,相关代码会基于 batch_size
和 drop_last
参数,结合用户指定的 Sampler 实体构造出对应的 BatchSampler 实体。那么 Sampler 实体是如何被指定的呢?对于 Map-style Dataset 来说,可以通过两种方式被指定: 一种是通过 sampler
参数显式指定 Sampler,一种是通过 shuffle
参数,当它为 True
时使用 RandomSampler
; 而对于 Iterable-style Dataset 来说,Sampler 实体并无意义,为了实现代码的兼容性,DataLoader
的构造函数会使用专门针对于 Iterable-style Dataset
提供的 Dummy Infinite Sampler _InfiniteConstantSampler
以充当 Sampler 实体,该 Sampler 可以无限地被迭代,其实质上是调用了 Iterable-style Dataset 的 __iter__(self)
魔法方法。
单独对于 Map-style Dataset 来说,用户还可以使用 batch_sampler
参数来直接设置 Mini-batch 的 BatchSampler。
设置基于 Sampled Indices 合成 Mini-batch 的方法
在基于指定的 BatchSampler 实体 yield
出一批某轮 Iteration 使用的 Samples 的 Indices 后,接下来就需要由 collate_fn
参数指定的
对于 Map-style Dataset 来说,Collate Function 流程可以抽象如下:
1
2for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
对于 Iteration-style Dataset 来说,Collate Function 流程可以抽象如下:
1
2
3dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
具体的 Collate Function 的详细设置细节可见 collate。
关闭自动批处理

当用户想用 Dataset 的代码手动处理 Batch,或每轮 Iteration 仅基于单条 Sample 进行训练时,可将 batch_size
和 batch_sampler
两个参数同时设为 None
, 此时 DataLoader
将关闭自动批处理,对 DataLoader
的迭代将使得其通过 Sampler 实体获得单条 Sample 的索引,然后将该索引对应的 Sample 数据交给 collate_fn
处理,以获得最终的 DataLoader
输出。
对于 Map-style Dataset 来说,Collate Function 流程可以抽象如下:
1
2for index in sampler:
yield collate_fn(dataset[index])
对于 Iteration-style Dataset 来说,Collate Function 流程可以抽象如下:
1
2for data in iter(dataset):
yield collate_fn(data)
Collate Function
上面说到,Collate Function 的输入是从 Dataset 实体中获取的 Sample(s) 的数据,其输出就是对 DataLoader
迭代所获得的输出,在开启/关闭自动批处理时,它的运行逻辑稍有不同。
关闭自动批处理时的情况

如 collate_fn
仅作用于单个 Sample,其工作就是简单地将 NumPy arrays
转化为 PyTorch 的 Tensor
,在转换过程中保留了 Sample 原有的数据结构,图中展示了当 Sample 的数据结构是 dict
时的情况。
开启自动批处理时的情况

而当开启自动批处理时,如 collate_fn
作用于多个 Samples,其将输入样本整理为一个 Batch,并将其 yield
回当前轮次的 Itertaion 以供训练。为了将输入样本整理为一个 Batch,collate_fn
的默认值 default_collate()
做了下面 $3$ 件事情:
- 追加 (Prepend) 一个新的维度作为 Batch Dimension (长度即为 Batch 的大小);
- 将 NumPy
arrays
和 Python Numberical Values 转化为 PyTorch 的Tensor
; - 保留输入的 Samples 中各条 Sample 的数据结构,如
img_batch_collate 所示,比如各条 Sample 是dict
时,default_collate()
将输出具有相同 Keys,且处理过的 Batched Tensor (或 Batched List,当无法转化为 Tensor 的时候) 作为值的dict
。当各条 Sample 是list
、tuple
和namedtuple
等时同理;
单进程数据加载
当设置 DataLoader
的 num_workers
为 $0$ (默认值) 时,则 DataLoader
的初始化进程和读取数据的进程是一样的,此时数据加载可能会导致主进程阻塞。当用于在进程之间共享数据的资源 (例如共享内存,文件描述符) 有限时,或者当整个数据集很小并且可以完全加载到内存中时,此模式可能是首选。此外,单进程加载通常显示更多可读的错误跟踪,因此对于调试很有用。
多进程数据加载
在多进程模式下,每轮 Iteration 开始时对 DataLoader
进行迭代时,都会创建 num_workers
个 dataset
, collate_fn
, worker_init_fn
等参数都会被传到各个 Worker 中,各个 Worker 使用这些参数进行初始化和数据的读取和整理。
torch.utils.data
模块提供了 get_work_info
函数用于在各个 Worker 进程中获取各个进程相关的信息,包括 Worker 的 ID,Dataset 的 Replica,以及 Initial Seed 等,在主进程而不是 Worker 进程中调用 get_work_info
函数将返回 None
。
对于 Map-style 的 Dataset 来说,主进程将会利用 Sampler (BatchSampler) 实体生成 Indices,然后将生成的 Indices 发送给各个 Worker 进程,也即「选择 Sample」这件事情是在主进程完成的,而「根据选择结果加载数据」这件事情是在各个 Worker 进程中完成的。
对于 Iterable-style 的 Dataset 来说,主进程中将不会进行生成 Indices 的操作,在各个 Worker 进程中会有一份 Dataset 的 Replica,各个 Worker 进程可以基于 get_work_info
函数获得的信息,对各份 Replica 进行不同的操作。
Pinned Memory (锁页内存)
我的另一篇文章 CUDA 内存管理 对 CUDA Pinned Memory 相关内容进行了介绍,如果您相关内容不熟悉,可以先移步阅读。
Host Memory 中的 Memory Page 有两种存在方式,一是 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25def pin_memory(data, device=None):
if isinstance(data, torch.Tensor):
return data.pin_memory(device)
elif isinstance(data, string_classes):
return data
elif isinstance(data, collections.abc.Mapping):
try:
return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return {k: pin_memory(sample, device) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample, device) for sample in data))
elif isinstance(data, tuple):
return [pin_memory(sample, device) for sample in data] # Backwards compatibility.
elif isinstance(data, collections.abc.Sequence):
try:
return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg]
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [pin_memory(sample, device) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data
PyTorch 为存储在 Host Memory 中的 Tensor 提供了 pin_memory()
方法,其定义如上所示,该方法返回操作的 Tensor 的副本,并将数据放在 Pinned Memory 中。对于放在 Pinned Memory 中的 Tensor,用户可以使用 to()
方法中加上参数 non_blocking=True
,以实现数据传输和 Host 计算两者的 Overlapping torch_pinned_memory。
对于 DataLoader
来说,我们可以设置传入参数 pin_memory=True
,以设置 DataLoader
将每次迭代返回的 Tensor 都放置到 Pinned Memory 中,以缩减数据在 Host Memory 和 GPU Memory 之间的拷贝时间。
另外,从上面关于 pin_memory()
方法的代码定义中可以看到,如果传入该函数的是一个自定义的数据类型,则该函数会直接返回该数据,而不做任何 Pinning 相关的处理。对于 DataLoader
来说,当我们使用 collate_fn
指定了自定义的整理函数并且该整理函数返回了自定义类型的 Batch 数据,则当我们指定 DataLoader
的 pin_memory=True
时,则会导致 Memory Pinning 的操作并不会生效的情况。为了解决这种情况,我们需要手动地为传入 collate_fn
的数据添加与 Memory Pinning 相关的代码,具体示例代码如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
DataLoader 源码解析
基于上一节对 PyTorch 提供的 DataLoader
有了功能性的认识后,本节我们将对其源码按顺序进行分析。
对 DataLoader
进行迭代
1 | for data, label in train_loader: |
首先在主进程中,我们会使用如上所示的代码对 DataLoader
进行遍历,此时会调用它的 __iter__(self)
魔法方法,该方法定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14class DataLoader(Generic[T_co]):
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0:
# 对于多进程数据加载的情况
if self._iterator is None:
# 第一次发起遍历,则创建 iterator
self._iterator = self._get_iterator()
else:
# 不是第一次发起遍历,则重置 iterator
self._iterator._reset(self)
return self._iterator
else:
# 对于多进程数据加载的情况
return self._get_iterator()
在上面的代码中可以看见其调用了 DataLoader
类下的 _get_iterator
获取 Iterator,具体代码如下所示:
1
2
3
4
5
6
7class DataLoader(Generic[T_co]):
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
可以看到其根据单/多进程数据读取的不同情况,返回了不同的 Iterator,我们下面分情况进行讨论。
DataLoader
的 Iterator
Iterator 基类: _BaseDataLoaderIter
首先,不论是单进程数据加载所使用的迭代器 _SingleProcessDataLoaderIter
,还是多进程数据加载所使用的迭代器 _MultiProcessingDataLoaderIter
,他们都继承自 _BaseDataLoaderIter
基类,其源码如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._shared_seed = loader._get_shared_seed()
if isinstance(self._dataset, IterDataPipe):
shared_rng = torch.Generator()
shared_rng.manual_seed(self._shared_seed)
self._dataset = torch.utils.data.graph_settings.apply_shuffle_seed(self._dataset, shared_rng)
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
# for other backends, pin_memory_device need to set. if not set
# default behaviour is CUDA device. if pin_memory_device is selected
# and pin_memory is not set, the default behaviour false.
if (len(loader.pin_memory_device) == 0):
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._pin_memory_device = None
else:
if not loader.pin_memory:
warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
"please set pin_memory to true, if you need to use the device pin memory")
warnings.warn(warn_msg)
self._pin_memory = loader.pin_memory
self._pin_memory_device = loader.pin_memory_device
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._shared_seed = loader._get_shared_seed()
if isinstance(self._dataset, IterDataPipe):
shared_rng = torch.Generator()
shared_rng.manual_seed(self._shared_seed)
self._dataset = torch.utils.data.graph_settings.apply_shuffle_seed(self._dataset, shared_rng)
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
# TODO(https://github.com/pytorch/pytorch/issues/76750)
self._reset() # type: ignore[call-arg]
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
next = __next__ # Python 2 compatibility
def __len__(self) -> int:
return len(self._index_sampler)
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
在 _BaseDataLoaderIter
中定义了 __next__(self)
函数,我们在主进程中使用 for
循环迭代 DataLoader
时,首先其会调用 DataLoader
的 __iter__(self)
魔法方法以获得 Iterator,而 DataLoader
的 Iterators 实现都继承自 _BaseDataLoaderIter
,因此在获取完 Iterator 后 for
循环实际上就是通过不断调用 _BaseDataLoaderIter
的 __next__(self)
以获得下一批用于训练的 Batched Tensor。从上面的代码中可以看到,__next__(self)
则是调用 _next_data()
以获取相关数据,而后者留给继承自 _BaseDataLoaderIter
的子类予以实现。
单进程加载 Iterator: _SingleProcessDataLoaderIter
类
在单进程数据加载的设定下,代表 Iterator 的类是 _SingleProcessDataLoaderIter
,其定义如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data
首先从 _SingleProcessDataLoaderIter
的初始化参数可以看到,其在父类 _BaseDataLoaderIter
的基础上定义了 _dataset_fetcher
, 并传入 _dataset
, _auto_collation
, _collate_fn
等参数,该类用于根据指定的 Indices 来 Fetch 对应的 Samples,事实上就是我们前文提到过的 Fetcher 实体,我们在后面会对其源码进行分析。
其次可以看见 _SingleProcessDataLoaderIter
实现了具体的 _next_data()
方法,其需要 next_index()
来获取要 Fetch 的 Samples 的 Indices,并将 Indices 传入 _dataset_fetcher
中以获取对应样本。
多进程加载 Iterator: _MultiProcessingDataLoaderIter
类
当用户创建 DataLoader
时传入的 num_workers
大于 $1$ 的时候,对 DataLoader
的迭代操作就会基于 _MultiProcessingDataLoaderIter
类进行。我们将在 mp 中对其进行具体分析。
单进程数据加载

现在我们来看用于单进程数据加载的基本流程,其基本调用关系如
当我们在主程序的 for
循环中对 DataLoader
进行迭代时,其会 ① 首先调用 DataLoader
的 __iter__(self)
魔法方法以获得 Iterator,从上面展示过的程序中我们可以知道,在单进程数据加载的设定下,代表 Iterator 的类是 _SingleProcessDataLoaderIter
;然后 ② 调用 _SingleProcessDataLoaderIter
的 __next__
以获取当前 Iteration 使用的 Sample Batch,从上一小节的分析我们知道:
_SingleProcessDataLoaderIter
的__next__
方法继承自父类_BaseDataLoaderIter
;- 父类
_BaseDataLoaderIter
的__next__
方法实际上调用了_next_data
方法来实现迭代逻辑,而后者留给子类实现; -
_SingleProcessDataLoaderIter
的_next_data
方法过程可以分为三步:- 调用
_SingleProcessDataLoaderIter
定义的next_index()
从 BatchSampler (Sampler) 实体中获取 Sampled Indices (Index); - 基于 Sampled Indices (Index),调用 Fetcher 实体从 Dataset 中获取并处理对应 Samples 的数据;
- 将处理后的 Sample(s) 数据转移至 Pinned Memory (如果
pin_memory==True
);
- 调用
下面我们就 _SingleProcessDataLoaderIter
类定义的 _next_data
方法所实现的三步逻辑进行分析:
获取 Indices
上面看到的 next_index()
方法是在 _SingleProcessDataLoaderIter
的父类 _BaseDataLoaderIter
中定义的,我们把相关代码整理如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
self._index_sampler = loader._index_sampler
self._sampler_iter = iter(self._index_sampler)
# ...
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
class DataLoader(Generic[T_co]):
def _auto_collation(self):
return self.batch_sampler is not None
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
从上面的代码可以看出,根据 DataLoader
中是否启用了了 batch_sampler
,next_index()
将对应地从 DataLoader
的 batch_sampler
或者 sampler
中迭代出 Indices。
Samples 加载
1
2
3
4
5
6
7
8
9
10
class _DatasetKind(object):
Map = 0
Iterable = 1
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
1 | class _DatasetKind(object): |
现在来看 Samples 加载的部分,也即 Fetcher 实体。在 _SingleProcessDataLoaderIter
的构造函数中可以看到其调用了 _DatasetKind.create_fetcher
创建了 Fetcher 实体,相关代码如上所示。根据数据集类型的不同,该函数会创建出 _MapDatasetFetcher
类型或者 _IterableDatasetFetcher
类型的 Fetcher 实体,分别对应 Map-style 或者 Iterable-style 的数据集。
1
2
3
4
5
6
7
8
9
10class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
对于 _MapDatasetFetcher
类型 Fetcher 来说,如上所示,其定义的 fetch()
函数直接输入 Indices,作为 Map 的 Key,获得对应的样本,然后将获取的样本交给 collate_fn
指定的 Collate Function 进行输出前的整理。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
对于 _IterableDatasetFetcher
类型 Fetcher 来说,如上所示,其在构造函数内设置了对应 Dataset 初始的迭代器,在 fetch()
方法内利用该迭代器获取元素,输入 fetch()
的 Indices 其实已经没有多大作用了。
转移至 Pinned Memory
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data
1 | class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): |
从上面 _SingleProcessDataLoaderIter
的 _next_data
方法中可以看见,在 Fetcher 实体完成 Sample(s) 的加载和处理后,最后一步就是根据用户是否指定将 Tensor 数据转移至 Pinned Memory,调用 PyTorch 官方提供的 _utils.pin_memory.pin_memory
方法进行 Pinned Memory 的转移,这个函数我们在 pinned_memory 中进行了说明。
多进程数据加载
我们下面首先对 PyTorch 多进程加载数据的逻辑进行概述,然后再结合具体代码分析 _MultiProcessingDataLoaderIter
的行为。
如

在上图中,我们将多进程数据加载分为三个部分 —— Initialization (初始化),Fetching (读取处理) 和 Iteration (迭代)。我们下面对着三个部分进行说明。
Initialization
流程说明
首先,当 _MultiProcessingDataLoaderIter
被创建时 (i.e. 其构造函数被运行时),将会创建 num_workers
条 Worker 进程,一条用于将数据转移至 Pinned Memory 的线程,以及它们之间用于数据传输的若干异步队列。这些异步队列包括:
index_queue
: 每个 Worker 进程一条,主进程使用该队列用于通告 BatchSampler 输出的 Sampled Indices 给各个 Worker 进程;_worker_result_queue
: 全局唯一一条,用于存储各个 Worker 进程读取和处理好的 Mini-batches;_data_queue
: 全局唯一一条,用于 Pin Memory Thread 存储完成锁页内存转移的数据;
Initialization 部分除了创建多条 Worker 进程以外,如果 DataLoader
的 pin_memory
参数被使能的话,_MultiProcessingDataLoaderIter
的构造函数还会创建出一条 Python Thread —— Pinned Memory Thread,用于从 _worker_result_queue
中获取各条 Worker 进程读取并处理好的 Mini-batches,然后将数据转移到锁页内存中,最后再把数据放入 _data_queue
中。如果 DataLoader
的 pin_memory
参数没有被使能,则 _worker_result_queue
中的数据就将直接作为 DataLoader
的处理结果。
在完成 Worker 进程,Pin Memory 处理线程和异步队列的创建和初始化后,主进程将会向各个 Worker 进程发送它们分别的首次读取的 Indices,以启动它们的第一次数据读取。
相关代码
上述流程的具体代码体现在 _MultiProcessingDataLoaderIter
的构造函数中,如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
assert self._num_workers > 0
assert self._prefetch_factor > 0
# 选择 Multiprocessing 模块的来源:
# [1] Python 官方 Multiprocessing 模块;或
# [2] PyTorch 提供的 Multiprocessing 模块;
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
# 创建 _num_workers 条 Worker 进程
for i in range(self._num_workers):
# 创建 Per-worker 的 Index Queue
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
# 创建和启动 Worker 进程
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers, self._shared_seed))
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
# 如果 DataLoader 使能了 _pin_memory
# 则创建出一条 Python Thread 用于处理数据的锁页内存转移
# 以及对应的异步队列 _data_queue
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
# Queue is not type-annotated
self._data_queue = queue.Queue() # type: ignore[var-annotated]
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event, self._pin_memory_device))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
# In some rare cases, persistent workers (daemonic processes)
# would be terminated before `__del__` of iterator is invoked
# when main process exits
# It would cause failure when pin_memory_thread tries to read
# corrupted data from worker_result_queue
# atexit is used to shutdown thread and child processes in the
# right sequence before main process exits
if self._persistent_workers and self._pin_memory:
import atexit
for w in self._workers:
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
# 重置 _MultiProcessingDataLoaderIter
# 在这里第一次被调用,实际上是初始化
self._reset(loader, first_iter=True)
上面代码的最后一行调用的 _reset
函数如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
# Not that this indicates that a worker still has work to do *for this epoch*.
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
self._workers_status = [True for i in range(self._num_workers)]
# Reset the worker queue cycle so it resumes next epoch at worker 0
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed))
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
return_idx, return_data = self._get_data()
if isinstance(return_idx, _utils.worker._ResumeIteration):
assert return_data is None
resume_iteration_cnt -= 1
# prime the prefetch loop
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
在上面代码的最后两行,_reset
函数通过一个 for
循环,向各条 Worker 进程的 index_queue
中放入了 Batch Indices,放入的次数是 _prefetch_factor
$\times$ _num_workers
,也即当 _prefetch_factor
$= 1$ 时,每个 Worker 进程就只会收到 $1$ 个 Batch Indices,_prefetch_factor
$= 2$ 时则收到 $2$ 个。
上面代码最后一行所调用的 _try_put_index
函数定义如下所示,其用于向下一个 Active 的 Worker 进程的 _index_queues
中放入 $1$ 个 Batch Indices。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def _try_put_index(self):
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index))
self._task_info[self._send_idx] = (worker_queue_idx,)
self._tasks_outstanding += 1
self._send_idx += 1
在上面的代码中,我们可以看到:
-
调用了
_next_index
函数用于获取下一个 Batch Indices,该函数实际上就是对 Sampler 的一次迭代:1
2
3class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration _MultiProcessingDataLoaderIter
使用_task_info
字典用于记录每一个 Batch Indices 被处理的具体信息,字典的键是处理这个 Batch Indices 的 Worker 的 Index,字典的值是 Worker 进程读取和处理完成后的数据;_MultiProcessingDataLoaderIter
使用_tasks_outstanding
变量用于记录尚未被交付的 (Outstanding) 任务的个数;_MultiProcessingDataLoaderIter
使用_send_idx
变量用于记录当前已经完成部署的 Batch Indices 的批次号;
在完成 _reset
函数的运行后,_MultiProcessingDataLoaderIter
就完成了 Initialization 部分的工作。
Fetching
现在我们来关心在每个 Worker 进程内部发生的事情。Worker 进程实际上就是负责根据主进程的 Sampler 输出的 Indices,在 Dataset 中完成数据的提取和处理,其核心工作流程可以类比单进程数据加载的流程。
在主进程初始化 Worker 进程的代码中,我们注意到 Worker 进程的运行函数是 _utils.worker._worker_loop
函数,在该函数中最关键的程序是其处理循环,我们将这段代码摘抄如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95def _worker_loop(...):
# ...
# 运行 _worker_init_fn,以及创建 fetcher
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
except Exception:
init_exception = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
# When using Iterable mode, some worker can exit earlier than others due
# to the IterableDataset behaving differently for different workers.
# When such things happen, an `_IterableDatasetStopIteration` object is
# sent over to the main process with the ID of this worker, so that the
# main process won't send more tasks to this worker, and will send
# `None` to this worker to properly exit it.
#
# Note that we cannot set `done_event` from a worker as it is shared
# among all processes. Instead, we set the `iteration_end` flag to
# signify that the iterator is exhausted. When either `done_event` or
# `iteration_end` is set, we skip all processing step and just wait for
# `None`.
iteration_end = False
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
# 从 index_queue 中接收来自主进程的消息
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
# 收到 _ResumeIteration —— 重新开始迭代的消息
if isinstance(r, _ResumeIteration):
# Acknowledge the main process
data_queue.put((r, None))
iteration_end = False
if isinstance(dataset, IterDataPipe):
assert r.seed is not None
shared_rng.manual_seed(r.seed)
dataset = apply_shuffle_seed(dataset, shared_rng)
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
continue
# 收到 None —— 通知结束当前 Worker 进程的消息
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
# 如果 Done Event 已经被置位,或者迭代已经结束,则不执行后续的读取操作
# 转而不断重复上面的流程,直到收到 None
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
# 解析收到的消息
idx, index = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
# 从 Dataset 中提取数据
data = fetcher.fetch(index)
except Exception as e:
# 迭代 Iterable Dataset 时遇到 StopIteration 异常,代表迭代结束
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
data = _IterableDatasetStopIteration(worker_id)
# Set `iteration_end`
# (1) to save future `next(...)` calls, and
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
iteration_end = True
else:
# It is important that we don't store exc_info in a variable.
# `ExceptionWrapper` does the correct thing.
# See NOTE [ Python Traceback Reference Cycle Problem ]
data = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
# 将提取并整理好的数据放入 data_queue 中
data_queue.put((idx, data))
del data, idx, index, r # save memory
可以看到,Worker 进程将从 index_queue
中接收指示其进行对应操作的消息 (Line 34),正常来说将从主进程接收到 idx, index
格式的消息,其中 idx
代表了主进程让当前 Worker 进程取出的 Mini-batch 的编号,index
则包含了具体的 Samples 的索引,在利用 Fetcher
实体完成数据的提取和整理后 (Line 77),Worker 进程会把数据放入 _worker_result_queue
中 (Line 94)。
对于 Iterable-style Dataset,Worker 进程设置了 iteration_end
指示变量来标识迭代的结束: 当从 Fetcher
实体中得到 StopIteration
异常时 (Line 80),将会置位 iteration_end
,并且 Worker 进程会通过 _worker_result_queue
将迭代结束的消息通告给主进程 (Line 81)。当 iteration_end
被置位时,后续的处理循环将不会再进行任何的数据获取和处理操作 (Line 62)。当 Worker 进程从 index_queue
中接收到 None
消息时,说明主进程通告当前 Worker 进程结束,Worker 进程跳出处理循环,进程结束 (Line 55)。
Iteration
下面我们来看 _MultiProcessingDataLoaderIter
的 _next_data
方法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(...):
#...
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# ...
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
def _try_put_index(self):
# ...
self._send_idx += 1
# ...
def _next_data(self):
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
while self._rcvd_idx < self._send_idx:
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data()
self._tasks_outstanding -= 1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
return self._process_data(data)