在 Oneflow 中开发 Gather Primitive

前言

Gather 的行为

    Gather 是一种常见的对 Tensor 元素进行提取和部分重排的操作,以 TensorFlow 为例,其 Gather 接口 tf_gather 的函数原型的简化版本如下所示:

1
gather(params, indices, axis=0)

    在参数列表中,params 是操作的源 Tensor,indices 是用于指示排列顺序的整型 Tensor,axis 是用于指示 Gather 维度的整型变量。举个简单的例子来说:

1
2
3
4
params = tf.constant([10.38, 16.19, 19.54, 15.39, 17.21, 8.13])
indices = tf.constant([2, 3])
output = tf.gather(params, indices, axis=0)
# output: ([19.54, 15.39])

    上述程序调用了 Gather 接口,在第 0 维对 params 张量按照 indices 张量所指示的顺序进行了重排。对于 params 张量来说,它也可以是多维的,此时 axis 参数的意义就在于告知 Gather 接口在 params 张量的哪个维度上进行重排,举例如下所示,当在 2 维 params 张量的第 0 维上进行重排时:

1
2
3
4
5
6
7
8
9
10
params = tf.constant([[0, 1.0, 2.0],
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
[30.0, 31.0, 32.0]])
indices = tf.constant([2, 1])

# 在第 0 维上进行重排
output = tf.gather(params, indices, axis=0)
# output: ([[20.0, 21.0, 22.0],
# [10.0, 11.0, 12.0]])

    当在 2 维 params 张量的第 1 维上进行重排时:

1
2
3
4
5
6
7
8
9
10
11
12
params = tf.constant([[0, 1.0, 2.0],
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
[30.0, 31.0, 32.0]])
indices = tf.constant([2, 1])

# 在第 1 维上进行重排
output = tf.gather(params, indices, axis=1)
# output: ([[2.0, 1.0],
# [12.0, 11.0],
# [22.0, 21.0],
# [32.0, 31.0]])

    除了 params 张量可以是多维的以外,indices 张量也可以是多维的,举例如下所示:

1
2
3
4
5
params = tf.constant([10.38, 16.19, 19.54, 15.39, 17.21, 8.13])
indices = tf.constant([[2, 0], [2, 5]])
output = tf.gather(params, indices, axis=0)
# output: ([[19.54, 10.38]
# [19.54, 8.13]])

    下面是一个 params 张量和 indices 张量都是多维张量的复杂例子,当在 2 维 params 张量的第 0 维上使用 2 维 indices 张量进行重排时:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
params = tf.constant([[0, 1.0, 2.0],
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
[30.0, 31.0, 32.0]])
indices = tf.constant([[2, 0], [0, 1]])

# 在第 0 维上进行重排
output = tf.gather(params, indices, axis=0)
# output: (
# [
# [
# [20.0, 21.0, 22.0], [0, 1.0, 2.0]
# ],
# [
# [0, 1.0, 2.0], [10.0, 11.0, 12.0]
# ]
# ]
# )

    当在 2 维 params 张量的第 1 维上使用 2 维 indices 张量进行重排时:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
params = tf.constant([[0, 1.0, 2.0],
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
[30.0, 31.0, 32.0]])
indices = tf.constant([[2, 0], [0, 1]])

# 在第 1 维上进行重排
output = tf.gather(params, indices, axis=1)
# output: (
# [
# [ [2.0, 0], [0, 1.0] ],
# [ [12.0, 10.0], [10.0, 11.0] ],
# [ [22.0, 20.0], [20.0, 21.0] ],
# [ [32.0, 30.0], [30.0, 31.0] ]
# ]
# )

    观察上述例子,从 Tensor 形状来说,输入的 params 张量的形状是 $[\underbrace{4}_{\text{第 0 维}},\underbrace{3}_{\text{第 1 维}}]$,indices 的形状是 $[2,2]$。如果在 params 张量的第 0 维上进行重排,则输出的张量的形状是 $[2,2,3]$; 如果在 params 张量的第 1 维上进行重排,则输出的张量的形状就是 $[4,2,2]$。细心的读者可以发现规律,实际上 Gather 对形状的处理就是将 paramsaxis 指定维度上的形状替换为 indices 的形状,用代码表示即:

1
2
def result_shape(p_shape, i_shape, axis=0):
return p_shape[:axis] + i_shape + p_shape[axis+1:]

Batch Gather 的行为

    在理解了 Gather 接口的行为后,我们再来看 Batch Gather。Batch Gather 允许我们指定 params 张量中有多少维是 Batch 维,指定后,Batch Gather 应用一个拥有相同 Batch 维形状的 indices 张量,对 params 张量中的各个 Batch 元素在指定 axis 上进行重排。还是以 TensorFlow 为例,还是同样的 tf.gather 接口,其支持 Batch Gather 的函数原型如下所示:

1
gather(params, indices, axis=0, batch_dims=0)

    其中,新加入的参数 batch_dims 用于指定 params 拥有的 Batch 维的数目。举个简单的例子,我们可以指定一个形状为 $[4,3]$ 的 params 张量拥有一个维度的 Batch 维 (i.e. 该维度宽度为 $4$),然后应用一个拥有相同 Batch 维宽度的,整体形状为 $[4,2]$ 的 indices 张量,在 params 张量的第 1 维上对其进行重排,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
params = tf.constant([ [0,1.0,2.0], [10.0,11.0,12.0], [20.0,21.0,22.0], [30.0,31.0,32.0] ])
indices = tf.constant([ [2,1], [0,2], [1,1], [1,0] ])

output = tf.gather(params, indices, axis=1, batch_dims=1)
# output: (
# [
# [2.0, 1.0],
# [10.0, 12.0],
# [21.0, 21.0],
# [31.0, 30.0]
# ]
# )

    Batch Gather 也可以更复杂,如 batch_gather_exp_2 所示,我们可以指定一个形状为 $[\underbrace{4,2}_{\text{Batch 维}},3,2]$ 的 params 张量拥有 2 个 Batch 维度 (i.e. batch_dims=2,也即第 0 维和第 1 维是 Batch 维),然后应用一个形状为 $[\underbrace{4,2}_{\text{Batch 维}},5]$ 的 indices 张量对其进行重排,并且指定在 params 张量上进行重排的维度为第 3 维(i.e. axis=3),最终得到一个形状为 $[4,2,3,5]$ 的张量。

    通过上面的例子我们可以发现,实际上 Batch Gather 的工作过程就是: 首先按照指定的 Batch 维度的个数,将 params 张量和 indices 张量分别进行切分,例如在上面的例子中,当指定 Batch 维度数目为 2 时,params 可以切分出 $[4,2]$ 个形状为 $[3,2]$ 的元素,indices 可以切分出 $[4,2]$ 个形状为 $[5]$ 的元素,然后基于参数 axis 指定的 Gather 维度,将这些元素一一对应进行重排操作。

    上面的说法暗示了,Batch Gather 接口对输入参数将有如下 约束:

  • batch_dims 的值不能大于 axis,例如指定 params 有 1 个 Batch 维时,不能在第 0 维上进行 Batch Gather 操作;
  • paramsindices 在 Batch 维需要有相同的形状,否则不能形成一一对应的关系;

    与 Gather 类似,我们也可以总结出 Batch Gather 输出张量的形状的规律:

1
2
def batched_result_shape(params_shape, indices_shape, axis=0, batch_dims=0):
return params_shape[:axis] + indices_shape[batch_dims:] + params_shape[axis+1:]

    另外值得注意的是,上述的 TensorFlow 的接口中,其支持由用户自行指定 Batch 维的数目,而在目前 OneFlow 实现的 Batch Gather Operator of_batch_gather_op 仅支持固定的 Batch 维数目,该数目由输入的 indices 张量决定: indices 张量除了最后一维剩余维度都是 Batch 维,params 的 Batch 维与之相对应,也即 Gather 操作在 indices 的最后一维上进行。

在 OneFlow 中,我们是否需要开发 与 TF 类似的 Batch Gather 逻辑?

OneFlow 如何对 Gather / Batch Gather 进行 SBP

若您对 SBP 相关的概念并不熟悉,可参考 OneFlow 官方文档 of_sbp

对 Gather 进行 SBP

    对于 Gather Operator 来说,可以进行 SBP 的 Tensor 一共有两个: paramsindices。我们下面分为两种情况 (i.e. 将 Broadcast 和 Split 操作分别应用在这两个张量上) 进行讨论。

params 进行 Broadcast,对 indices 进行 Split

    当 Gather 的维度为 axis=0 时,indices 在不同维度下的 Split 的情况如 gather_sbp_exp_1gather_sbp_exp_2 所示:

    上述的例子比较简单,是常规的 SBP 操作。值得注意的是,由于 indices 张量是作用在 params 张量的指定维度 axis 上的,因此当我们思考输出张量的 SBP 形状时,实际上需要将 Gather 指定的维度考虑其中,输出张量的 Split 的维度应该是 Gather 维度 axisindices 的 Split 维度之和。举例如 gather_sbp_exp_6 所示,当指定的 Gather 维度为 1 时,对 indices 进行 S0 的切分,输出张量 output 实际上会形成 S2 的切分.

    综上,我们首先定义了对 params 张量进行 Broadcast,对 indices 张量进行 Split 的 SBP,OneFlow 源码 of_gather_op 如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
auto GatherOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {
const int64_t in_num_axes =
ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes();
const int64_t indices_num_axes =
ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes();
const int64_t gather_axis = ctx->Attr<int64_t>("axis");
CHECK_GE_OR_RETURN(gather_axis, 0);
CHECK_LT_OR_RETURN(gather_axis, in_num_axes);
FOR_RANGE(int64_t, i, 0, indices_num_axes) {
ctx->NewBuilder()
.Split(user_op::OpArg("indices", 0), i)
.Broadcast(user_op::OpArg("in", 0))
.Split(user_op::OpArg("out", 0), gather_axis + i)
.Build();
}
/* ... */
}
params 进行 Split,对 indices 进行 Broadcast

    这种情况比较复杂,因为 params 是 Gather 的对象,对 params 张量进行 Split 的维度 $d_s$ 和 Gather 的维度 $g_s$ 之间的前后关系将会影响到输出张量的 SBP 形状,下面我们分情况进行讨论。

    如 gather_sbp_exp_4 所示,当 Split 的维度 $d_s=0$ 在 Gather 维度 $g_s=1$ 之前时 ($d_s < g_s$),实际上就可以理解为 Split 动作发生在 Batch 维,此时情况就比较简单,输出张量的 Split 维度 $d'_s$ 与 params 张量 $d_s$ 相同,也即 $d'_s = d_s$。

    如 gather_sbp_exp_3 所示,当 Split 的维度 $d_s=1$ 在 Gather 维度 $g_s=0$ 之后时 ($d_s > g_s$),由于 Split 所在的维度后续会受到 Gather 操作的影响 (i.e. 如果 indices 的维度大于 1,则将在 $g_s$ 维后插入新的维度),因此输出张量对应的 Split 维度 $d'_s$ 将会被推高,即 $d'_s = d_s+g_s-1$。

    如 gather_sbp_exp_5 所示,当 Split 的维度 $d_s=0$ 与 Gather 维度 $g_s=0$ 相同时 ($d_s = g_s$),此时 Split 操作将造成部分 Gather 维度部分元素的缺失,此时缺失部分可以采用补 0 的方式进行填充,这样一来,输出张量的 SBP 形状将是 PartialSum。

    综上,我们首先定义了对 params 张量进行 Split,对 indices 张量进行 Broadcast 的 SBP,OneFlow 源码 of_gather_op 如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
auto GatherOp::GetSbp(user_op::SbpContext* ctx) -> Maybe<void> {
const int64_t in_num_axes =
ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes();
const int64_t indices_num_axes =
ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes();
const int64_t gather_axis = ctx->Attr<int64_t>("axis");
CHECK_GE_OR_RETURN(gather_axis, 0);
CHECK_LT_OR_RETURN(gather_axis, in_num_axes);
FOR_RANGE(int64_t, i, 0, in_num_axes) {
if (i == gather_axis) {
ctx->NewBuilder()
.Broadcast(user_op::OpArg("indices", 0))
.Split(user_op::OpArg("in", 0), i)
.PartialSum(user_op::OpArg("out", 0))
.Build();
} else {
ctx->NewBuilder()
.Broadcast(user_op::OpArg("indices", 0))
.Split(user_op::OpArg("in", 0), i)
.Split(user_op::OpArg("out", 0), i < gather_axis ? i : i + indices_num_axes - 1)
.Build();
}
}
return Maybe<void>::Ok();
}

对 Batch Gather 进行 SBP

    对于 Batch Gather 来说,由于 OneFlow 中的实现保证了 Gather Dim 永远是 indices 的最后一维,因此 SBP 的情况比较简单。

    首先有一种基于 PartialSum 的 SBP 方法,如 batch_gather_sbp_exp_1 所示,将 params 进行 PartialSum,将 indices 进行 Broadcast 处理,得到的输出张量将与 params 保持相同的 PartialSum 形状。

    同时也有基于对 paramsindices 在同一个维度上进行 Split 的方法,当 Split 的维度 $d_s=0$ 在 Gather 维度 $g_s=1$ 之前时,此时的 Split 不会影响 Gather 在指定维度上的操作,因此输出张量拥有和输入张量相同的 Split 格式。

    当 Split 的维度 $d_s=1$ 与 Gather 维度 $g_s=1$ 相同时,和 Gather 中阐述过的逻辑类似,此时首先我们需要将 indices 张量改为进行 Broadcast,其次对 params 在 Gather 维度上进行 Split 将导致 Gather 在指定维度上的部分元素缺失,此时可以将输出张量设置为 PartialSum 来补齐缺失元素。

OneFlow 源码目前没有最后一种 SBP 形式,是否需要补齐?

    综上所述,对应到 OneFlow 源码 of_batch_gather_op,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
Maybe<void> BatchGatherOp::GetSbp(user_op::SbpContext* ctx) {
const int64_t indices_num_axes =
ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes();
if (indices_num_axes > 1) {
FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) {
ctx->NewBuilder()
.Split(user_op::OpArg("indices", 0), i)
.Split(user_op::OpArg("in", 0), i)
.Split(user_op::OpArg("out", 0), i)
.Build();
}
ctx->NewBuilder()
.Split(user_op::OpArg("indices", 0), indices_num_axes - 1)
.Broadcast(user_op::OpArg("in", 0))
.PartialSum(user_op::OpArg("out", 0))
.Build();
}
ctx->NewBuilder()
.Broadcast(user_op::OpArg("indices", 0))
.PartialSum(user_op::OpArg("in", 0))
.PartialSum(user_op::OpArg("out", 0))
.Build();
return Maybe<void>::Ok();
}

Oneflow 中 Primitive 是什么?

    本文记录的是将 Gather / Batch Gather 功能开发为一个 OneFlow Primitive 的过程,因此我们需要花一些篇幅来阐述 Primitive 的概念。

Primitive 概念

    comp_graph_op_impt 简单说明了 Computation Graph、Operator 和 Implementation 之间的关系。首先 Computation Graph (计算图) 用于描述程序计算流,各个节点代表的是 Operator (算子),各条边描述的是 Operator 为算子之间的依赖关系,Operator 之间相互依赖的是输入输出 Tensor。对于 Operator 来说,它底层可以有多种实现,例如编写 CUDA Kernel 在 GPGPU 硬件上进行处理,编写 CPU Kernel 在多核处理器上进行处理等。在 OneFlow 中,Primitive 可以理解为供 CUDA/CPU Kernel 程序调用的函数接口,例如 OneFlow 中定义了一系列的 UnaryFunctor of_unary_functor 用于实现各种激活函数的前向传播运算,定义了一系列的 BinaryFunctor of_binary_functor 用于实现各种激活函数的反向传播运算等。

    本文把实际完成计算任务的程序称为 Implementation,这里指的是一个广义概念,包括 CUDA Kernel, CPU Kernel,CUBLAS 等厂商库和 OneFlow Primitive 等和底层实际计算相关的程序。

文件目录

    在 OneFlow 中,与 primitive 相关的有几个文件目录:

文件路径 描述
oneflow/core/ep/include/primitive of_include_primitive 对各个 Primitive 的定义
oneflow/core/ep/common/primitive of_common_primitive 对各个 Primitive 的定义的实现
oneflow/core/ep/cuda/primitive of_cuda_primitive 实现各个 Primitive 的 CUDA Kernel
oneflow/core/ep/cpu/primitive of_cpu_primitive 实现各个 Primitive 的 CPU Kernel

简化 Tensor 形状

    在进入对 Gather 和 Batch Gather 具体实现的讨论之前,我们还需要补充一个背景知识,首先我们提出一个这样的问题:

    用户传入的 Tensor,在符合规则的条件下,其形状是任意的,底层的 CUDA Kernel、CPU Kernel 和 Primitive 等 Implementation 程序是如何适配不同维度数、不同维度大小的输入 Tensor 的呢?

    例如对于本文要实现的 Gather / Batch Gather Primitive 来说,用户传入的 paramsindices 张量形状将是任意的,指定的 axisbatch_dims 也将是任意的。简单一想,如何实现处理任意输入形状的 Implementation 似乎是个棘手的问题。Implementation 是否真的需要关心输入张量的形状呢?我们下面进行分析。

Tensor 包含的两种信息: 形状与存储

    通常来说,我们把 Tensor 理解为高维的矩阵。给定一个 Tensor,我们可以获取两方面的信息:

  • 形状信息: Tensor 的形状通常可以使用一组数字来表示,我们下文称之为形状向量 (Shape Vector),这些数字分别代表着 Tensor 在各个维度上的宽度,如 img_tensor_shape_data 所示,$\text{Tensor_A}$ 的维度向量为 $[5,4,3]$, $\text{Tensor_B}$ 的维度向量为 $[5,12]$。总结来说,"形状" 信息描述了我们看待 Tensor 的方式,其精确地描述了这组数据的内部边界;
  • 存储信息: Tensor 实际上是一组数的组合,这些数在计算机中通常使用一段连续内存用于存储。为了追踪这段内存,程序一般需要 ① 一个指针变量来描述数据存储的起始地址和 ② 一个整形变量来描述数据存储的规模。不同形状的 Tensor 可能拥有相同的底层存储,如 img_tensor_shape_data 所示,虽然 $\text{Tensor_A}$ 和 $\text{Tensor_B}$ 拥有不同的形状,是两个不同的 Tensor,但是它们底层的数据存储是完全相同的。总结来说,"数据" 信息描述了 Tensor 实际存储的内容;

Operator 和 Implementation 分别关心的 Tensor 信息

    对于构成计算图的 Operator 来说,它主要关心输入/输出 Tensor 的 形状信息,Opeartor 除了负责调用底层实现完成数据计算之外,还负责对输入的 Tensor 的形状、数据类型等信息进行合法性检查,以及推导输出 Tensor 的形状。

    而对于 Implementation 来说,它主要关心输入/输出 Tensor 的 存储信息。实际上,Implementation 的任务就是:

  1. 从指定的内存空间中取到操作数;
  2. 完成相应的计算过程;
  3. 将结果写回指定的内存空间;

    Implementation 并不需要关心 Tensor 的具体形状,其工作就是从正确的内存地址取数、将计算结果写回指定的内存地址。不论 Tensor 有多少维度,每个维度的 Size 是多大,在 Implementation 看来都是一段连续内存。因此,Tensor 的形状信息之于 Implementation 是可以被模糊化的,从而简化 Implementation 的实现难度。下面我们进行举例说明。

简化 Implementation: 合并边界信息,简化输入输出形状

例子 1

    举个例子,对于 img_tensor_shape_data 中的 $\text{Tensor_A}$ 来说,如果 Operator 是在 Dim_0 上对其进行操作 (i.e. 不会在 Dim_1 或 Dim_2 上对部分元素进行修改),则 Operator 实际上可以把 $\text{Tensor_A}$ 从形状向量为 $[5,3,4]$ 的三维张量简化为形状向量为 $[5,12]$ 的二维张量 (i.e. $\text{Tensor_B}$) 交给 Implementation 进行处理,这样一来:

  • 如果按照原始 $\text{Tensor_A}$ 的形状 $[5,3,4]$,则 Implementation 单次操作的粒度为 4 个元素 (i.e. $\text{Tensor_A}$ 在 Dim_2 的宽度),操作重复的次数为 15 次 (i.e. $\text{Tensor_A}$ 在 Dim_0 和 Dim_1 的宽度之积);
  • 简化后,Implementation 单次操作的粒度为 12 个元素 (i.e. $\text{Tensor_B}$ 在 Dim_1 的宽度),操作重复的次数为 5 次 (i.e. $\text{Tensor_B}$ 在 Dim_0 的宽度);

    通常来说,后者的 Implementation 往往更加简单,因为其合并了大部分无用的内部边界信息 (i.e. 合并了 Dim_1 和 Dim_2),仅保留了指定处理维度 (i.e. Dim_0) 的形状信息。这样一来,对于高维张量的处理,可以设计一个统一的 Implementation,基于一套处理低维张量的逻辑进行处理。

例子 2

    再举个例子,如 img_matmul 所示,考虑将一个形状向量为 $[2,4,3]$ 的三维张量 $\text{Tensor_A}$ 和形状向量为 $[3,5]$ 的二维张量 $\text{Tensor_W}$ 进行乘法操作的 Operator,其实际上是在 Dim_2 上对 $\text{Tensor_A}$ 进行操作,因此 Operator 在调用底层 Implementation 进行处理时,可以合并 $\text{Tensor_A}$ 在 Dim_0 和 Dim_1 上的内部边界信息,将 $\text{Tensor_A}$ 视为一个形状向量为 $[8,3]$ 的二维张量 $\text{Tensor_A}'$。

    在上述对 Tensor 形状的简化的例子中,第一个例子实现了指定处理维度 (i.e. Dim_0) 之后的边界信息的合并 (i.e. 合并 Dim_1 和 Dim_2); 第二个例子实现了指定处理维度 (i.e. Dim_2) 之前的边界信息的合并 (i.e. 合并 Dim_0 和 Dim_1)。通用地说,在给定处理维度的情况下,Operator 在调用 Implementation 进行处理之前,可以模糊掉对于 Implementation 来说无用的张量形状信息,仅保留:

  1. Processed Dim:指定处理维度的边界信息,指示了 Implementation 处理的维度上有多少个元素;
  2. Outer Dim:指定处理维度之前的合并边界信息,指示了 Implementation 的处理需要重复的次数;
  3. Inner Dim:指定处理维度之后的合并边界信息,指示了 Implementation 处理的维度上各个元素的规模;

    如果还需要举例说明,那么本文的主角 — Gather 就是最好的例子。考虑对一个 $[\underbrace{4}_{\text{Dim_0}},\underbrace{2}_{\text{Dim_1}},\underbrace{3}_{\text{Dim_2}},\underbrace{2}_{\text{Dim_3}},\underbrace{6}_{\text{Dim_4}}]$ 的 params 张量应用 Gather 操作,指定在 Dim_2 上进行,则我们首先可以把 params 张量的 Dim_0 和 Dim_1 模糊掉,视为统一的 Outer Dim (Size 为 $4 \times 2 = 8$),把 Dim_0 和 Dim_1 视为统一的 Inner Dim (Size 为 $2 \times 6 = 12$),把 Dim_2 视为 Processed Dim (Size 为 $3$),最终形成一个 $[8,3,12]$ 的输入张量交给 Implementation 进行处理。可能有读者好奇,那么对于 indices 的形状该如何处理呢?实际上,不论 indices 是什么形状,在 Implementation 看来,它都是一个一维的向量 (i.e. 被化简为 $[\underbrace{1}_{\text{Outer Dim}},\underbrace{x}_{\text{Processed Dim}},\underbrace{1}_{\text{Inner Dim}}]$ 的形状),Implementation 只需要依次从 indices 取出 Gather Index,然后将 params 在其 Processed Dim 上的相应元素搬运到对应位置即可。

    综上所述,这样一来不论输入的张量是什么形状,Implementation 都可以统一地用一套处理三维张量的逻辑进行 Gather 逻辑的处理,这也就回答了本节最开始提出的问题。

NdIndexOffsetHelper: OneFlow 提供的「Offset」和「张量坐标」互转工具

    我们上面说过,张量在内存中是一段连续的内存空间,那么每个张量元素在这段连续内存空间中就会有一个 Offset; 我们上面还说过,在实现 Implementation 时可以将输入 Tensor 的形状简化为 $[d_{\text{Outer}}, d_{\text{Processed}}, d_{\text{Inner}}]$。我们在下面 Gather 和 Batch Gather 的具体实现中将看到两个重要的需求:

  • 给定一个 Offset $l_o$,输出高位张量坐标 $[d_{\text{Outer}}, d_{\text{Processed}}, d_{\text{Inner}}]$;
  • 给定一个 高位张量坐标 $[d_{\text{Outer}}, d_{\text{Processed}}, d_{\text{Inner}}]$,输出该元素在内存空间中存储相应的 Offset $l_o$;

    在 OneFlow 中提供了一个 NdIndexOffsetHelperof_nd_index_offset_helper 用于提供上述的功能 of_ndindex_offset_helper_post。在实例化该类时,我们可以首先在参数列表中传入对应的形状向量:

1
2
#include "oneflow/core/common/nd_index_offset_helper.h"
NdIndexOffsetHelper<int32_t, 3> in_helper(outer_dim_size, gather_dim_size, inner_dim_size);

    然后把这个对象传入 CUDA Kernel,或者 Device Function 中,我们就可以通过它实现 Offset 和 高维张量坐标的互转了:

1
2
3
4
5
6
7
8
__device__ void TestDeviceFunction(NdIndexOffsetHelper<int32_t, 3> in_helper, int32_t elem_cnt) {
int32_t index[3];
CUDA_1D_KERNEL_LOOP_T(int32_t, i, elem_cnt) {
in_helper.OffsetToNdIndex(i, index); // Offset -> 高维张量坐标
int32_t offset = in_helper.NdIndexToOffset(index) // 高维张量坐标 -> Offset
assert(i == offset);
}
}

开发

    下面我们对 Gather 和 Batch Gather Primitive 的具体实现进行说明。

Gather 过程 的实现思路

简单情况

    首先我们来看 Gather 的 CUDA 实现思路,我们在本节中将使用 gather_impt_exp_1 所示的例子进行说明,这个例子中使用了一个形状为 $[2,3]$ 的 indices 张量,在一个形状为 $[2,5,2]$ 的 params 张量的第 1 维上进行 Gather 操作。基于我们在 section_simplied_tensor 中讨论的内容,我们可以把上述张量按照 Outer Dim, Processed Dim 和 Inner Dim 进行排列,化简为如 gather_impt_exp_2 所示的形状。

    设计 CUDA Kernel / Device Function 的一个窍门是: 通常来说,我们会在输出张量的每一个元素上应用一条 Thread 进行处理,因此我们可以从输出张量的单个元素出发,反推为了计算出这个元素,应该使用到哪些输入张量的哪些元素,从而完成整个 Implementation。这里我们应用相同的思路。如 gather_impt_exp_2 所示,我们以输出张量偏移量为 $18$ 的元素为例进行说明。

    首先,我们可以调用上面 section_nd_index_offset_helper 所提到的 NdIndexOffsetHelper 类提供的 OffsetToNdIndex 函数,将 Offset (i.e. $18$) 转化为输出张量的坐标 (i.e. $[1,3,0]$),也即上图中的步骤 ①。

    在获取坐标 $[1,3,0]$ 后,我们就可以根据其 Processed Dim 的值 $3$,知悉当前输出的元素是来源于 indices 中偏移量为 $3$ 的坐标,也即上图中的步骤 ②。

    注意到完成通用形状简化后,output 中的元素所在的 Outer Dim 和 Inner Dim 和它们在 params 中是一致的,因此我们只需要将元素在 output 中的高维坐标 $[d_{\text{Outer}}, d_{\text{Processed}}, d_{\text{Inner}}]$ 中的 Processed Dim 替换为 indices 中偏移量为 $3$ 的元素的值 $0$,就可以得到该输出元素在 params 中的坐标 (i.e. $[1,0,0]$),也即上图中的步骤 ③。

    最后调用 NdIndexOffsetHelper 类提供的 NdIndexToOffset 函数,我们就可以得到对应输入元素在连续内存空间中的 Offset (i.e. $10$),也即上图中的步骤 ④。在知悉元素在输入张量 params 存储空间中的偏移量后,Implementation 就可以从输入张量 params 偏移 $10$ 位置取出元素,放入到输出张量 output 偏移 $18$ 位置上去了。

    针对输出张量的每一个元素,其推导源输入张量偏移量的逻辑都是一样的。我们把上述过程总结为如下 CUDA 程序:

1
// TODO

    基于上述逻辑的 CPU 处理程序如下所示:

1
// TODO

考虑 SBP 切分 Gather 维的情况

    回顾我们在 section_gather_sbp 中讨论过的,当对 Gather Operator 采用对 params 输入张量进行 Split,indices 进行 Broadcast 的 SBP 形式时,如 gather_sbp_exp_5 所示,一旦对 params 输入张量进行 Split 的维度恰好是 Gather 的维度,那么将会导致 indices 中指定的部分 Gather 序号对应的元素的缺失,此时的输出张量 output 将呈现 PartialSum 的 SBP 形式。在这种情况下,底层 Implementation 在处理的时候就需要引入对当前机器是否拥有 Gather 所需元素的判断,如果判断到当前机器没有某个 Index 所指定的元素,则需要往对应的输出元素写 0,以保证 PartialSum 的正确性。下面我们对如何实现这层判断进行讨论。

    我们还是复用上面的例子,并且我们让 params 张量在 Gather 维进行 SBP Split 切分,如 gather_impt_exp_3 所示:Gather 维中偏移量为 0 和 1 的元素在第 0 号机器上,偏移量为 2、3 和 4 的元素在第 1 号机器上。此时我们首先使用一个变量 tensor_offset,来指示每台机器所拥有的张量在 Split 的维度上的偏移量: 第 0 号机器的偏移量 tensor_offset_0 为 0,第 1 号机器的偏移量 tensor_offset_1 为 2。每台机器上的这个变量将用于后续判断当前机器是否拥有某个元素。

    我们使用上面一样的思路,针对每一个输出元素反推出它在输入张量中的 Offset。如 gather_impt_exp_3 所示,对于第 0 号机器来说,流程照常,首先 ① 根据输出元素在内存中的偏移量反推出张量坐标; 然后 ② 根据 Inner Dim 的坐标得到 Index 在 indices 中的偏移量; 完成后,接下来的流程与上面稍有所不同,③ 我们让 Index 在 indices 中的偏移量减去 tensor_offset, 此时如果同时满足以下两个条件:

  • 减法结果大于 0;
  • 减法结果小于当前机器所持有的 params 张量部分在 Processed Dim 上的 Size;

    那么说明当前机器拥有指定的元素,并且减法结果就是对应元素在 params 张量中在 Processed Dim 的偏移量。

    举例如 gather_impt_exp_3 所示,对于第 0 号机器输出张量中偏移量为 18 的元素 (i.e. 值为 50),在进行到 ③ 时,计算得到

$$\text{offset}_{(\text{Processed Dim})} = 0 - \text{tensor_offset_0} = 0$$

    因此可以使用坐标 $[1,\underbrace{0}_{\text{Processed Dim}},0]$ 对输入张量 params 的对应元素进行访问。

    对于第 1 号机器输出张量中偏移量为 17 的元素 (i.e. 值为 71) 来说,在进行到 ③ 时,计算得到:

$$\text{offset}_{(\text{Processed Dim})} = 2 - \text{tensor_offset_1} = 0$$

    因此可以使用坐标 $[1,\underbrace{0}_{\text{Processed Dim}},1]$ 对输入张量 params 的对应元素进行访问。

    如果 $\text{offset}_{(\text{Processed Dim})}$ 求出来小于 $0$,那么说明请求的元素超出了当前机器所持有的张量部分在 Processed Dim 上的下界,因此需要对输出元素填 0。除了判断下界,当然需要判断上界,如果 $\text{offset}_{(\text{Processed Dim})}$ 大于等于当前机器所持有的张量部分在 Processed Dim 上的 Size,则同理需要对输出元素填 0。

    补充上述对 SBP 逻辑的处理,更新后的 CUDA 程序如下所示:

1
// TODO

    CPU 处理程序如下所示:

1
// TODO

Batch Gather 的实现思路

    我们使用 batch_gather_impt_exp_1 所示的例子来说明 Batch Gather 的实现过程。上述的例子使用了一个形状为 $[2,5,2,3]$ 的 indices 张量对形状为 $[2,5,2,3]$ 的 params 张量在第 $2$ 维上进行了 Batch Gather,且指定了 Batch 维度一共有两维,也即第 $0$ 和 $1$ 维。在完成 Batch Gather 后,将输出一个形状为 $[2,5,2,3,3]$ 的 output 张量。

预先计算值

    在进行实际的输出元素计算逻辑的推导前,我们先计算几个后面会使用到的值。

    首先,我们可以确定上述例子中 Batch 的数目:

$\text{Nb_Batch} = \underbrace{2}_{\text{params 第 0 维}} \times \underbrace{5}_{\text{params 第 1 维}} = 10$

    每个 Batch 在被 Gather 之前的 Size 为:

$\text{Size_Origin_Batch} = \underbrace{2}_{\text{params 第 2 维}} \times \underbrace{3}_{\text{params 第 3 维}} = 6$

    并且结合 Gather 前每个 Batch 的形状 $[2,3]$,可以得到每个 Batch 的通用形状简化为:

$\left\{ \begin{aligned} & \text{params}_{\text{Outer Dim}} &= 1 \\ & \text{params}_{\text{Processed Dim}} &= 2 \\ & \text{params}_{\text{Inner Dim}} &= 3 \end{aligned} \right. $

    另外,每个 Batch 在被 Gather 之后的 Size 为:

$\text{Size_Gathered_Batch} = \underbrace{2}_{\text{output 第 2 维}} \times \underbrace{3}_{\text{output 第 3 维}} \times \underbrace{3}_{\text{output 第 4 维}} = 18$

    回顾我们上面对 Gather 逻辑的阐述: 由于我们在针对每个 Batch 进行 Gather 的过程中,我们把当前 Batch 所使用的 indices 视为一条向量,因此 Batch 输入输出的形状拥有相同的 Outer Dim 和 Inner Dim。结合 Gather 后每个 Batch 的形状 $[2,3,3]$,可以得到每个 Batch 的通用形状简化为:

$\left\{ \begin{aligned} & \text{output}_{\text{Outer Dim}} &= 1 \\ & \text{output}_{\text{Processed Dim}} &= 6 \\ & \text{output}_{\text{Inner Dim}} &= 3 \end{aligned} \right. $

    同时,在 indices 中,Per-batch indices 的 Size 为:

$\text{Size_Per_Batch_Indices} = \underbrace{2}_{\text{indices 第 2 维}} \times \underbrace{3}_{\text{indices 第 3 维}} = 6$

输出元素计算逻辑推导

    给定一个输出张量 output 中的元素在内存中的偏移量,如 batch_gather_impt_exp_1 所示我们以偏移量为 $58$ 的元素 (i.e. 值为 $97.3$) 为例,基于 eq_size_gathered_batch,我们可以知道当前这个元素所属的 Batch 的序号为:

$\text{Batch_Id}_{58} = \frac{58}{\text{Size_Gathered_Batch}} = \frac{58}{18} = 3$

    并且这个元素在所在 Batch 在内存中存储的偏移量为:

$\text{Offset_Within_Per_Batch_Output}_{58} = 58 \% \text{Size_Gathered_Batch} = 58 \% 18 = 4$

    有了 Batch Id,结合 $\text{Size_Origin_Batch}$ (eq_size_origin_batch), $\text{Size_Gathered_Batch}$ (eq_size_gathered_batch), $\text{Size_Per_Batch_Indices}$ (eq_size_per_batch_indices),我们就可以以 Batch 为粒度在 params, indicesoutput 张量中进行访问,因此我们可以单独把一个 Batch 拎出来,然后按照上面介绍过的 Gather 逻辑进行推导,然后最后加上当前 Batch 相应的 Offset 即可。按照 eq_params_gerneal_shapeeq_output_gerneal_shape 所指示的通用形状简化,可以得到如 batch_gather_impt_exp_2 所示的 Gather 过程:

    如 batch_gather_impt_exp_2 所示,我们可以再次通过设置 NdIndexOffsetHelper (详见 section_nd_index_offset_helper) 的方法,将 $\text{Offset_Within_Per_Batch_Output}_{58} = 4$ 转为张量坐标 $[0,1,1]$,然后采用和上述 Gather 过程相同的思路,以 $[0,1,1]$ 的 Processed Dim 值 $1$ 作为对应 Per-batch indices 的下标,利用这个下标以及 $\text{Size_Per_Batch_Indices}$ (eq_size_per_batch_indices) 和 $\text{Batch_Id}_{58}$ (eq_batch_id),我们就可以得到对应的 Index 在 indices 张量中的偏移量:

$\text{Offset_Within_Per_Batch_Indices}_{58} = \underbrace{3}_{\text{Batch_Id}_{58}} \times \underbrace{6}_{\text{Size_Per_Batch_Indices}} + 1 = 19$

    基于 $\text{Offset_Within_Per_Batch_Indices}_{58}=19$ 取出 Index 值 $1$ 后,将其替换 $[0,1,1]$ 中的 Processed Dim,最终得到对应元素在 params 中的张量坐标 $[0,1,1]$,然后再次使用 NdIndexOffsetHelper 将其转化为内存偏移量 $4$,也即

$\text{Offset_Within_Per_Batch_Params}_{58} = 4$

    拿到 $\text{Offset_Within_Per_Batch_Params}$ 后,我们就可以根据当前 Batch 对应的 Offset,最终计算得到相应元素在输入张量 params 中的偏移量:

$\begin{aligned} & \text{Offset_Within_Params}_{58} \\ & = \underbrace{3}_{\text{Batch_Id}_{58}} \times \underbrace{6}_{\text{Size_Origin_Batch}} + \underbrace{4}_{\text{Offset_Within_Per_Batch_Params}_{58}} \\ & = 22 \end{aligned} $

    基于上述逻辑的 CUDA 代码程序如下所示:

1
// TODO

    CPU 代码如下所示:

1
// TODO

性能测试