源码解析:DGL 创建图拓扑的流程

⚠ 转载请注明出处:作者:ZobinHuang,更新日期:June 22 2022


知识共享许可协议

    本作品ZobinHuang 采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 进行许可,在进行使用或分享前请查看权限要求。若发现侵权行为,会采取法律手段维护作者正当合法权益,谢谢配合。


目录

有特定需要的内容直接跳转到相关章节查看即可。

正在加载目录...

前言

    本文中我们将对 DGL Runtime 创建和维护图拓扑的流程进行分析,我另外的文章 源码解析:DGL 图拓扑存储 介绍了图拓扑在 DGL 中的存储细节,建议作为前置文章进行阅读。

创建同构图拓扑

顶层 API

    DGL 使用一个唯一的整数来表示一个节点,称为点 ID,并用对应的两个端点 ID 表示一条边。同时,DGL 也会根据边被添加的顺序,给每条边分配一个唯一的整数编号,称为边 ID。节点和边的 ID 都是从 $0$ 开始构建的。在 DGL 的图里,所有的边都是有方向的,即边 $(u,v)$ 表示它是从节点 $u$ 指向节点 $v$ 的。

    我们首先考虑同构图这种比较简单的情况。DGL 中使用 DGLGraph 来代表一个同构图,创建一个 DGLGraph 对象的一种方法是使用 dgl.graph 函数。它接受一个边的集合作为输入。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import dgl
import torch as th

# 边 0->1, 0->2, 0->3, 1->3
u, v = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])
g = dgl.graph((u, v))
print(g) # 图中节点的数量是 DGL 通过给定的图的边列表中最大的点ID推断所得出的
Graph(num_nodes=4, num_edges=4,
ndata_schemes={}
edata_schemes={})

# 获取节点的 ID
print(g.nodes())
# tensor([0, 1, 2, 3])

# 获取边的对应端点
print(g.edges())
# (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))

# 获取边的对应端点和边 ID
print(g.edges(form='all'))
# (tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]), tensor([0, 1, 2, 3]))

# 如果具有最大 ID 的节点没有边,在创建图的时候,用户需要明确地指明节点的数量。
g = dgl.graph((u, v), num_nodes=8)

Python 侧对用户传入参数进行处理

    下面我们对 dgl.graph 这个 API 背后的代码细节进行分析。DGL 初始化一个图的 Code Flow 如下所示。

    我们在上面的代码的 Line 6 中 调用的用于创建图对象的 dgl.graph 函数是在 dgl/convert.py dglsrc_convert_py 下定义的 (p.s. 这个文件定义了若干 API,用于将其它形式的图拓扑数据转化为 DGL 图对象)。在上面的代码中,我们向 dgl.graph 函数中传入了由两个 torch.tensor 组成的 tuple 代表了 COO 形式的图拓扑存储方法。dgl.graph 的代码定义摘抄如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def graph(data,
ntype=None, etype=None,
*,
num_nodes=None,
idtype=None,
device=None,
row_sorted=False,
col_sorted=False,
**deprecated_kwargs):

(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(data, idtype)
if num_nodes is not None: # override the number of nodes
if num_nodes < max(urange, vrange):
raise DGLError('The num_nodes argument must be larger than the max ID in the data,'
' but got {} and {}.'.format(num_nodes, max(urange, vrange) - 1))
urange, vrange = num_nodes, num_nodes

g = create_from_edges(sparse_fmt, arrays, '_N', '_E', '_N', urange, vrange,
row_sorted=row_sorted, col_sorted=col_sorted)

return g.to(device)

    对于 COO/CSR 格式的稀疏矩阵存储格式,我们在上一篇文章中有所提及,这里不再赘述。 DGL 推荐使用一个 一维 的整型张量(如,PyTorch 的 Tensor 类,TensorFlow 的 Tensor 类或 MXNet 的 ndarray 类)来保存稀疏矩阵的存储内容。

对输入参数进行整理

    值得注意的是,传入 dgl.graph 的 tuple 类型的参数 data 可以接受以下几种常见格式:

  • ("coo", torch.tensor, torch.tensor): 显式指明 COO 的拓扑表示方法,并通过两个 torch.tensor 对象传入拓扑信息;
  • ("csr", torch.tensor, torch.tensor, torch.tensor): 显式指明 CSR 的拓扑表示方法,并通过三个 torch.tensor 对象传入拓扑信息;
  • ("csc", torch.tensor, torch.tensor, torch.tensor): 显式指明 CSC 的拓扑表示方法,并通过三个 torch.tensor 对象传入拓扑信息;
  • (torch.tensor, torch.tensor): 不指明使用的稀疏矩阵存储方法,会被 DGL 一概当作 COO 格式处理;
  • ("coo", list, list): 显式指明使用的稀疏矩阵存储方法,并且通过两个 Python List 传入拓扑信息
  • (list, list): 不指明使用的稀疏矩阵存储方法,会被 DGL 一概当作 COO 格式处理;

    当然,上面举的例子中我们默认使用的 backend 是 PyTorch,DGL 也支持使用 Tensorflow 和 MXNet 作为 backend。

    另外,DGL 还支持基于 networkxscipy 等第三方库的数据作为 dgl.graph 函数的 data 参数。为了应对 data 不同的信息格式,上面的代码在 Line 11 调用了 utils.graphdata2tensors 函数 (i.e. codeflow_param 中 ①) 对输入的 tuple 进行统一处理,该函数最终返回了以下 4 个信息:

  • sparse_fmt: 字符串,代表稀疏矩阵存储格式,可取值有: "coo", "csc", "csr"
  • arrays: (torch.tensor, torch.tensor) 的 tuple,代表了在对应的稀疏矩阵存储格式下的图拓扑信息;
  • urange: 源节点的个数;
  • vrange: 目的节点的个数

    这里使用的 utils.graphdata2tensors 函数是在 dgl/utils/data.py dglsrc_utils_data_py 中定义的,代码细节摘抄如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from collections import namedtuple

SparseAdjTuple = namedtuple('SparseAdjTuple', ['format', 'arrays'])

def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
# 如果 data是一个元组,
# 那么将 data 元组转化为 ('coo', (row, col)) 或 ('csr', (indptr, indices, eids)) 等形式的 SparseAdjTuple 数据
if isinstance(data, tuple):
if not isinstance(data[0], str):
# (row, col) format, convert to ('coo', (row, col))
data = ('coo', data)
data = SparseAdjTuple(*data)

# 如果没有传入 idType,并且传入的 tuple 的两个元素不是 tensor 时,强行设置 idtype 为 int64。
# 否则,在后面的代码中我们会看到:
# [1] 要么使用传入的 idType;
# [2] 要么传入的 idType 为空,直接从传入的 tensor tuple 中推理出 idtype
if idtype is None and \
not (isinstance(data, SparseAdjTuple) and F.is_tensor(data.arrays[0])):
# preferred default idtype is int64
# if data is tensor and idtype is None, infer the idtype from tensor
idtype = F.int64

# 检查 idtype 是否为 None, int32, int64 中的一种
checks.check_valid_idtype(idtype)

# 如果传入的 tutle 的两个元素不是 tensor 时,将 Iterable 的对象转化为 tensor
# (Iterable, Iterable) type data, convert it to (Tensor, Tensor)
if isinstance(data, SparseAdjTuple) and (not all(F.is_tensor(a) for a in data.arrays)):
if len(data.arrays[0]) == 0:
# 如果发现 Iterable 的对象是一个空表,那么我们在将其转化为 tensor 的时候需要强行指定 tensor 的类型
data = SparseAdjTuple(data.format, tuple(F.tensor(a, idtype) for a in data.arrays))
else:
# 如果 Iterable 的对象不为空表,那我们保持它原本的类型就行
# convert the iterable to tensor and keep its native data type so we can check
# its validity later
data = SparseAdjTuple(data.format, tuple(F.tensor(a) for a in data.arrays))

if isinstance(data, SparseAdjTuple):
# 如果传入的 idType 不为空,那么将我们创建的 tensor 设置为传入的 idType 指定的类型
if idtype is not None:
data = SparseAdjTuple(data.format, tuple(F.astype(a, idtype) for a in data.arrays))

# 推断图中的源节点和目的节点的数目
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)

return data, num_src, num_dst

    如上所述,graphdata2tensors 的功能是将多种支持的图拓扑数据输入格式,统一为 DGL 底层运行的 Backend (i.e. PyTorch, Tensorflow 和 MXNet) 的 Tensor 类型数据。graphdata2tensorsdata 输入参数存在很多种可能性,因此程序对各种可能性进行了处理。

    首先,如果传入的 data 是一个 tuple,那么在 Line 6~12 处程序将其转化为了规范化的形如 ('coo', (row, col)) 或者 ('csr', (indptr, indices, eids)) 的格式。这里的 rowcol 数据可能本身传入的就是 Backend 的 Tensor 类型数据,也可能是 Python 原生的 Iterable 对象 (e.g. list),我们后面需要进行判断和处理。

    在进入判断和处理之前,还有一个值得注意的是,在上面代码的 Line 3 中,程序首先借助 Python 官方的 collections 模块 python_collection_container,调用其中的 namedtuple 工厂函数,创建了一个元组子类 SparseAdjTuple; 在 Line 12 当我们将这个子类应用于 graphdata2tensors 的输入参数 data 的时候,我们在后续就可以通过 data.formatarrays 的方式来访问 data 中的内容。

    另外,在下面的讨论中,我们对于传入的 data 不是一个 tuple 的情况不予以讨论,这种情况下,传入的参数一般是基于 networkxscipy 等第三方库的数据,涉及到第三方库数据像 Backend Tensor 的转化,我们不进行研究。

    来到 Line 18~22 处的代码,这里程序处理的是传入的 idtype —— 也即图拓扑数据中,节点的 Index 的类型。分为两种情况:

  • 当发现传入的 idtype 为空,并且传入的 data 中的 rowcol 不为 Backend 对应的 Tensor 类型的时候,此时程序就需要手动指定 idtype。程序默认使用的是所使用 Backend 所规定的 int64 类型;
  • 如果发现虽然传入的 idtype 为空,但是传入的 data 中的 rowcol 是Backend 对应的 Tensor 类型的时候,此时程序则不着急指定 idtype 类型,因为后续可以直接从 Tensor 中提取出对应的类型。

    此时来到 Line 29~37 处,此时代码完成了从 Python 原生 Iterable 对象 (e.g. list) 到 Backend Tensor 类型的转化。到了 Line 39~42,代码又强制转化了一下 Tensor 的类型,可能是为了防止用户传入的 Tensor 的数据类型和 idtype 声明的数据类型不匹配的情况。

    到了 Line 45,程序调用了 infer_num_nodes 函数对图中包含的源节点和目的节点的数量进行了判断。 infer_num_nodes 函数是在同个文件下定义的,具体如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def infer_num_nodes(data, bipartite=False):
"""Function for inferring the number of nodes.

Parameters
----------
data : graph data
Supported types are:

* SparseTuple ``(sparse_fmt, arrays)`` where ``arrays`` can be either ``(src, dst)`` or
``(indptr, indices, data)``.
* SciPy matrix.
* NetworkX graph.
bipartite : bool, optional
Whether infer number of nodes of a bipartite graph --
num_src and num_dst can be different.

Returns
-------
num_src : int
Number of source nodes.
num_dst : int
Number of destination nodes.

or

None
If the inference failed.
"""
if isinstance(data, tuple) and len(data) == 2:
if not isinstance(data[0], str):
raise TypeError('Expected sparse format as a str, but got %s' % type(data[0]))

if data[0] == 'coo':
# ('coo', (src, dst)) format
u, v = data[1]
nsrc = F.as_scalar(F.max(u, dim=0)) + 1 if len(u) > 0 else 0
ndst = F.as_scalar(F.max(v, dim=0)) + 1 if len(v) > 0 else 0
elif data[0] == 'csr':
# ('csr', (indptr, indices, eids)) format
indptr, indices, _ = data[1]
nsrc = F.shape(indptr)[0] - 1
ndst = F.as_scalar(F.max(indices, dim=0)) + 1 if len(indices) > 0 else 0
elif data[0] == 'csc':
# ('csc', (indptr, indices, eids)) format
indptr, indices, _ = data[1]
ndst = F.shape(indptr)[0] - 1
nsrc = F.as_scalar(F.max(indices, dim=0)) + 1 if len(indices) > 0 else 0
else:
raise ValueError('unknown format %s' % data[0])
elif isinstance(data, sp.sparse.spmatrix):
nsrc, ndst = data.shape[0], data.shape[1]
elif isinstance(data, nx.Graph):
if data.number_of_nodes() == 0:
nsrc = ndst = 0
elif not bipartite:
nsrc = ndst = data.number_of_nodes()
else:
nsrc = len({n for n, d in data.nodes(data=True) if d['bipartite'] == 0})
ndst = data.number_of_nodes() - nsrc
else:
return None
if not bipartite:
nsrc = ndst = max(nsrc, ndst)
return nsrc, ndst

    我们重点关心上面的代码中是如何推断出各种稀疏矩阵存储格式下,源节点和目的节点的个数的:

  • COO (Line 33~37): 源节点的数目取自 row 的最大值; 目的节点的数目取自 col 的最大值;
  • CSR (Line 39~42): 源节点的数目取自 indptr 的长度; 目的节点的数目取自 indices 的最大值;
  • CSC (Line 44~47): 源节点的数目取自 indptr 的长度; 目的节点的数目取自 indices 的最大值;

    值得注意的是,源节点和目的节点的数量并不代表图中的点的数量,因为图中可能有孤立点。

    回到 graphdata2tensors 函数,在获得源节点和目的节点的数量后,函数将转化好的 Tensor 一并打包,进行返回 (i.e. codeflow_param 中 ②)。

处理节点数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def graph(data,
ntype=None, etype=None,
*,
num_nodes=None,
idtype=None,
device=None,
row_sorted=False,
col_sorted=False,
**deprecated_kwargs):
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(data, idtype)
if num_nodes is not None: # override the number of nodes
if num_nodes < max(urange, vrange):
raise DGLError('The num_nodes argument must be larger than the max ID in the data,'
' but got {} and {}.'.format(num_nodes, max(urange, vrange) - 1))
urange, vrange = num_nodes, num_nodes
# ...

    回到 dgl.graph 函数,我们上面讨论的 graphdata2tensors 函数返回了稀疏矩阵格式、Tensors,以及源节点和目的节点。返回后程序的第一件事就是对源节点和目的节点的范围进行重新调整,如上 Line 11~15 所示。我们上面提到过,源节点和目的节点的数量并不代表图中的点的数量,因为图中可能有孤立点,因此上述程序对源节点和目的节点的数量进行了调整。

创建 DGLHeteroGraph 实例

    继续 dgl.graph 函数,完成节点数量的调整后,其在 Line 18 处其调用了位于同一个文件 dgl/convert.py dglsrc_convert_py 下定义的 dgl.create_from_edges 函数 (i.e. codeflow_param 中 ②),基于已知的关于图拓扑的信息,创建出一个 DGLHeteroGraph 实例,该函数的代码细节如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def create_from_edges(sparse_fmt, arrays,
utype, etype, vtype,
urange, vrange,
row_sorted=False,
col_sorted=False):
if utype == vtype:
num_ntypes = 1
else:
num_ntypes = 2

if sparse_fmt == 'coo':
u, v = arrays
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'],
row_sorted, col_sorted)
else: # 'csr' or 'csc'
indptr, indices, eids = arrays
hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, urange, vrange, indptr, indices, eids, ['coo', 'csr', 'csc'],
sparse_fmt == 'csc')
if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])

    在 Line 6-9 中,函数首先对传入的 utype (源节点类型) 和 vtype (目的节点类型) 进行一致性判断,从而得出图中节点种类数 num_ntypes。在一个 Unit Graph 中,要么点的种类有两种 (i.e., 有向图),要么点的种类有一种 (i.e., 无向图),因此这里的 num_ntypes 的取值范围为 1 或者 2。

    我们在这里调用的 DGLGraph API,在调用 create_from_edges 的时候,传入的 utypevtype 是相同的值 _N,因此 num_ntypes 的取值为 1。

    然后,在 Line 11-20 中,基于稀疏矩阵存储格式的不同,分别调用了 create_unitgraph_from_coo 或者 create_unitgraph_from_csr 函数 (i.e. codeflow_param 中 ③),这两个函数是在 dgl/heterograph_index.py dglsrc_heterograph_index_py 中定义的。以前者为例,我们来看其详细定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
formats, row_sorted=False, col_sorted=False):
"""Create a unitgraph graph index from COO format

Parameters
----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
row : utils.Index
Row index.
col : utils.Index
Col index.
formats : list of str.
Restrict the storage formats allowed for the unit graph.
row_sorted : bool, optional
Whether or not the rows of the COO are in ascending order.
col_sorted : bool, optional
Whether or not the columns of the COO are in ascending order within
each row. This only has an effect when ``row_sorted`` is True.

Returns
-------
HeteroGraphIndex
"""
if isinstance(formats, str):
formats = [formats]
return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst),
F.to_dgl_nd(row), F.to_dgl_nd(col),
formats, row_sorted, col_sorted)

    函数 create_unitgraph_from_coo 的传入参数的含义如下所示:

  • num_ntypes: Unit Graph 上点的种类数 (i.e. 1 或 2);
  • num_src: 源节点的数量;
  • num_dst: 目的节点的数量;
  • row: 以 COO 形式存储的 Row Index;
  • col: 以 COO 形式存储的 Column Index;
  • formats: 一个列表,限制了底层 Unit Graph 存储时使用的矩阵存储形式 (i.e., COO, CSR 和 CSC)。用户传入的 rowcol 拓扑信息是使用 COO 形式指定的,但是 DGL 在底层转换为 Unit Graph 存储的时候,可以使用其他的格式;
  • row_sorted: 布尔变量,标识 row 是否进行了排序;
  • col_sorted: 布尔变量,标识 col 是否进行了排序;

    从 create_unitgraph_from_coo 的代码可以发现,create_unitgraph_from_coo 函数实际上调用了动态链接库中的 _CAPI_DGLHeteroCreateUnitGraphFromCOO 函数。在 源码解析:借鉴 TVM 的 Python 和 C++ 的调用机制 一文中我们分析过 DGL 是如何基于 TVM 的 FFI 机制实现 Python 运行时对动态链接库中封装的 API 的调用,我们在这里不再过多赘述。简单来说就是在当前我们分析的 create_unitgraph_from_coo 函数所在的 dgl/heterograph_index.py dglsrc_heterograph_index_py 文件中,有以下代码:

1
2
3
4
5
from ._ffi.function import _init_api

# ......

_init_api("dgl.heterograph_index")

    _init_api 的调用使得当前模块拥有了所有的以 dgl.heterograph_index 为前缀的动态链接库中的 API,其中就包括了我们在这里调用的 _CAPI_DGLHeteroCreateUnitGraphFromCOO

    另外,值得注意的是,create_unitgraph_from_coo 在调用 _CAPI_DGLHeteroCreateUnitGraphFromCOO 的时候,对于第 4 和第 5 个参数,它使用了我们在 源码解析:DGL 的内存管理方法 中介绍的 F.to_dgl_nd API 实现了 PyTorch Tensor 到 DGL NDArray 的转化。

动态链接库侧 Runtime 创建图拓扑实例

    现在让我们来看动态链接库中关于 _CAPI_DGLHeteroCreateUnitGraphFromCOO 这个函数的定义,这个函数是在 src/graph/heterograph_capi.cc dglsrc_src_graph_heterograph_capi_cc 中被定义的,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0]; // 节点种类数
int64_t num_src = args[1]; // 源节点个数
int64_t num_dst = args[2]; // 目的节点个数
IdArray row = args[3]; // Row Index
IdArray col = args[4]; // Col Index
List<Value> formats = args[5]; // Unit Graph 的可用存储格式列表
bool row_sorted = args[6];
bool col_sorted = args[7];

// 获取代表 Unit Graph 底层存储形式的 code
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
const auto code = SparseFormatsToCode(formats_vec);

// 在底层创建 Unit Graph
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col,
row_sorted, col_sorted, code);

// 返回底层 Unit Graph 的指针
*rv = HeteroGraphRef(hgptr);
});

    在上面的代码中,_CAPI_DGLHeteroCreateUnitGraphFromCOO 首先对传入的 Unit Graph 的存储格式列表进行了整理,然后生成了指定存储格式的 code 变量。接着在 Line 21 调用了在 rc/graph/creators.cc dglsrc_src_graph_creators_cc 中定义的函数 CreateFromCOO (i.e. codeflow_dll 中 ①),该函数的定义如下所示:

1
2
3
4
5
6
7
8
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col,
bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO(
num_vtypes, num_src, num_dst, row, col, row_sorted, col_sorted, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}

    可以发现,CreateFromCOO 实际上调用了 UnitCOO 类的 CreateFromCOO 方法 (i.e. codeflow_dll 中 ②),该方法在 dgl/src/graph/unit_graph.cc dglsrc_src_graph_unit_graph_cc 中进行了定义,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col,
bool row_sorted, bool col_sorted,
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, num_src, num_dst, row, col,
row_sorted, col_sorted));

return HeteroGraphPtr(
new UnitGraph(mg, nullptr, nullptr, coo, formats));
}

HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, mat));

return HeteroGraphPtr(
new UnitGraph(mg, nullptr, nullptr, coo, formats));
}

    在确认 num_vtypesnum_srcnum_dst 三者的数值关系无误后,Line 9 调用了 CreateUnitGraphMetaGraph 函数 (i.e. codeflow_dll 中 ③),以基于图中点的种类数 num_vtypes 创建当前 Unit Graph 的 MetaGraph。回顾我们在 metagraph 中提到的,对于 Unit Graph 来说,其 Meta Graph 的形式只有两种情况。下面让我们来看 CreateUnitGraphMetaGraph 的具体定义,它位于 src/graph/unit_graph.cc dglsrc_src_graph_unit_graph_cc 中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) {
static GraphPtr mg1 = CreateUnitGraphMetaGraph1();
static GraphPtr mg2 = CreateUnitGraphMetaGraph2();
if (num_vtypes == 1)
return mg1;
else if (num_vtypes == 2)
return mg2;
else
LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2.";
return {};
}

// create metagraph of one node type
inline GraphPtr CreateUnitGraphMetaGraph1() {
// a self-loop edge 0->0
std::vector<int64_t> row_vec(1, 0);
std::vector<int64_t> col_vec(1, 0);
IdArray row = aten::VecToIdArray(row_vec);
IdArray col = aten::VecToIdArray(col_vec);
GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col);
return g;
}

// create metagraph of two node types
inline GraphPtr CreateUnitGraphMetaGraph2() {
// an edge 0->1
std::vector<int64_t> row_vec(1, 0);
std::vector<int64_t> col_vec(1, 1);
IdArray row = aten::VecToIdArray(row_vec);
IdArray col = aten::VecToIdArray(col_vec);
GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col);
return g;
}

    可以发现,CreateUnitGraphMetaGraph 根据 num_vtypes 为 1 或 2 的情况,分别调用 CreateUnitGraphMetaGraph1CreateUnitGraphMetaGraph2,以分别创建带有 1 个或 2 个点的 Meta Graph,Meta Graph 是使用基于 COO 格式的 ImmutableGraph 表示的。

    另外值得注意的是,在上面代码的 Line 2~3 中,我们可以发现 CreateUnitGraphMetaGraph 是以 static 的方式创建这两种情况的 Meta Graph,然后根据 num_vtypes 的值返回对应的 Meta Graph 实例。

    回到 UnitGraph::CreateFromCOO 函数 (i.e. codeflow_dll 中 ④),在获得 Meta Graph 后,在 Line 10 (Line 24) 基于该 Meta Graph 创建了一个 UnitGraph::COO 类,函数最后在 Line 13~14 (Line 26~27) 基于 Meta Graph 和新创建的 UnitGraph::COO 类,创建了 UnitGraph 类,并返回。

    在完成 UnitGraph 的创建后,CreateFromCOO 基于此创建了 HeteroGraph 实例 (i.e. codeflow_dll 中 ⑤),并返回。

    最终,最外层的函数 _CAPI_DGLHeteroCreateUnitGraphFromCOO 在拿到 HeteroGraph 实例后 (i.e. codeflow_dll 中 ⑥),将该实例作为该 C++ API 的返回值进行返回。

    总结来说,在动态链接库 Runtime 侧创建的图拓扑表示形式可以总结为上图。

Python 侧后续处理

    在 Runtime 中完成图拓扑的创建后,让我们回到 Python 侧程序。为了查看方便,上图 codeflow_param_2codeflow_param 的一份拷贝。

    DGL 在 Python 侧使用 HeteroGraphIndex 对 Runtime 侧返回的 HeteroGraph 进行封装,这个类是在 python/dgl/heterograph_index.pydglsrc_heterograph_index_py 中定义的。这个类中对 Runtime 侧的 HeteroGraph 类的部分公有成员函数进行了封装。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def create_from_edges(sparse_fmt, arrays,
utype, etype, vtype,
urange, vrange,
row_sorted=False,
col_sorted=False):
if utype == vtype:
num_ntypes = 1
else:
num_ntypes = 2

if sparse_fmt == 'coo':
u, v = arrays
hgidx = heterograph_index.create_unitgraph_from_coo(
num_ntypes, urange, vrange, u, v, ['coo', 'csr', 'csc'],
row_sorted, col_sorted)
else: # 'csr' or 'csc'
indptr, indices, eids = arrays
hgidx = heterograph_index.create_unitgraph_from_csr(
num_ntypes, urange, vrange, indptr, indices, eids, ['coo', 'csr', 'csc'],
sparse_fmt == 'csc')
if utype == vtype:
return DGLHeteroGraph(hgidx, [utype], [etype])
else:
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])

    回到 create_from_edges 函数 (i.e. codeflow_param_2 中 ④),如上代码所示,在将来自动态链接库 Runtime 侧的异构图实例封装为 HeteroGraphIndex 后,基于此实例化了一个 DGLHeteroGraph 类,DGLHeteroGraph 类的成员变量 _graph 用于存储该 HeteroGraphIndex

创建异构图

    上面我们分析了创建同构图的流程,下面我们对创建异构图的流程进行分析。DGL 使用 dgl.heterograph API 构建一个异构图。如下所示,该 API 传入的参数是一个字典,字典的 Key 值代表了一个 Unit Graph 所对应的关系,也即 (Source Type, Edge Type, Destination Type);字典的 Value 值代表了该 Unit Graph 的拓扑。

1
2
3
4
5
6
7
8
9
10
import dgl
import torch as th

graph_data = {
('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),
('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))
}

g = dgl.heterograph(graph_data)

    下面我们将对异构图的创建流程进行分析,codeflow_hetero 展示了代码流程。

    dgl.heterograph 是在 python/dgl/convert.pydglsrc_convert_py 中定义的,定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def heterograph(data_dict,
num_nodes_dict=None,
idtype=None,
device=None):
# Convert all data to node tensors first
node_tensor_dict = {}
need_infer = num_nodes_dict is None
if num_nodes_dict is None:
num_nodes_dict = defaultdict(int)

# 整理用户输入的各个字典
for (sty, ety, dty), data in data_dict.items():
if isinstance(data, spmatrix):
raise DGLError("dgl.heterograph no longer supports graph construction from a SciPy "
"sparse matrix, use dgl.from_scipy instead.")

if isinstance(data, nx.Graph):
raise DGLError("dgl.heterograph no longer supports graph construction from a NetworkX "
"graph, use dgl.from_networkx instead.")
is_bipartite = (sty != dty)

# 调用 graphdata2tensors 获得对应 Unit Graph 的信息
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=is_bipartite)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)

# 获取以各个类型为源/目的节点的数量信息
if need_infer:
num_nodes_dict[sty] = max(num_nodes_dict[sty], urange)
num_nodes_dict[dty] = max(num_nodes_dict[dty], vrange)
else: # sanity check
if num_nodes_dict[sty] < urange:
raise DGLError('The given number of nodes of node type {} must be larger than'
' the max ID in the data, but got {} and {}.'.format(
sty, num_nodes_dict[sty], urange - 1))
if num_nodes_dict[dty] < vrange:
raise DGLError('The given number of nodes of node type {} must be larger than'
' the max ID in the data, but got {} and {}.'.format(
dty, num_nodes_dict[dty], vrange - 1))
# Create the graph
metagraph, ntypes, etypes, relations = heterograph_index.create_metagraph_index(
num_nodes_dict.keys(), node_tensor_dict.keys())
num_nodes_per_type = utils.toindex([num_nodes_dict[ntype] for ntype in ntypes], "int64")
rel_graphs = []
for srctype, etype, dsttype in relations:
sparse_fmt, arrays = node_tensor_dict[(srctype, etype, dsttype)]
g = create_from_edges(sparse_fmt, arrays, srctype, etype, dsttype,
num_nodes_dict[srctype], num_nodes_dict[dsttype])
rel_graphs.append(g)

# create graph index
hgidx = heterograph_index.create_heterograph_from_relations(
metagraph, [rgrh._graph for rgrh in rel_graphs], num_nodes_per_type)
retg = DGLHeteroGraph(hgidx, ntypes, etypes)

return retg.to(device)

整理用户输入数据

    针对字典中的各个 Unit Graph,函数首先调用了 utils.graphdata2tensors 函数来整理各个 Unit Graph 的信息 (codeflow_hetero 中 ①),以获取拓扑表示的格式,以及对应的拓扑,该函数我们在上面进行了介绍,这里不再赘述。被处理好的各个 Unit Graph 的信息,被放入字典 node_tensor_dictnum_nodes_dict 中,它们的功能和格式分别如下:

字典 形式 功能
node_tensor_dict {(sty,ety,dty): (sparse_fmt, arrays)} 记录每一个 Unit Graph 的拓扑信息
num_nodes_dict {ty: number} 记录每一种类型的节点的数量

创建 Meta Graph

    在完成对输入字典的整理后,基于 node_tensor_dictnum_nodes_dict 两个字典,进一步调用了在 python/dgl/heterograph_index.pydglsrc_heterograph_index_py 中定义的 create_metagraph_index 函数 (codeflow_hetero 中 ②),用于创建当前异构图的 Meta Graph。具体定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def create_metagraph_index(ntypes, canonical_etypes):
ntypes = list(sorted(ntypes))
relations = list(sorted(canonical_etypes))
ntype_dict = {ntype: i for i, ntype in enumerate(ntypes)}
meta_edges_src = []
meta_edges_dst = []
etypes = []
for srctype, etype, dsttype in relations:
meta_edges_src.append(ntype_dict[srctype])
meta_edges_dst.append(ntype_dict[dsttype])
etypes.append(etype)
# metagraph is DGLGraph, currently still using int64 as index dtype
metagraph = from_coo(len(ntypes), meta_edges_src, meta_edges_dst, True)
return metagraph, ntypes, etypes, relations

    create_metagraph_index 首先对节点的类型进行了排序,存储在 ntypes 中,以及对点之间的关系进行了排序,存储在 relation 中。

    接着,create_metagraph_index 对传入的节点类型 ntypes 以及所有的边类型 canonical_etypes 进行处理: 在 Line 4 先将节点类型转化为具体的数字 index,然后在 Line 5~11 使用转化好的数字 index 来记录所有的边类型。并把各种类型的边类型记录在 etypes 中。

    随后,create_metagraph_index 在 Line 13 调用了在 python/dgl/graph_index dglsrc_python_dgl_graph_index_py 中定义的函数 from_coo (codeflow_hetero 中 ③),该函数定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def from_coo(num_nodes, src, dst, readonly):
src = utils.toindex(src)
dst = utils.toindex(dst)
if readonly:
gidx = _CAPI_DGLGraphCreate(
src.todgltensor(),
dst.todgltensor(),
int(num_nodes),
readonly)
else:
gidx = _CAPI_DGLGraphCreateMutable()
gidx.add_nodes(num_nodes)
gidx.add_edges(src, dst)
return gidx

    可见,from_coo 通过调用动态链接库里的 API,在 Runtime 中维护了一个当前异构图对应的 Meta Graph 的 ImmutableGraph 实例。

    在完成 Meta Graph 的创建后,create_metagraph_index将 Meta Graph,以及 ntypes, etypesrelations 等内容进行返回 (codeflow_hetero 中 ④)。

为每种关系创建 Unit Graph

    完成 Meta Graph 的创建后,函数接下来的任务是为每一种关系创建 Unit Graph。针对在 relation 中存储的异构图中的每一种关系,函数调用了我们熟悉的 create_from_edges 函数 (codeflow_hetero 中 ⑤),以在 Runtime 中维护各个 Unit Graph 的 HeteroGraph 实例,并将这些实例存储在 rel_graphs 列表中。

创建异构图拓扑

    获取存储着所有 Unit Graph 实例的 rel_graphs 后,结合上面创建的 Meta Graph 实例,函数调用了 create_heterograph_from_relations 函数 (codeflow_hetero 中 ⑥),以创建异构图实例。这个函数是在 python/dgl/heterograph_index.pydglsrc_heterograph_index_py 中被定义的,定义如下所示:

1
2
3
4
5
6
def create_heterograph_from_relations(metagraph, rel_graphs, num_nodes_per_type):
if num_nodes_per_type is None:
return _CAPI_DGLHeteroCreateHeteroGraph(metagraph, rel_graphs)
else:
return _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes(
metagraph, rel_graphs, num_nodes_per_type.todgltensor())

    可以看到其实际上调用了动态链接库侧的函数 _CAPI_DGLHeteroCreateHeteroGraphWithNumNodes,后者是在 src/graph/heterograph_capi.ccdglsrc_src_graph_heterograph_capi_cc 中定义的,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1];
IdArray num_nodes_per_type = args[2];
std::vector<HeteroGraphPtr> rel_ptrs;
rel_ptrs.reserve(rel_graphs.size());
for (const auto& ref : rel_graphs) {
rel_ptrs.push_back(ref.sptr());
}
auto hgptr = CreateHeteroGraph(
meta_graph.sptr(), rel_ptrs, num_nodes_per_type.ToVector<int64_t>());
*rv = HeteroGraphRef(hgptr);
});

    后者首先创建了一个用于存储 HeteroGraphPtr 的 STL vector 容器,然后将传入的所有 Unit Graph 的指针都存储该容器中,接着调用了 CreateHeteroGraph 用于创建 HeteroGraph 实例。函数 CreateHeteroGraph 是在 src/graph/creators.ccdglsrc_src_graph_creators_cc 中被定义的,如下所示:

1
2
3
4
5
6
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
}

    我们对上面涉及的 HetroGraph 的构造函数进行回顾 (在 src/graph/heterograph.ccdglsrc_src_graph_heterograph_cc 中被定义):

1
2
3
4
5
6
7
8
9
10
11
HeteroGraph::HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) {
if (num_nodes_per_type.size() == 0)
num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
else
num_verts_per_type_ = num_nodes_per_type;
HeteroGraphSanityCheck(meta_graph, rel_graphs);
relation_graphs_ = CastToUnitGraphs(rel_graphs);
}

    完成异构图 HetroGraph 实例的创建后,动态链接库侧 Runtime API 将该实例的指针进行返回,至此完成异构图实例在动态链接库 Runtime 中的创建。

    回到 Python 侧,基于返回的 HetroGraph 实例指针,和同构图一样,Python 侧初始化了一个 DGLHeteroGraph 实例,至此完成异构图的创建过程 (codeflow_hetero 中 ⑦)。

节点和边的特征

    DGLGraph 对象的节点和边可具有多个用户定义的、可命名的特征,以储存图的节点和边的属性。通过 ndataedata 接口可访问这些特征。实例代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import dgl
import torch as th
g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6个节点,4条边

print(g)
# Graph(num_nodes=6, num_edges=4,
# ndata_schemes={}
# edata_schemes={})

g.ndata['x'] = th.ones(g.num_nodes(), 3) # 长度为3的节点特征
g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32) # 标量整型特征

print(g)
# Graph(num_nodes=6, num_edges=4,
# ndata_schemes={'x' : Scheme(shape=(3,), dtype=torch.float32)}
# edata_schemes={'x' : Scheme(shape=(,), dtype=torch.int32)})

# 不同名称的特征可以具有不同形状
g.ndata['y'] = th.randn(g.num_nodes(), 5)

# 获取节点1的特征
print(g.ndata['x'][1])
# tensor([1., 1., 1.])

# 获取边0和3的特征
print(g.edata['x'][th.tensor([0, 3])])
# tensor([1, 1], dtype=torch.int32)

    关于 ndataedata 接口,它们有如下特征:

  • 仅允许使用数值类型 (如单精度浮点型、双精度浮点型和整型) 的特征,这些特征可以是标量、向量或多维张量;
  • 每个节点特征具有唯一名称,每个边特征也具有唯一名称。节点和边的特征可以具有相同的名称;
  • 通过张量分配创建特征时,DGL 会将特征赋给图中的每个节点和每条边。该张量的第一维必须与图中节点或边的数量一致,不能将特征赋给图中节点或边的子集;
  • 相同名称的特征必须具有相同的维度和数据类型;
  • 特征张量使用 "行优先" 的原则,即每个行切片储存 $1$ 个节点或 $1$ 条边的特征

    另外,对于加权图,用户可以将权重储存为一个边特征,例子如下所示:

1
2
3
4
5
6
7
8
9
10
# 边 0->1, 0->2, 0->3, 1->3
edges = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])
weights = th.tensor([0.1, 0.6, 0.9, 0.7]) # 每条边的权重
g = dgl.graph(edges)
g.edata['w'] = weights # 将其命名为 'w'

print(g)
# Graph(num_nodes=4, num_edges=4,
# ndata_schemes={}
# edata_schemes={'w' : Scheme(shape=(,), dtype=torch.float32)})

异构图

    相比同构图,异构图里可以有不同类型的节点和边。这些不同类型的节点和边具有独立的 ID 空间和特征。在 DGL 中,一个异构图由一系列子图构成,一个子图对应一种关系。每个关系由一个字符串三元组定义 (源节点类型, 边类型, 目标节点类型)。由于这里的关系定义消除了边类型的歧义,DGL 称它们为 canonical edge types。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import dgl
import torch as th

# Create a heterograph with 3 node types and 3 edges types.
graph_data = {
('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
('drug', 'interacts', 'gene'): (th.tensor([0, 1]), th.tensor([2, 3])),
('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))
}
g = dgl.heterograph(graph_data)

print(g.ntypes)
# ['disease', 'drug', 'gene']

print(g.etypes)
# ['interacts', 'interacts', 'treats']

print(g.canonical_etypes)
# [('drug', 'interacts', 'drug'),
# ('drug', 'interacts', 'gene'),
# ('drug', 'treats', 'disease')]

    这么看来,同构图和二分图只是一种特殊的异构图,它们只包括一种关系。

1
2
3
4
5
# 一个同构图
dgl.heterograph({('node_type', 'edge_type', 'node_type'): (u, v)})

# 一个二分图
dgl.heterograph({('source_type', 'edge_type', 'destination_type'): (u, v)})

    与异构图相关联的 metagraph 就是图的模式。它指定节点集和节点之间的边的类型约束。 metagraph 中的一个节点 $u$ 对应于相关异构图中的一个节点类型。metagraph 中的边 $(u,v)$ 表示在相关异构图中存在从 $u$ 型节点到 $v$ 型节点的边。

1
2
3
4
5
6
7
8
9
10
11
print(g)
# Graph(num_nodes={'disease': 3, 'drug': 3, 'gene': 4},
# num_edges={('drug', 'interacts', 'drug'): 2,
# ('drug', 'interacts', 'gene'): 2,
# ('drug', 'treats', 'disease'): 1},
# metagraph=[('drug', 'drug', 'interacts'),
# ('drug', 'gene', 'interacts'),
# ('drug', 'disease', 'treats')])

print(g.metagraph().edges())
# OutMultiEdgeDataView([('drug', 'drug'), ('drug', 'gene'), ('drug', 'disease')])

    当引入多种节点和边类型后,用户在调用 DGLGraph API 以获取特定类型的信息时,需要指定具体的节点和边类型。此外,不同类型的节点和边具有单独的 ID。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 获取图中所有节点的数量
print(g.num_nodes())
# 10

# 获取drug节点的数量
print(g.num_nodes('drug'))
# 3

# 不同类型的节点有单独的ID。因此,没有指定节点类型就没有明确的返回值。
print(g.nodes())
# DGLError: Node type name must be specified if there are more than one node types.

print(g.nodes('drug'))
# tensor([0, 1, 2])

    为了设置/获取特定节点和边类型的特征,DGL 提供了两种新类型的语法:

  • 获取特定点类型的特征: g.nodes[‘node_type’].data[‘feat_name’]
  • 获取特定边类型的特征: g.edges[‘edge_type’].data[‘feat_name’]

    示例代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
# 设置/获取"drug"类型的节点的"hv"特征
g.nodes['drug'].data['hv'] = th.ones(3, 1)
print(g.nodes['drug'].data['hv'])
# tensor([[1.],
# [1.],
# [1.]])

# 设置/获取"treats"类型的边的"he"特征
g.edges['treats'].data['he'] = th.zeros(1, 1)
print(g.edges['treats'].data['he'])
# tensor([[0.]])

    如果图里只有一种节点或边类型,则不需要指定节点或边的类型。

1
2
3
4
5
6
7
8
9
10
g = dgl.heterograph({
('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
('drug', 'is similar', 'drug'): (th.tensor([0, 1]), th.tensor([2, 3]))
})

print(g.nodes())
# tensor([0, 1, 2, 3])

# 设置/获取单一类型的节点或边特征,不必使用新的语法
g.ndata['hv'] = th.ones(4, 1)

在 GPU 上使用 DGLGraph

    用户可以通过在构造图的过程中传入两个 GPU 张量来创建 GPU 上的 DGLGraph。另一种方法是使用 to API 将 DGLGraph 复制到 GPU,这会将图结构和特征数据都拷贝到指定的设备。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import dgl
import torch as th
u, v = th.tensor([0, 1, 2]), th.tensor([2, 3, 4])
g = dgl.graph((u, v))
g.ndata['x'] = th.randn(5, 3) # 原始特征在CPU上

print(g.device)
# device(type='cpu')

cuda_g = g.to('cuda:0') # 接受来自后端框架的任何设备对象
print(cuda_g.device)
# device(type='cuda', index=0)

print(cuda_g.ndata['x'].device) # 特征数据也拷贝到了GPU上
device(type='cuda', index=0)

# 由GPU张量构造的图也在GPU上
u, v = u.to('cuda:0'), v.to('cuda:0')
g = dgl.graph((u, v))
print(g.device)
# device(type='cuda', index=0)

    任何涉及在 GPU 上存储的图的操作都是在 GPU 上运行的。因此,这要求所有张量参数都已经放在GPU上,其结果(图或张量)也将在 GPU 上。此外,在 GPU 上存储的图只接受 GPU 上的特征数据。

1
2
3
4
5
6
7
8
9
10
11
12
print(cuda_g.in_degrees())
# tensor([0, 0, 1, 1, 1], device='cuda:0')

print(cuda_g.in_edges([2, 3, 4])) # 可以接受非张量类型的参数
(tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))

print(cuda_g.in_edges(th.tensor([2, 3, 4]).to('cuda:0'))) # 张量类型的参数必须在GPU上
(tensor([0, 1, 2], device='cuda:0'), tensor([2, 3, 4], device='cuda:0'))

print(cuda_g.ndata['h'] = th.randn(5, 4)) # ERROR! 特征也必须在GPU上!
# DGLError: Cannot assign node feature "h" on device cpu to a graph on device
# cuda:0. Call DGLGraph.to() to copy the graph to the same device.