源码解析:DGL 图拓扑存储

⚠ 转载请注明出处:作者:ZobinHuang,更新日期:July.21 2022


知识共享许可协议

    本作品ZobinHuang 采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 进行许可,在进行使用或分享前请查看权限要求。若发现侵权行为,会采取法律手段维护作者正当合法权益,谢谢配合。


目录

有特定需要的内容直接跳转到相关章节查看即可。

正在加载目录...

前言

    在本文中,我们将自底向顶地分析 DGL 是如何在内存中表示和存储图拓扑的。如 map 所示,这是我们本文要梳理的脉络,下面我们自底向顶地分模块地进行分析。

底层 Array 存储

    DGL 依赖各种主流的深度学习计算框架 (e.g. PyTorch, TensorFlow, MXNet, etc.) 作为后端,自然地也就依赖于后端的相关定义来保存底层的 Tensor 数据。换句话说,DGL 并不定义具体的底层 Tensor 数据组织和存储的方式,而是基于后端定义的 Tensor 数据进行包装。作为在本文的第一个部分,本节我们将首先对 DGL 是如何包装后端 Tensor 的进行探究。

动态链接库中的相关定义

    具体来说,DGL 在后端是以 Array 的方式来组织数据的。DGL 在动态链接库侧使用各种类型的 __Array (e.g., IdArray, DegreeArray, etc.) 实现对后端 Array 的包装。上图展示了整个流程,下面我们对这个流程对应的代码进行分析。

    首先,在 include/dgl/aten/types.h dglsrc_include_dgl_aten_type_h 中,我们可以看到如下的重定义声明:

1
2
3
4
5
6
typedef NDArray IdArray;
typedef NDArray DegreeArray;
typedef NDArray BoolArray;
typedef NDArray IntArray;
typedef NDArray FloatArray;
typedef NDArray TypeArray;

    而在 include/dgl/graph_interface.h dglsrc_include_dgl_graph_interface_h 中我们可以看到 EdgeArray 被重定义如下:

1
2
3
4
typedef struct {
/* \brief the two endpoints and the id of the edge */
IdArray src, dst, id;
} EdgeArray;

    这里的主角 —— NDArray,是在 include/dgl/runtime/ndarray.h dglsrc_include_dgl_runtime_ndarray_h 中进行定义的,定义代码摘抄如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/*!
* \brief Managed NDArray.
* The array is backed by reference counted blocks.
*/
class NDArray {
public:
// NDArray 定义的私有结构体
struct Container;

// 构造函数
NDArray() {}

// 构造函数
explicit inline NDArray(Container* data);

// 构造函数
inline NDArray(const NDArray& other);

// 构造函数
NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}

/* ...... */

private:
// Internal Data content
Container* data_{nullptr};
}

inline NDArray::NDArray(Container* data)
: data_(data) {
if (data_)
data_->IncRef();
}

inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
if (data_)
data_->IncRef();
}

    从上面的 NDArray 的构造函数代码可以发现,NDArray 存储的数据实际上依赖于类型为 NDArray::Conatainer 结构体的私有变量 data_,后者的定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
struct NDArray::Container {
public:
// NOTE: the first part of this structure is the same as
// DLManagedTensor, note that, however, the deleter
// is only called when the reference counter goes to 0
/*!
* \brief The corresponding dl_tensor field.
* \note it is important that the first field is DLTensor
* So that this data structure is DLTensor compatible.
* The head ptr of this struct can be viewed as DLTensor*.
*/
DLTensor dl_tensor;
/*!
* \brief addtional context, reserved for recycling
* \note We can attach additional content here
* which the current container depend on
* (e.g. reference to original memory when creating views).
*/
void* manager_ctx{nullptr};
/*!
* \brief Customized deleter
*
* \note The customized deleter is helpful to enable
* different ways of memory allocator that are not
* currently defined by the system.
*/
void (*deleter)(Container* self) = nullptr;
/*! \brief default constructor */
Container() {
dl_tensor.data = nullptr;
dl_tensor.ndim = 0;
dl_tensor.shape = nullptr;
dl_tensor.strides = nullptr;
dl_tensor.byte_offset = 0;
}
/*! \brief pointer to shared memory */
std::shared_ptr<SharedMemory> mem;
/*! \brief developer function, increases reference counter */
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
/*! \brief developer function, decrease reference counter */
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter != nullptr) {
(*this->deleter)(this);
}
}
}

private:
friend class NDArray;
friend class RPCWrappedFunc;
/*!
* \brief The shape container,
* can be used for shape data.
*/
std::vector<int64_t> shape_;
/*!
* \brief The stride container,
* can be used for stride data.
*/
std::vector<int64_t> stride_;
/*! \brief The internal array object */
std::atomic<int> ref_counter_{0};

bool pinned_by_dgl_{false};
};

    NDArray 的实际数据存储依赖于 NDArray::Container。开发人员基于 NDArray::Container 提供的 IncRefDecRef API 可以实现对结构体被引用情况的计数。

    另外我们还可以观察到,NDArray::Container 结构体的实际数据存储实际上依赖于类型为来自 DLPack 库的 DLTensor 的成员变量 dl_tensor。这里在代码注释中说明的一个 trick 是: 把成员变量 dl_tensor 放在 NDArray::Container 定义的起始位置使得 DGL 代码的其他部分可以使用 DLTensor* 指针来指向 NDArray::Container 实例。

    这里说到了来自 DLPack 库的 DLTensorDLPack 库是为了统一各种计算框架 (e.g., PyTorch, TensorFlow, etc.) 的 Tensor 定义。DLPack 库并不会定义具体的 Tensor 的存储以及操作方法,而是提供了一个接口,使得在一个计算框架中可以使用来自另一种计算框架定义的 Tensor 数据 (p.s., 以共享内存的方式),具体可以参考 DLPack 库的官方文档 dlpack 和官方仓库 dlpack_repo

    下面我们对 DLTensor 的具体定义进行分析,它是在 DLPack 库的 include/dlpack/dlpack.h dlpacksrc_include_dlpack_dlpack_h 中定义的,具体定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
/*!
* \brief Plain C Tensor object, does not manage memory.
*/
typedef struct {
/*!
* \brief The data pointer points to the allocated data. This will be CUDA
* device pointer or cl_mem handle in OpenCL. It may be opaque on some device
* types. This pointer is always aligned to 256 bytes as in CUDA. The
* `byte_offset` field should be used to point to the beginning of the data.
*
* Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
* TVM, perhaps others) do not adhere to this 256 byte aligment requirement
* on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed
* (after which this note will be updated); at the moment it is recommended
* to not rely on the data pointer being correctly aligned.
*
* For given DLTensor, the size of memory required to store the contents of
* data is calculated as follows:
*
* \code{.c}
* static inline size_t GetDataSize(const DLTensor* t) {
* size_t size = 1;
* for (tvm_index_t i = 0; i < t->ndim; ++i) {
* size *= t->shape[i];
* }
* size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
* return size;
* }
* \endcode
*/
void* data;
/*! \brief The device of the tensor */
DLDevice device;
/*! \brief Number of dimensions */
int32_t ndim;
/*! \brief The data type of the pointer*/
DLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
/*!
* \brief strides of the tensor (in number of elements, not bytes)
* can be NULL, indicating tensor is compact and row-majored.
*/
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DLTensor;

    在 DLTensor 的定义中,类型为 void* 的成员变量 data 用于指向真正存储着 Array 数据的 Tensor (p.s. 来自各种计算框架),其余的成员变量则用于存储关于 data 指向的 Tensor 的元数据。

Python 侧 API

    DGL 在 Python 侧提供了将后端框架 (e.g., PyTorch, TensorFlow, etc.) 定义的 Tensor 数组转化为 NDArray 的 API。我们下面对它们进行分析。

    以 PyTorch 作为后端框架为例,在调用动态链接库中的 _CAPI_DGLHeteroCreateUnitGraphFromCOO API 时,我们可以看到如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from . import backend as F

def create_unitgraph_from_coo(num_ntypes, num_src, num_dst, row, col,
formats, row_sorted=False, col_sorted=False):
"""Create a unitgraph graph index from COO format

Parameters
----------
num_ntypes : int
Number of node types (must be 1 or 2).
num_src : int
Number of nodes in the src type.
num_dst : int
Number of nodes in the dst type.
row : utils.Index
Row index.
col : utils.Index
Col index.
formats : list of str.
Restrict the storage formats allowed for the unit graph.
row_sorted : bool, optional
Whether or not the rows of the COO are in ascending order.
col_sorted : bool, optional
Whether or not the columns of the COO are in ascending order within
each row. This only has an effect when ``row_sorted`` is True.

Returns
-------
HeteroGraphIndex
"""
if isinstance(formats, str):
formats = [formats]
return _CAPI_DGLHeteroCreateUnitGraphFromCOO(
int(num_ntypes), int(num_src), int(num_dst),
F.to_dgl_nd(row), F.to_dgl_nd(col),
formats, row_sorted, col_sorted)

    _CAPI_DGLHeteroCreateUnitGraphFromCOO 的定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray row = args[3];
IdArray col = args[4];
List<Value> formats = args[5];
bool row_sorted = args[6];
bool col_sorted = args[7];
std::vector<SparseFormat> formats_vec;
for (Value val : formats) {
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
const auto code = SparseFormatsToCode(formats_vec);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col,
row_sorted, col_sorted, code);
*rv = HeteroGraphRef(hgptr);
});

    可以发现,_CAPI_DGLHeteroCreateUnitGraphFromCOO 接受的第 3 个和第 4 个参数均是类型为 IdArray 的参数,其实质类型为 NDArray。在 Python 侧调用该 API 时,我们发现程序使用了 F.to_dgl_nd 这个 API,将 torch.tensor 数据转化为了 NDArray 数据。F.to_dgl_nd 是在 python/dgl/backend/__init__.py dglsrc_python_dgl_backend_init_py 中定义的,如下所示:

1
2
def to_dgl_nd(data):
return zerocopy_to_dgl_ndarray(data)

    上面的代码中,zerocopy_to_dgl_ndarray 的实现取决于具体的后端实现框架。我们这里以 PyTorch 为例的话,zerocopy_to_dgl_ndarray 是在 backend/pytorch/tensor.py dglsrc_python_dgl_backend_pytorch_tensor_py 中进行定义的,如下所示:

1
2
3
4
5
6
7
8
9
10
from torch.utils import dlpack

if LooseVersion(th.__version__) >= LooseVersion("1.10.0"):
def zerocopy_to_dgl_ndarray(data):
if data.dtype == th.bool:
data = data.byte()
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))
else:
def zerocopy_to_dgl_ndarray(data):
return nd.from_dlpack(dlpack.to_dlpack(data.contiguous()))

    从上面的代码中可以看出,zerocopy_to_dgl_ndarray 首先使用 torch.utils.dlpack.to_dlpack API,将一个 PyTorch 的 Tensor 数据以零内存拷贝的方式转化为通用的 DLPack 实例,然后再调用 DGL 的 NDArray.from_dlpack API 将 DLPack 实例转化为 DGL 的 NDArray 实例,后者这个 API 的底层将调用动态链接库中的方法实现相应的转化,我们在这里不再细究。

稀疏矩阵的表示和存储

稀疏矩阵的高效表示

    本节我们来看在计算机上的图拓扑的存储方法。首先,大部分图数据集中所使用的图拓扑大都十分稀疏 (i.e. 一个(简单)图中边的数量上限是 $O(n^2)$,这也是邻接矩阵的元素个数,大部分图数据集的拓扑都达不到这个数量级,导致邻接矩阵较为稀疏),因此应该使用更加存储高效的方法,而不是邻接矩阵来存储稀疏的图拓扑数据。下面介绍典型的几种存储稀疏矩阵的方法。

    首先是 Coordinate list (COO) 方法,如上图所示,COO 仅记录了稀疏矩阵中非零元素的坐标以及对应的值的信息,就缩小了存储稀疏矩阵所需要的存储空间。

    接着是 Compressed Sparse Row (CSR) 方法,如上图所示,CSR 相对于 COO 进一步优化了存储策略:我们可以观察到在 COO 方法中,Row 信息的记录实际上存在冗余,因为其重复记录了同一行不同列的矩阵元素的行信息;因此 CSR 方法通过把 Row 信息更改为 Row 指针,只记录属于同一行不通列的矩阵元素在 Column 中的起始位置,消除了该冗余性,进一步提高了存储效率。

    最后是 Compressed Sparse Column (CSC) 方法,其根本思想与 CSR 别无二致,只是把行换成了列,这里不再赘述。

DGL 中稀疏矩阵的相关定义

    DGL 中使用 COOMatrixCSRMatrix 二者来表示稀疏矩阵。这二者都是通过 IdArray 类对稀疏矩阵进行存储,并且提供了将稀疏矩阵进行持久化存储的接口,这也是其重要的功能。下面我们分别对它们进行分析。

COOMatrix

    在 include/dgl/aten/coo.h dglsrc_include_dgl_aten_coo_h 中定义的 COOMatrix 结构体是用于存储 COO 格式的稀疏矩阵的结构体,其定义摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*!
* \brief Plain COO structure
*
* The data array stores integer ids for reading edge features.
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/

constexpr uint64_t kDGLSerialize_AtenCooMatrixMagic = 0xDD61ffd305dff127;

// TODO(BarclayII): Graph queries on COO formats should support the case where
// data ordered by rows/columns instead of EID.
struct COOMatrix {
/*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0;
/*! \brief COO index arrays */
IdArray row, col;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data;
/*! \brief whether the row indices are sorted */
bool row_sorted = false;
/*! \brief whether the column indices per row are sorted */
bool col_sorted = false;
/*! \brief whether the matrix is in pinned memory */
bool is_pinned = false;
/*! \brief default constructor */
COOMatrix() = default;
/*! \brief constructor */
COOMatrix(int64_t nrows, int64_t ncols, IdArray rarr, IdArray carr,
IdArray darr = NullArray(), bool rsorted = false,
bool csorted = false)
: num_rows(nrows),
num_cols(ncols),
row(rarr),
col(carr),
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {
CheckValidity();
}

/*! \brief constructor from SparseMatrix object */
explicit COOMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
row(spmat.indices[0]),
col(spmat.indices[1]),
data(spmat.indices[2]),
row_sorted(spmat.flags[0]),
col_sorted(spmat.flags[1]) {
CheckValidity();
}

// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
}

bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCooMatrixMagic)
<< "Invalid COOMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&row)) << "Invalid row";
CHECK(fs->Read(&col)) << "Invalid col";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&row_sorted)) << "Invalid row_sorted";
CHECK(fs->Read(&col_sorted)) << "Invalid col_sorted";
CheckValidity();
return true;
}

void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCooMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(row);
fs->Write(col);
fs->Write(data);
fs->Write(row_sorted);
fs->Write(col_sorted);
}

inline void CheckValidity() const {
CHECK_SAME_DTYPE(row, col);
CHECK_SAME_CONTEXT(row, col);
if (!aten::IsNullArray(data)) {
CHECK_SAME_DTYPE(row, data);
CHECK_SAME_CONTEXT(row, data);
}
CHECK_NO_OVERFLOW(row->dtype, num_rows);
CHECK_NO_OVERFLOW(row->dtype, num_cols);
}

/*! \brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == row->ctx)
return *this;
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx, stream),
col.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
row_sorted, col_sorted);
}

/*!
* \brief Pin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned)
return;
row.PinMemory_();
col.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
}

/*!
* \brief Unpin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned)
return;
row.UnpinMemory_();
col.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
}
};

    COOMatrix 结构体中,使用了三个 IdArray 实现了基于 COO 的稀疏矩阵的存储。

CSRMatrix

    在 include/dgl/aten/csr.h dglsrc_include_dgl_aten_csr_h 中定义的 CSRMatrix 是用于存储 CSR 格式的稀疏矩阵的结构体,其定义摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/*!
* \brief Plain CSR matrix
*
* The column indices are 0-based and are not necessarily sorted. The data array stores
* integer ids for reading edge features.
*
* Note that we do allow duplicate non-zero entries -- multiple non-zero entries
* that have the same row, col indices. It corresponds to multigraph in
* graph terminology.
*/

constexpr uint64_t kDGLSerialize_AtenCsrMatrixMagic = 0xDD6cd31205dff127;

struct CSRMatrix {
/*! \brief the dense shape of the matrix */
int64_t num_rows = 0, num_cols = 0;
/*! \brief CSR index arrays */
IdArray indptr, indices;
/*! \brief data index array. When is null, assume it is from 0 to NNZ - 1. */
IdArray data;
/*! \brief whether the column indices per row are sorted */
bool sorted = false;
/*! \brief whether the matrix is in pinned memory */
bool is_pinned = false;
/*! \brief default constructor */
CSRMatrix() = default;
/*! \brief constructor */
CSRMatrix(int64_t nrows, int64_t ncols, IdArray parr, IdArray iarr,
IdArray darr = NullArray(), bool sorted_flag = false)
: num_rows(nrows),
num_cols(ncols),
indptr(parr),
indices(iarr),
data(darr),
sorted(sorted_flag) {
CheckValidity();
}

/*! \brief constructor from SparseMatrix object */
explicit CSRMatrix(const SparseMatrix& spmat)
: num_rows(spmat.num_rows),
num_cols(spmat.num_cols),
indptr(spmat.indices[0]),
indices(spmat.indices[1]),
data(spmat.indices[2]),
sorted(spmat.flags[0]) {
CheckValidity();
}

// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCSR), num_rows,
num_cols, {indptr, indices, data}, {sorted});
}

bool Load(dmlc::Stream* fs) {
uint64_t magicNum;
CHECK(fs->Read(&magicNum)) << "Invalid Magic Number";
CHECK_EQ(magicNum, kDGLSerialize_AtenCsrMatrixMagic)
<< "Invalid CSRMatrix Data";
CHECK(fs->Read(&num_cols)) << "Invalid num_cols";
CHECK(fs->Read(&num_rows)) << "Invalid num_rows";
CHECK(fs->Read(&indptr)) << "Invalid indptr";
CHECK(fs->Read(&indices)) << "Invalid indices";
CHECK(fs->Read(&data)) << "Invalid data";
CHECK(fs->Read(&sorted)) << "Invalid sorted";
CheckValidity();
return true;
}

void Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_AtenCsrMatrixMagic);
fs->Write(num_cols);
fs->Write(num_rows);
fs->Write(indptr);
fs->Write(indices);
fs->Write(data);
fs->Write(sorted);
}

inline void CheckValidity() const {
CHECK_SAME_DTYPE(indptr, indices);
CHECK_SAME_CONTEXT(indptr, indices);
if (!aten::IsNullArray(data)) {
CHECK_SAME_DTYPE(indptr, data);
CHECK_SAME_CONTEXT(indptr, data);
}
CHECK_NO_OVERFLOW(indptr->dtype, num_rows);
CHECK_NO_OVERFLOW(indptr->dtype, num_cols);
CHECK_EQ(indptr->shape[0], num_rows + 1);
}

/*! \brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (ctx == indptr->ctx)
return *this;
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx, stream),
indices.CopyTo(ctx, stream),
aten::IsNullArray(data) ? data : data.CopyTo(ctx, stream),
sorted);
}

/*!
* \brief Pin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned)
return;
indptr.PinMemory_();
indices.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
}

/*!
* \brief Unpin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned)
return;
indptr.UnpinMemory_();
indices.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
}
};

COOMatrix 结构体中,同样使用了三个 IdArray 实现了基于 CSR 的稀疏矩阵的存储。

图拓扑

    上面我们分析了 DGL 底层存储 Array 所依赖的 NDArray 结构,另外还分析了由于图拓扑的稀疏性,计算机当下存储图拓扑的格式是 COO、CSC 和 CSR,而不是直接存储邻接矩阵。基于这些定义,本节中我们将看到 DGL 如何完成对图拓扑的抽象。

GraphInterface 基类

    首先来看 GraphInterface 基类。这个类是在 include/dgl/graph_interface.h dglsrc_include_dgl_graph_interface_h 中被定义的。DGL 中将图拓扑称为 graph index,而 GraphInterface 是一个定义了对 graph index 进行操作的接口类,被我们后面会介绍的下游的 Graph, ImmutableGraph, COOCSR 类继承。GraphInterface 的定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
class GraphInterface : public runtime::Object {
public:
virtual ~GraphInterface() = default;

// Add vertices to the graph
virtual void AddVertices(uint64_t num_vertices) = 0;

// Add one edge to the graph
virtual void AddEdge(dgl_id_t src, dgl_id_t dst) = 0;

// Add edges to the graph
virtual void AddEdges(IdArray src_ids, IdArray dst_ids) = 0;

// Clear the graph. Remove all vertices/edges
virtual void Clear() = 0;

// Get the device context of this graph
virtual DLContext Context() const = 0;

// Get the number of integer bits used to store node/edge ids (32 or 64)
virtual uint8_t NumBits() const = 0;

// return whether the graph is a multigraph
virtual bool IsMultigraph() const = 0;

// return whether the graph is read-only
virtual bool IsReadonly() const = 0;

// return the number of vertices in the graph
virtual uint64_t NumVertices() const = 0;

// return the number of edges in the graph
virtual uint64_t NumEdges() const = 0;

// return true if the given vertex is in the graph
virtual bool HasVertex(dgl_id_t vid) const {
return vid < NumVertices();
}

// return a 0-1 array indicating whether the given vertices are in the graph
virtual BoolArray HasVertices(IdArray vids) const = 0;

// return true if the given edge is in the graph
virtual bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const = 0;

// return a 0-1 array indicating whether the given edges are in the graph
virtual BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const = 0;

// Find the predecessors of a vertex.
virtual IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const = 0;

// Find the successors of a vertex.
virtual IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const = 0;

// Get all edge ids between the two given endpoints
virtual IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const = 0;

// Get all edge ids between the given endpoint pairs.
virtual EdgeArray EdgeIds(IdArray src, IdArray dst) const = 0;

// Find the edge ID and return the pair of endpoints
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const = 0;

// Find the edge IDs and return their source and target node IDs.
virtual EdgeArray FindEdges(IdArray eids) const = 0;

// Get the in edges of the vertex
virtual EdgeArray InEdges(dgl_id_t vid) const = 0;

// Get the in edges of the vertices.
virtual EdgeArray InEdges(IdArray vids) const = 0;

// Get the out edges of the vertex.
virtual EdgeArray OutEdges(dgl_id_t vid) const = 0;

// Get the out edges of the vertices.
virtual EdgeArray OutEdges(IdArray vids) const = 0;

// Get all the edges in the graph
virtual EdgeArray Edges(const std::string &order = "") const = 0;

// Get the in degree of the given vertex
virtual uint64_t InDegree(dgl_id_t vid) const = 0;

// Get the in degrees of the given vertices
virtual DegreeArray InDegrees(IdArray vids) const = 0;

// Get the out degree of the given vertex
virtual uint64_t OutDegree(dgl_id_t vid) const = 0;

// Get the out degrees of the given vertices
virtual DegreeArray OutDegrees(IdArray vids) const = 0;

// Construct the induced subgraph of the given vertices
virtual Subgraph VertexSubgraph(IdArray vids) const = 0;

// Construct the induced edge subgraph of the given edges
virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0;

// Return the successor vector
virtual DGLIdIters SuccVec(dgl_id_t vid) const = 0;

// Return the out edge id vector
virtual DGLIdIters OutEdgeVec(dgl_id_t vid) const = 0;

// Return the predecessor vector
virtual DGLIdIters PredVec(dgl_id_t vid) const = 0;

// Return the in edge id vector

virtual DGLIdIters InEdgeVec(dgl_id_t vid) const = 0;

// Get the adjacency matrix of the graph.
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0;

// Sort the columns in CSR.
virtual void SortCSR() {
}

static constexpr const char* _type_key = "graph.Graph";
DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);
}

    可以看到这个接口类定义了所有 DGL 中可以对 graph index 进行的操作。基于 GraphInterface 接口类,派生出了 Graph, ImmutableGraph, COOCSR 类。在 graph_interface 中,我总结出了 GraphInterface 提供的所有接口对应的功能,以及所有子类予以实现的情况。

API dgl::Graph dgl::ImmutableGraph dgl::COO dgl::CSR
void AddVertices(uint64_t num_vertices)
向 Graph Index 中添加规定数量的节点
$\checkmark$
void AddEdge(dgl_id_t src, dgl_id_t dst)
向 Graph Index 中添加连接指定的单个「源节点 $\leftrightarrow$ 目的节点」pair 的边
$\checkmark$
void AddEdges(IdArray src_ids, IdArray dst_ids)
向 Graph Index 中添加连接指定的多个「源节点 $\leftrightarrow$ 目的节点」pairs 的边
$\checkmark$
void Clear()
清空图
$\checkmark$
DLContext Context() const
获取存储当前图的设备的上下文
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
uint8_t NumBits() const
获取当前存储 Graph Index 所使用的数据长度 (32/64 bits)
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
bool IsMultigraph() const
返回当前存储的 Graph 是否是一个 Multi Graph
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
bool IsReadonly() const
返回当前图是否是一个 Readonly 的图
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
uint64_t NumVertices() const
返回当前图中点的个数
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
uint64_t NumEdges() const
返回当前图中边的个数
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
bool HasVertex(dgl_id_t vid) const
返回当前图中是否存在指定下表的 (单个) 点
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
BoolArray HasVertices(IdArray vids) const
返回当前图中是否存在指定下表的 (多个) 点
$\checkmark$ $\checkmark$
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const
返回两个指定的点之间是否存在边
$\checkmark$ $\checkmark$ $\checkmark$
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const
返回一个 0-1 Array,指示对应的点之间是否存在边 (平铺邻接矩阵)
$\checkmark$ $\checkmark$ $\checkmark$
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const
返回一个点的先驱节点 (指向自己的节点),可以指定半径范围
$\checkmark$ $\checkmark$ (反向后使用 Successors)
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const
返回一个点的后驱节点 (自己指向的节点),可以指定半径范围
$\checkmark$ $\checkmark$ $\checkmark$
IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const
获取单个「源节点 $\leftrightarrow$ 目的节点」pair 之间的边的 Id
$\checkmark$ $\checkmark$ $\checkmark$
EdgeArray EdgeIds(IdArray src, IdArray dst) const
获取多个「源节点 $\leftrightarrow$ 目的节点」pairs 之间的边的 Ids
$\checkmark$ $\checkmark$ $\checkmark$
std::pair FindEdge(dgl_id_t eid) const
获取 (单条) 指定 Id 的边,并返回对应的源节点和目的节点的 Ids
$\checkmark$ $\checkmark$ $\checkmark$
EdgeArray FindEdges(IdArray eids) const
获取 (多条) 指定 Id 的边,并返回对应的源节点和目的节点的 Ids
$\checkmark$ $\checkmark$ $\checkmark$
EdgeArray InEdges(dgl_id_t vid) const
获取 (单个) 指定点的所有入边
$\checkmark$ $\checkmark$ (反向后使用 OutEdges)
EdgeArray InEdges(IdArray vids) const
获取 (多个) 指定点的所有入边
$\checkmark$ $\checkmark$ (反向后使用 OutEdges)
EdgeArray OutEdges(dgl_id_t vid) const
获取 (单个) 指定点的所有入边
$\checkmark$ $\checkmark$ $\checkmark$
EdgeArray OutEdges(IdArray vids) const
获取 (多个) 指定点的所有入边
$\checkmark$ $\checkmark$ $\checkmark$
EdgeArray Edges(const std::string &order = "") const
按照指定的排列顺序,获取图中的所有边
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
uint64_t InDegree(dgl_id_t vid) const
获取 (单个) 指定点的入度
$\checkmark$ $\checkmark$ (反向后使用 OutDegree)
DegreeArray InDegrees(IdArray vids) const
获取 (多个) 指定点的入度
$\checkmark$ $\checkmark$ (反向后使用 OutDegrees)
uint64_t OutDegree(dgl_id_t vid) const
获取 (单个) 指定点的出度
$\checkmark$ $\checkmark$ $\checkmark$
DegreeArray OutDegrees(IdArray vids) const
获取 (多个) 指定点的出度
$\checkmark$ $\checkmark$ $\checkmark$
Subgraph VertexSubgraph(IdArray vids) const
根据给定的点的 Ids,构建点导出子图
$\checkmark$ $\checkmark$ $\checkmark$
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const
根据给定的边的 Ids,构建边导出子图
$\checkmark$ $\checkmark$ $\checkmark$
DGLIdIters SuccVec(dgl_id_t vid) const
$\checkmark$ $\checkmark$ $\checkmark$
DGLIdIters OutEdgeVec(dgl_id_t vid) const
$\checkmark$ $\checkmark$ $\checkmark$
DGLIdIters PredVec(dgl_id_t vid) const
$\checkmark$ $\checkmark$ (反向后使用 SuccVec)
DGLIdIters InEdgeVec(dgl_id_t vid) const
$\checkmark$ $\checkmark$ (反向后使用 OutEdgeVec)
std::vectorGetAdj(bool transpose, const std::string &fmt) const
获取当前存储的图对应的邻接矩阵
$\checkmark$ $\checkmark$ $\checkmark$ $\checkmark$
void SortCSR()
$\checkmark$ $\checkmark$

    我们下面对 GraphInterface 的各个子类进行分析。

COO

    COO 类底层依赖于 COOMatrix 结构体,以 COO 的稀疏矩阵格式存储 只读 图拓扑,并且提供了 GraphInterface 中提供的部分操作图拓扑的接口的实现方法,它是在 include/dgl/immutable_graph.h dglsrc_include_dgl_immutable_graph_h 中定义的,graph_interface 中展示了它对 GraphInterface 接口的实现情况。COO 定义摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class COO : public GraphInterface {
public:
// Create a coo graph that shares the given src and dst
COO(int64_t num_vertices, IdArray src, IdArray dst,
bool row_sorted = false, bool col_sorted = false);


// ...

/* 对 GraphInterface 接口的实现 */

// ...

/*! \brief Return the transpose of this COO */
COOPtr Transpose() const {
return COOPtr(new COO(adj_.num_rows, adj_.col, adj_.row));
}

/*! \brief Convert this COO to CSR */
CSRPtr ToCSR() const;

/*!
* \brief Get the coo matrix that represents this graph.
* \note The coo matrix shares the storage with this graph.
* The data field of the coo matrix is none.
*/
aten::COOMatrix ToCOOMatrix() const {
return adj_;
}

/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The graph under another context.
*/
COO CopyTo(const DLContext& ctx) const;

/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
COO CopyToSharedMem(const std::string &name) const;

/*!
* \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64).
* \return The graph with new bit size storage.
*/
COO AsNumBits(uint8_t bits) const;

/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return false;
}

// member getters

IdArray src() const { return adj_.row; }

IdArray dst() const { return adj_.col; }

private:
/* !\brief private default constructor */
COO() {}

// The internal COO adjacency matrix.
// The data field is empty
aten::COOMatrix adj_;
}

    另外,DGL 还在 include/dgl/immutable_graph.h dglsrc_include_dgl_immutable_graph_h 中定义了一个智能指针 cpp_smart_pointer COOPtr,用来定义所有对 COO 实例的引用,如下所示:

1
typedef std::shared_ptr<COO> COOPtr;

CSR

    同理地,CSR 类底层依赖于 CSRMatrix 结构体,以 CSR 的稀疏矩阵格式存储 只读 图拓扑,并且提供了 GraphInterface 中提供的部分操作图拓扑的接口的实现方法,它是在 include/dgl/immutable_graph.h dglsrc_include_dgl_immutable_graph_h 中定义的,graph_interface 中展示了它对 GraphInterface 接口的实现情况。CSR 定义摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class CSR : public GraphInterface {
public:
// Create a csr graph that has the given number of verts and edges.
CSR(int64_t num_vertices, int64_t num_edges);

// Create a csr graph whose memory is stored in the shared memory
// that has the given number of verts and edges.
CSR(const std::string &shared_mem_name,
int64_t num_vertices, int64_t num_edges);

// Create a csr graph that shares the given indptr and indices.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids);

// Create a csr graph by data iterator
template <typename IndptrIter, typename IndicesIter, typename EdgeIdIter>
CSR(int64_t num_vertices, int64_t num_edges,
IndptrIter indptr_begin, IndicesIter indices_begin, EdgeIdIter edge_ids_begin);

// Create a csr graph whose memory is stored in the shared memory
// and the structure is given by the indptr and indcies.
CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name);

// ...

/* 对 GraphInterface 接口的实现 */

// ...

/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return !shared_mem_name_.empty();
}

/*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
CSRPtr Transpose() const;

/*! \brief Convert this CSR to COO */
COOPtr ToCOO() const;

/*!
* \return the csr matrix that represents this graph.
* \note The csr matrix shares the storage with this graph.
* The data field of the CSR matrix stores the edge ids.
*/
aten::CSRMatrix ToCSRMatrix() const {
return adj_;
}

/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The graph under another context.
*/
CSR CopyTo(const DLContext& ctx) const;

/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
CSR CopyToSharedMem(const std::string &name) const;

/*!
* \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64).
* \return The graph with new bit size storage.
*/
CSR AsNumBits(uint8_t bits) const;

// member getters

IdArray indptr() const { return adj_.indptr; }

IdArray indices() const { return adj_.indices; }

IdArray edge_ids() const { return adj_.data; }

/*! \return Load CSR from stream */
bool Load(dmlc::Stream *fs);

/*! \return Save CSR to stream */
void Save(dmlc::Stream* fs) const;

void SortCSR() override {
if (adj_.sorted)
return;
aten::CSRSort_(&adj_);
}

private:
friend class Serializer;

/*! \brief private default constructor */
CSR() {adj_.sorted = false;}
// The internal CSR adjacency matrix.
// The data field stores edge ids.
aten::CSRMatrix adj_;

// The name of the shared memory to store data.
// If it's empty, data isn't stored in shared memory.
std::string shared_mem_name_;

    另外,DGL 还在 include/dgl/immutable_graph.h dglsrc_include_dgl_immutable_graph_h 中定义了一个智能指针 cpp_smart_pointer CSRPtr,用来定义所有对 CSR 实例的引用,如下所示:

1
typedef std::shared_ptr<CSR> CSRPtr;

ImmutableGraph

    ImmutableGraph 实际上是对 COOCSR 类的封装,提供一个统一的创建和管理只读图拓扑的类。它是在 include/dgl/immutable_graph.h dglsrc_include_dgl_immutable_graph_h 中定义的,定义摘抄如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class ImmutableGraph: public GraphInterface {
public:
/*! \brief Construct an immutable graph from the COO format. */
explicit ImmutableGraph(COOPtr coo): coo_(coo) { }

/*!
* \brief Construct an immutable graph from the CSR format.
*
* For a single graph, we need two CSRs, one stores the in-edges of vertices and
* the other stores the out-edges of vertices. These two CSRs stores the same edges.
* The reason we need both is that some operators are faster on in-edge CSR and
* the other operators are faster on out-edge CSR.
*
* However, not both CSRs are required. Technically, one CSR contains all information.
* Thus, when we construct a temporary graphs (e.g., the sampled subgraphs), we only
* construct one of the CSRs that runs fast for some operations we expect and construct
* the other CSR on demand.
*/
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr)
: in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
}

/*! \brief Construct an immutable graph from one CSR. */
explicit ImmutableGraph(CSRPtr csr): out_csr_(csr) { }

/*! \brief default copy constructor */
ImmutableGraph(const ImmutableGraph& other) = default;

#ifndef _MSC_VER
/*! \brief default move constructor */
ImmutableGraph(ImmutableGraph&& other) = default;
#else
ImmutableGraph(ImmutableGraph&& other) {
this->in_csr_ = other.in_csr_;
this->out_csr_ = other.out_csr_;
this->coo_ = other.coo_;
other.in_csr_ = nullptr;
other.out_csr_ = nullptr;
other.coo_ = nullptr;
}
#endif // _MSC_VER

/*! \brief default assign constructor */
ImmutableGraph& operator=(const ImmutableGraph& other) = default;

/*! \brief default destructor */
~ImmutableGraph() = default;

// ...

/* 对 GraphInterface 接口的实现 */

// ...

/* !\brief Return in csr. If not exist, transpose the other one.*/
CSRPtr GetInCSR() const;

/* !\brief Return out csr. If not exist, transpose the other one.*/
CSRPtr GetOutCSR() const;

/* !\brief Return coo. If not exist, create from csr.*/
COOPtr GetCOO() const;

/*! \brief Create an immutable graph from CSR. */
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);

static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);

/*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst,
bool row_osrted = false, bool col_sorted = false);

/*!
* \brief Convert the given graph to an immutable graph.
*
* If the graph is already an immutable graph. The result graph will share
* the storage with the given one.
*
* \param graph The input graph.
* \return an immutable graph object.
*/
static ImmutableGraphPtr ToImmutable(GraphPtr graph);

/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The graph under another context.
*/
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DLContext& ctx);

/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
static ImmutableGraphPtr CopyToSharedMem(ImmutableGraphPtr g, const std::string &name);

/*!
* \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64).
* \return The graph with new bit size storage.
*/
static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits);

/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
*
* \return the reversed graph
*/
ImmutableGraphPtr Reverse() const;

/*! \return Load ImmutableGraph from stream, using out csr */
bool Load(dmlc::Stream *fs);

/*! \return Save ImmutableGraph to stream, using out csr */
void Save(dmlc::Stream* fs) const;

void SortCSR() override {
GetInCSR()->SortCSR();
GetOutCSR()->SortCSR();
}

bool HasInCSR() const {
return in_csr_ != NULL;
}

bool HasOutCSR() const {
return out_csr_ != NULL;
}

/*! \brief Cast this graph to a heterograph */
HeteroGraphPtr AsHeteroGraph() const;

protected:
friend class Serializer;
friend class UnitGraph;

/* !\brief internal default constructor */
ImmutableGraph() {}

/* !\brief internal constructor for all the members */
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, COOPtr coo)
: in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
CHECK(AnyGraph()) << "At least one graph structure should exist.";
}

ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
: in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
this->shared_mem_name_ = shared_mem_name;
}

/* !\brief return pointer to any available graph structure */
GraphPtr AnyGraph() const {
if (in_csr_) {
return in_csr_;
} else if (out_csr_) {
return out_csr_;
} else {
return coo_;
}
}

// Store the in csr (i.e, the reverse csr)
CSRPtr in_csr_;
// Store the out csr (i.e, the normal csr)
CSRPtr out_csr_;
// Store the edge list indexed by edge id (COO)
COOPtr coo_;

// The name of shared memory for this graph.
// If it's empty, the graph isn't stored in shared memory.
std::string shared_mem_name_;
// We serialize the metadata of the graph index here for shared memory.
NDArray serialized_shared_meta_;
};

    值得注意的是,ImmutableGraph 既支持基于 COO 创建,也支持从 CSR 创建。当基于 CSR 创建的时候,还可以选择以下两种情况:

  • 仅基于 1 个 CSR,存储在成员变量 out_csr_ 中;
  • 基于 2 个 CSR 创建,正常的 CSR 存储在成员变量 out_csr_ 中,反向的 CSR 存储在成员变量 in_csr_ 中。同时存储正向和反向的 CSR 的动机是因为对于某些操作来说,反向的 CSR 操作起来将更有效率;

    ImmutableGraph 支持 on-demand 地初始化存储着反向 CSR 的成员变量 in_csr_,也即若 in_csr_ 未被初始化,那么它将在某些需要反向 CSR 以加速操作的 API 被调用时被初始化。

Graph

    和 ImmutableGraph 不同,Graph 是基于邻接矩阵的 支持修改 的图拓扑类。它在 include/dgl/graph.h dglsrc_include_dgl_graph_h 中被定义,相关代码摘抄如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Graph: public GraphInterface {
public:
/*! \brief default constructor */
Graph() {}

/*! \brief construct a graph from the coo format. */
Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes);

/*! \brief default copy constructor */
Graph(const Graph& other) = default;

#ifndef _MSC_VER
/*! \brief default move constructor */
Graph(Graph&& other) = default;
#else
Graph(Graph&& other) {
adjlist_ = other.adjlist_;
reverse_adjlist_ = other.reverse_adjlist_;
all_edges_src_ = other.all_edges_src_;
all_edges_dst_ = other.all_edges_dst_;
read_only_ = other.read_only_;
num_edges_ = other.num_edges_;
other.Clear();
}
#endif // _MSC_VER

/*! \brief default assign constructor */
Graph& operator=(const Graph& other) = default;

/*! \brief default destructor */
~Graph() = default;

// ...

/* 对 GraphInterface 接口的实现 */

// ...

/*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO(
int64_t num_nodes, IdArray src_ids, IdArray dst_ids) {
return std::make_shared<Graph>(src_ids, dst_ids, num_nodes);
}

protected:
friend class GraphOp;
/*! \brief Internal edge list type */
struct EdgeList {
/*! \brief successor vertex list */
std::vector<dgl_id_t> succ;
/*! \brief out edge list */
std::vector<dgl_id_t> edge_id;
};
typedef std::vector<EdgeList> AdjacencyList;

/*! \brief adjacency list using vector storage */
AdjacencyList adjlist_;
/*! \brief reverse adjacency list using vector storage */
AdjacencyList reverse_adjlist_;

/*! \brief all edges' src endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_src_;

/*! \brief all edges' dst endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_dst_;

/*! \brief read only flag */
bool read_only_ = false;

/*! \brief number of edges */
uint64_t num_edges_ = 0;
}

    在 Graph 类的内部,在 Line 48~53,其定义了内部结构体类 EdgeList,每一个 EdgeList 记录了每一个点的所有前驱节点 (i.e., 邻接矩阵的每一列),以及所有出边的 Edge Id。基于 EdgeList 的定义,Graph 类创建了成员变量 AdjacencyList,作为邻接矩阵。

    另外,DGL 在 include/dgl/graph.hdglsrc_include_dgl_graph_h 中定义了一个智能指针 cpp_smart_pointer MutableGraphPtr,用来定义所有对 Graph 实例的引用,如下所示:

1
typedef std::shared_ptr<Graph> MutableGraphPtr

异构图的存储

    真实数据集的图拓扑往往是异构的,也即 异构图 (heterogeneous graph)。这里的异构根据 DGL 代码中的注释,指的是在一个图中包含了多种类型的点和多种类型的边的图。在 `\text{NodeType A}` 和 `\text{NodeType B}` 的点之间可以存在多种类型的边,而对于 `\text{EdgeType A}` 的边来说,它首尾连接的点对的类型是固定的。另外,对于 同构图 (heterogeneous graph),也即只有一种点类型和边类型的图来说,它可以被视为一种特殊的异构图,因此 DGL 使用异构图来抽象所有的图拓扑。

    总的来说,异构图可以被认为是综合了多个拓扑的图,每一个拓扑则代表了一种边类型所组成的图。

    下面我们先对单元图和元图的概念进行分析,然后再结合源码分析如 hetero 中所示的 DGL 对异构图的具体实现。

单元图 (Unit Graph)

    在 DGL 中,底层使用的是 单元图 (Unit Graph) 来存储顶层的图拓扑信息。所谓单元图,根据 DGL 的文档 dgl_doc_heterogeneous_graphs 和在动态链接库代码中的注释 dglsrc_src_graph_unit_graph_cc,是只包含一种关系 (utype, etype, vtype) 的图。在只有一种关系的定义下,单元图有如下两种情况:

  • utypevtype 分属于两种不同类型;
  • utypevtype 类型相同

    对于同构图 (i.e. 全图只有一种类型的点) 和 二分图 (bipartite) 来说,它们天然就满足单元图的定义,因此都是单元图;对于异构图 (i.e. 存在多种类型的点,以及边关系) 来说,DGL 底层会将其拆解为若干个只包含一种关系 (utype, etype, vtype) 的图进行存储,也即单元图。

元图 (Meta Graph)

    元图 Meta Graph 是基于异构图所产生出来的概念。元图中的各个点代表了异构图中的各个点的种类;元图中的边代表了异构图中各类点之间的邻接关系。元图的示意图如上所示。

    值得注意的是,对于 Unit Graph 来说,其 Meta Graph 的形式只有如上图所示的两种情况,分别对应有向图和无向图的情况。

BaseHeteroGraph 基类

    首先我们来看 BaseHeteroGraph 基类,它是在 include/dgl/base_heterograph.hdglsrc_include_dgl_base_heterograph_h 中被定义的。它是异构图表示的基类,定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
/*!
* \brief Base heterogenous graph.
*
* In heterograph, nodes represent entities and edges represent relations.
* Nodes and edges are associated with types. The same pair of entity types
* can have multiple relation types between them, but relation type **uniquely**
* identifies the source and destination entity types.
*
* In a high-level, a heterograph is a data structure composed of:
* - A meta-graph that stores the entity-entity relation graph.
* - A dictionary of relation type to the bipartite graph representing the
* actual connections among entity nodes.
*/
class BaseHeteroGraph : public runtime::Object {
public:
explicit BaseHeteroGraph(GraphPtr meta_graph): meta_graph_(meta_graph) {}
virtual ~BaseHeteroGraph() = default;

////////////////////////// query/operations on meta graph ////////////////////////

/*! \return the number of vertex types */
virtual uint64_t NumVertexTypes() const {
return meta_graph_->NumVertices();
}

/*! \return the number of edge types */
virtual uint64_t NumEdgeTypes() const {
return meta_graph_->NumEdges();
}

/*! \return given the edge type, find the source type */
virtual std::pair<dgl_type_t, dgl_type_t> GetEndpointTypes(dgl_type_t etype) const {
return meta_graph_->FindEdge(etype);
}

/*! \return the meta graph */
virtual GraphPtr meta_graph() const {
return meta_graph_;
}

/*!
* \brief Return the bipartite graph of the given edge type.
* \param etype The edge type.
* \return The bipartite graph.
*/
virtual HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const = 0;

////////////////////////// query/operations on realized graph ////////////////////////

/*! \brief Add vertices to the given vertex type */
virtual void AddVertices(dgl_type_t vtype, uint64_t num_vertices) = 0;

/*! \brief Add one edge to the given edge type */
virtual void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) = 0;

/*! \brief Add edges to the given edge type */
virtual void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) = 0;

/*!
* \brief Clear the graph. Remove all vertices/edges.
*/
virtual void Clear() = 0;

/*!
* \brief Get the data type of node and edge IDs of this graph.
*/
virtual DLDataType DataType() const = 0;

/*!
* \brief Get the device context of this graph.
*/
virtual DLContext Context() const = 0;

/*!
* \brief Pin graph.
*/
virtual void PinMemory_() = 0;

/*!
* \brief Check if this graph is pinned.
*/
virtual bool IsPinned() const = 0;

/*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).
*/
// TODO(BarclayII) replace NumBits() calls to DataType() calls
virtual uint8_t NumBits() const = 0;

/*!
* \return whether the graph is a multigraph
*/
virtual bool IsMultigraph() const = 0;

/*! \return whether the graph is read-only */
virtual bool IsReadonly() const = 0;

/*! \return the number of vertices in the graph.*/
virtual uint64_t NumVertices(dgl_type_t vtype) const = 0;

/*! \return the number of vertices for each type in the graph as a vector */
inline virtual std::vector<int64_t> NumVerticesPerType() const {
LOG(FATAL) << "[BUG] NumVerticesPerType() not supported on this object.";
return {};
}

/*! \return the number of edges in the graph.*/
virtual uint64_t NumEdges(dgl_type_t etype) const = 0;

/*! \return true if the given vertex is in the graph.*/
virtual bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const = 0;

/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const = 0;

/*! \return true if the given edge is in the graph.*/
virtual bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;

/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
virtual BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const = 0;

/*!
* \brief Find the predecessors of a vertex.
* \note The given vertex should belong to the source vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id.
* \return the predecessor id array.
*/
virtual IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const = 0;

/*!
* \brief Find the successors of a vertex.
* \note The given vertex should belong to the dest vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id.
* \return the successor id array.
*/
virtual IdArray Successors(dgl_type_t etype, dgl_id_t src) const = 0;

/*!
* \brief Get all edge ids between the two given endpoints
* \note The given src and dst vertices should belong to the source vertex type
* and the dest vertex type of the given edge type, respectively.
* \param etype The edge type
* \param src The source vertex.
* \param dst The destination vertex.
* \return the edge id array.
*/
virtual IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const = 0;

/*!
* \brief Get all edge ids between the given endpoint pairs.
*
* \param etype The edge type
* \param src The src vertex ids.
* \param dst The dst vertex ids.
* \return EdgeArray containing all edges between all pairs.
*/
virtual EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const = 0;

/*!
* \brief Get edge ids between the given endpoint pairs.
*
* Only find one matched edge Ids even if there are multiple matches due to parallel
* edges. The i^th Id in the returned array is for edge (src[i], dst[i]).
*
* \param etype The edge type
* \param src The src vertex ids.
* \param dst The dst vertex ids.
* \return EdgeArray containing all edges between all pairs.
*/
virtual IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const = 0;

/*!
* \brief Find the edge ID and return the pair of endpoints
* \param etype The edge type
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
*/
virtual std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const = 0;

/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param etype The edge type
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
*/
virtual EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const = 0;

/*!
* \brief Get the in edges of the vertex.
* \note The given vertex should belong to the dest vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id.
* \return the edges
*/
virtual EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Get the in edges of the vertices.
* \note The given vertex should belong to the dest vertex type
* of the given edge type.
* \param etype The edge type
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
virtual EdgeArray InEdges(dgl_type_t etype, IdArray vids) const = 0;

/*!
* \brief Get the out edges of the vertex.
* \note The given vertex should belong to the source vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
*/
virtual EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Get the out edges of the vertices.
* \note The given vertex should belong to the source vertex type
* of the given edge type.
* \param etype The edge type
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
virtual EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const = 0;

/*!
* \brief Get all the edges in the graph.
* \note If order is "srcdst", the returned edges list is sorted by their src and
* dst ids. If order is "eid", they are in their edge id order.
* Otherwise, in the arbitrary order.
* \param etype The edge type
* \param order The order of the returned edge list.
* \return the id arrays of the two endpoints of the edges.
*/
virtual EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const = 0;

/*!
* \brief Get the in degree of the given vertex.
* \note The given vertex should belong to the dest vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id.
* \return the in degree
*/
virtual uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Get the in degrees of the given vertices.
* \note The given vertex should belong to the dest vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id array.
* \return the in degree array
*/
virtual DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const = 0;

/*!
* \brief Get the out degree of the given vertex.
* \note The given vertex should belong to the source vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id.
* \return the out degree
*/
virtual uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Get the out degrees of the given vertices.
* \note The given vertex should belong to the source vertex type
* of the given edge type.
* \param etype The edge type
* \param vid The vertex id array.
* \return the out degree array
*/
virtual DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const = 0;

/*!
* \brief Return the successor vector
* \note The given vertex should belong to the source vertex type
* of the given edge type.
* \param vid The vertex id.
* \return the successor vector iterator pair.
*/
virtual DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Return the out edge id vector
* \note The given vertex should belong to the source vertex type
* of the given edge type.
* \param vid The vertex id.
* \return the out edge id vector iterator pair.
*/
virtual DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Return the predecessor vector
* \note The given vertex should belong to the dest vertex type
* of the given edge type.
* \param vid The vertex id.
* \return the predecessor vector iterator pair.
*/
virtual DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Return the in edge id vector
* \note The given vertex should belong to the dest vertex type
* of the given edge type.
* \param vid The vertex id.
* \return the in edge id vector iterator pair.
*/
virtual DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0;

/*!
* \brief Get the adjacency matrix of the graph.
*
* TODO(minjie): deprecate this interface; replace it with GetXXXMatrix.
*
* By default, a row of returned adjacency matrix represents the destination
* of an edge and the column represents the source.
*
* If the fmt is 'csr', the function should return three arrays, representing
* indptr, indices and edge ids
*
* If the fmt is 'coo', the function should return one array of shape (2, nnz),
* representing a horitonzal stack of row and col indices.
*
* \param transpose A flag to transpose the returned adjacency matrix.
* \param fmt the format of the returned adjacency matrix.
* \return a vector of IdArrays.
*/
virtual std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const = 0;

/*!
* \brief Determine which format to use with a preference.
*

* Otherwise, it will return whatever DGL thinks is the most appropriate given
* the arguments.
*
* \param etype Edge type.
* \param preferred_formats Preferred sparse formats.
* \return Available sparse format.
*/
virtual SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const = 0;

/*!
* \brief Return sparse formats already created for the graph.
*
* \return a number of type dgl_format_code_t.
*/
virtual dgl_format_code_t GetCreatedFormats() const = 0;

/*!
* \brief Return allowed sparse formats for the graph.
*
* \return a number of type dgl_format_code_t.
*/
virtual dgl_format_code_t GetAllowedFormats() const = 0;

/*!
* \brief Return the graph in specified available formats.
*
* \return The new graph.
*/
virtual HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const = 0;

/*!
* \brief Get adjacency matrix in COO format.
* \param etype Edge type.
* \return COO matrix.
*/
virtual aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const = 0;

/*!
* \brief Get adjacency matrix in CSR format.
*
* The row and column sizes are equal to the number of dsttype and srctype
* nodes, respectively.
*
* \param etype Edge type.
* \return CSR matrix.
*/
virtual aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const = 0;

/*!
* \brief Get adjacency matrix in CSC format.
*
* A CSC matrix is equivalent to the transpose of a CSR matrix.
* We reuse the CSRMatrix data structure as return value. The row and column
* sizes are equal to the number of dsttype and srctype nodes, respectively.
*
* \param etype Edge type.
* \return A CSR matrix.
*/
virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0;

/*!
* \brief Extract the induced subgraph by the given vertices.
*
* The length of the given vector should be equal to the number of vertex types.
* Empty arrays can be provided if no vertex is needed for the type. The result
* subgraph has the same meta graph with the parent, but some types can have no
* node/edge.
*
* \param vids the induced vertices per type.
* \return the subgraph.
*/
virtual HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const = 0;

/*!
* \brief Extract the induced subgraph by the given edges.
*
* The length of the given vector should be equal to the number of edge types.
* Empty arrays can be provided if no edge is needed for the type. The result
* subgraph has the same meta graph with the parent, but some types can have no
* node/edge.
*
* \param eids The edges in the subgraph.
* \param preserve_nodes If true, the vertices will not be relabeled, so some vertices
* may have no incident edges.
* \return the subgraph.
*/
virtual HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const = 0;

/*!
* \brief Convert the list of requested unitgraph graphs into a single unitgraph graph.
*
* \param etypes The list of edge type IDs.
* \return The flattened graph, with induced source/edge/destination types/IDs.
*/
virtual FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const {
LOG(FATAL) << "Flatten operation unsupported";
return nullptr;
}

/*! \brief Cast this graph to immutable graph */
virtual GraphPtr AsImmutableGraph() const {
LOG(FATAL) << "AsImmutableGraph not supported.";
return nullptr;
}

static constexpr const char* _type_key = "graph.HeteroGraph";
DGL_DECLARE_OBJECT_TYPE_INFO(BaseHeteroGraph, runtime::Object);

protected:
/*! \brief meta graph */
GraphPtr meta_graph_;

// empty constructor
BaseHeteroGraph(){}
}

    在 BaseHeteroGraph 中,成员变量 meta_graph_ 用于存储当前异构图对应的 MetaGraph。我们可以发现 BaseHeteroGraph 中定义了大量和操作异构图拓扑相关的接口,这些将会被后面继承于 BaseHeteroGraph 的类实现。

    另外,DGL 在 include/dgl/base_heterograph.hdglsrc_include_dgl_base_heterograph_h 中定义了智能指针cpp_smart_pointer HeteroGraphPtr,用于指向所有继承于 BaseHeteroGraph 的类,如下所示:

1
typedef std::shared_ptr<BaseHeteroGraph> HeteroGraphPtr;

UnitGraph

    Unit Graph 可以被视为一种特殊的异构图,它的两种情况我们在上面已经分析过了。DGL 中对 Unit Graph 的抽象是类 UnitGraph,它是在 src/graph/unit_graph.hdglsrc_src_graph_unit_graph_h 中被定义的,它继承自 BaseHeteroGraph 类,具体定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
class UnitGraph : public BaseHeteroGraph {
public:
// internal data structure
class COO;
class CSR;
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;

inline dgl_type_t SrcType() const {
return 0;
}

inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
}

inline dgl_type_t EdgeType() const {
return 0;
}

HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself.";
return {};
}

void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void Clear() override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

DLDataType DataType() const override;

DLContext Context() const override;

bool IsPinned() const override;

uint8_t NumBits() const override;

bool IsMultigraph() const override;

bool IsReadonly() const override {
return true;
}

uint64_t NumVertices(dgl_type_t vtype) const override;

inline std::vector<int64_t> NumVerticesPerType() const override {
std::vector<int64_t> num_nodes_per_type;
for (dgl_type_t vtype = 0; vtype < NumVertexTypes(); ++vtype)
num_nodes_per_type.push_back(NumVertices(vtype));
return num_nodes_per_type;
}

uint64_t NumEdges(dgl_type_t etype) const override;

bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override;

BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;

bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;

BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;

IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override;

IdArray Successors(dgl_type_t etype, dgl_id_t src) const override;

IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;

EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override;

IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override;

std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override;

EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override;

EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override;

EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override;

EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override;

EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override;

EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override;

uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override;

DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override;

uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override;

DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override;

DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override;

// 32bit version functions, patch for SuccVec
DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) const;

DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override;

DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;

std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override;

HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;

HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;

// creators
/*! \brief Create a graph with no edges */
static HeteroGraphPtr Empty(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
DLDataType dtype, DLContext ctx) {
IdArray row = IdArray::Empty({0}, dtype, ctx);
IdArray col = IdArray::Empty({0}, dtype, ctx);
return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
}

/*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, bool row_sorted = false,
bool col_sorted = false, dgl_format_code_t formats = ALL_CODE);

static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats = ALL_CODE);

/*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = ALL_CODE);

static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE);

/*! \brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = ALL_CODE);

static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats = ALL_CODE);

/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);

/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);

/*!
* \brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
* \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
void PinMemory_() override;

/*!
* \brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
* \note The graph will be unpinned inplace. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
void UnpinMemory_();

/*!
* \brief Create in-edge CSR format of the unit graph.
* \param inplace if true and the in-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the in-edge CSR format. Create from other format if not exist.
*/
CSRPtr GetInCSR(bool inplace = true) const;

/*!
* \brief Create out-edge CSR format of the unit graph.
* \param inplace if true and the out-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the out-edge CSR format. Create from other format if not exist.
*/
CSRPtr GetOutCSR(bool inplace = true) const;

/*!
* \brief Create COO format of the unit graph.
* \param inplace if true and the COO format does not exist, the created
* format will be cached in this object unless the format is restricted.
* \return Return the COO format. Create from other format if not exist.
*/
COOPtr GetCOO(bool inplace = true) const;

/*! \return Return the COO matrix form */
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;

/*! \return Return the in-edge CSC in the matrix form */
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override;

/*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;

SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return SelectFormat(preferred_formats);
}

/*!
* \brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist.
*
* \return A graph in the requested format.
*/
HeteroGraphPtr GetFormat(SparseFormat format) const;

dgl_format_code_t GetCreatedFormats() const override;

dgl_format_code_t GetAllowedFormats() const override;

HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;

/*! \return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);

/*! \return Save UnitGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;

/*! \brief Creat a LineGraph of self */
HeteroGraphPtr LineGraph(bool backtracking) const;

/*! \return the reversed graph */
UnitGraphPtr Reverse() const;

/*! \return the simpled (no-multi-edge) graph
* the count recording the number of duplicated edges from the original graph.
* the edge mapping from the edge IDs of original graph to those of the
* returned graph.
*/
std::tuple<UnitGraphPtr, IdArray, IdArray>ToSimple() const;

void InvalidateCSR();

void InvalidateCSC();

void InvalidateCOO();

private:
friend class Serializer;
friend class HeteroGraph;
friend class ImmutableGraph;
friend HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);

// private empty constructor
UnitGraph() {}

/*!
* \brief constructor
* \param metagraph metagraph
* \param in_csr in edge csr
* \param out_csr out edge csr
* \param coo coo
*/
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
dgl_format_code_t formats = ALL_CODE);

/*!
* \brief constructor
* \param num_vtypes number of vertex types (1 or 2)
* \param metagraph metagraph
* \param in_csr in edge csr
* \param out_csr out edge csr
* \param coo coo
* \param has_in_csr whether in_csr is valid
* \param has_out_csr whether out_csr is valid
* \param has_coo whether coo is valid
*/
static HeteroGraphPtr CreateUnitGraphFrom(
int num_vtypes,
const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo,
bool has_in_csr,
bool has_out_csr,
bool has_coo,
dgl_format_code_t formats = ALL_CODE);

/*! \return Return any existing format. */
HeteroGraphPtr GetAny() const;

/*!
* \brief Determine which format to use with a preference.
*
* If the storage of unit graph is "locked", i.e. no conversion is allowed, then
* it will return the locked format.
*
* Otherwise, it will return whatever DGL thinks is the most appropriate given
* the arguments.
*/
SparseFormat SelectFormat(dgl_format_code_t preferred_formats) const;

/*! \return Whether the graph is hypersparse */
bool IsHypersparse() const;

GraphPtr AsImmutableGraph() const override;

// Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked.
/*! \brief CSR graph that stores reverse edges */
CSRPtr in_csr_;
/*! \brief CSR representation */
CSRPtr out_csr_;
/*! \brief COO representation */
COOPtr coo_;
/*!
* \brief Storage format restriction.
*/
dgl_format_code_t formats_;
};

    UnitGraph 用于存储节点类型数为 1 或 2 的异构图 (i.e. 总之只有一种边类型)。从上面的代码中可以看出,其底层也是依赖于 COO、CSR 等存储格式来实现拓扑的存储。UnitGraph 定义了两个内部类 COOCSR,用于提供底层存储的功能,这两个类同样继承自 BaseHeteroGraph 基类。

UnitGraph:COO

    内部类 UnitGraph:COO 定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
class UnitGraph::COO : public BaseHeteroGraph {
public:
COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
IdArray dst, bool row_sorted = false, bool col_sorted = false)
: BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
adj_ = aten::COOMatrix{num_src, num_dst, src, dst,
NullArray(),
row_sorted, col_sorted};
}

COO(GraphPtr metagraph, const aten::COOMatrix& coo)
: BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always
// assigned ids from 0 to num_edges - 1.
CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
adj_.data = aten::NullArray();
}

COO() {
// set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
adj_.num_rows = -1;
adj_.num_cols = -1;
};

bool defined() const {
return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
}

inline dgl_type_t SrcType() const {
return 0;
}

inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
}

inline dgl_type_t EdgeType() const {
return 0;
}

HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself.";
return {};
}

void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void Clear() override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

DLDataType DataType() const override {
return adj_.row->dtype;
}

DLContext Context() const override {
return adj_.row->ctx;
}

bool IsPinned() const override {
return adj_.is_pinned;
}

uint8_t NumBits() const override {
return adj_.row->dtype.bits;
}

COO AsNumBits(uint8_t bits) const {
if (NumBits() == bits)
return *this;

COO ret(
meta_graph_,
adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits));
return ret;
}

COO CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx)
return *this;
return COO(meta_graph_, adj_.CopyTo(ctx, stream));
}


/*! \brief Pin the adj_: COOMatrix of the COO graph. */
void PinMemory_() {
adj_.PinMemory_();
}

/*! \brief Unpin the adj_: COOMatrix of the COO graph. */
void UnpinMemory_() {
adj_.UnpinMemory_();
}

bool IsMultigraph() const override {
return aten::COOHasDuplicate(adj_);
}

bool IsReadonly() const override {
return true;
}

uint64_t NumVertices(dgl_type_t vtype) const override {
if (vtype == SrcType()) {
return adj_.num_rows;
} else if (vtype == DstType()) {
return adj_.num_cols;
} else {
LOG(FATAL) << "Invalid vertex type: " << vtype;
return 0;
}
}

uint64_t NumEdges(dgl_type_t etype) const override {
return adj_.row->shape[0];
}

bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
return vid < NumVertices(vtype);
}

BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
LOG(FATAL) << "Not enabled for COO graph";
return {};
}

bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::COOIsNonZero(adj_, src, dst);
}

BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::COOIsNonZero(adj_, src_ids, dst_ids);
}

IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second;
}

IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
return aten::COOGetRowDataAndIndices(adj_, src).second;
}

IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::COOGetAllData(adj_, src, dst);
}

EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}

IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
return aten::COOGetData(adj_, src, dst);
}

std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
return std::pair<dgl_id_t, dgl_id_t>(src, dst);
}

EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
BUG_IF_FAIL(aten::IsNullArray(adj_.data)) <<
"FindEdges requires the internal COO matrix not having EIDs.";
return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids),
eids};
}

EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
IdArray ret_src, ret_eid;
std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
aten::COOTranspose(adj_), vid);
IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid};
}

EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids);
auto row = aten::IndexSelect(vids, coosubmat.row);
return EdgeArray{coosubmat.col, row, coosubmat.data};
}

EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
IdArray ret_dst, ret_eid;
std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid};
}

EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
auto coosubmat = aten::COOSliceRows(adj_, vids);
auto row = aten::IndexSelect(vids, coosubmat.row);
return EdgeArray{row, coosubmat.col, coosubmat.data};
}

EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \""
<< order << "\".";
IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
return EdgeArray{adj_.row, adj_.col, rst_eid};
}

uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid;
return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid);
}

DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids);
}

uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
return aten::COOGetRowNNZ(adj_, vid);
}

DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
return aten::COOGetRowNNZ(adj_, vids);
}

DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for COO graph.";
return {};
}

DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for COO graph.";
return {};
}

DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for COO graph.";
return {};
}

DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for COO graph.";
return {};
}

std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) {
return {aten::HStack(adj_.col, adj_.row)};
} else {
return {aten::HStack(adj_.row, adj_.col)};
}
}

aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
return adj_;
}

aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for COO graph";
return aten::CSRMatrix();
}

aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for COO graph";
return aten::CSRMatrix();
}

SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::kCOO;
}

dgl_format_code_t GetAllowedFormats() const override {
LOG(FATAL) << "Not enabled for COO graph";
return 0;
}

dgl_format_code_t GetCreatedFormats() const override {
LOG(FATAL) << "Not enabled for COO graph";
return 0;
}

HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
HeteroSubgraph subg;
const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
DLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
submat.row, submat.col);
subg.induced_vertices = vids;
subg.induced_edges.emplace_back(submat.data);
return subg;
}

HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
HeteroSubgraph subg;
if (!preserve_nodes) {
IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
subg.induced_vertices.emplace_back(aten::Relabel_({new_src}));
subg.induced_vertices.emplace_back(aten::Relabel_({new_dst}));
const auto new_nsrc = subg.induced_vertices[0]->shape[0];
const auto new_ndst = subg.induced_vertices[1]->shape[0];
subg.graph = std::make_shared<COO>(
meta_graph(), new_nsrc, new_ndst, new_src, new_dst);
subg.induced_edges = eids;
} else {
IdArray new_src = aten::IndexSelect(adj_.row, eids[0]);
IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]);
subg.induced_vertices.emplace_back(
aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
subg.induced_vertices.emplace_back(
aten::NullArray(DLDataType{kDLInt, NumBits(), 1}, Context()));
subg.graph = std::make_shared<COO>(
meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
subg.induced_edges = eids;
}
return subg;
}

HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
LOG(FATAL) << "Not enabled for COO graph.";
return nullptr;
}

aten::COOMatrix adj() const {
return adj_;
}

/*!
* \brief Determines whether the graph is "hypersparse", i.e. having significantly more
* nodes than edges.
*/
bool IsHypersparse() const {
return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
(NumVertices(SrcType()) > 1000000);
}

bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
meta_graph_ = meta_imgraph;
CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
return true;
}
void Save(dmlc::Stream* fs) const {
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
fs->Write(meta_graph_ptr);
fs->Write(adj_);
}

private:
friend class Serializer;

/*! \brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_;
};

UnitGraph:CSR

    内部类 UnitGraph:CSR 定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
class UnitGraph::CSR : public BaseHeteroGraph {
public:
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids)
: BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
if (aten::IsValidIdArray(edge_ids))
CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids))
<< "edge id arrays should have the same length as indices if not empty";
CHECK_EQ(num_src, indptr->shape[0] - 1)
<< "number of nodes do not match the length of indptr minus 1.";

adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
}

CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) {
}

CSR() {
// set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
adj_.num_rows = -1;
adj_.num_cols = -1;
};

bool defined() const {
return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
}

inline dgl_type_t SrcType() const {
return 0;
}

inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
}

inline dgl_type_t EdgeType() const {
return 0;
}

HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself.";
return {};
}

void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

void Clear() override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}

DLDataType DataType() const override {
return adj_.indices->dtype;
}

DLContext Context() const override {
return adj_.indices->ctx;
}

bool IsPinned() const override {
return adj_.is_pinned;
}

uint8_t NumBits() const override {
return adj_.indices->dtype.bits;
}

CSR AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
} else {
CSR ret(
meta_graph_,
adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits));
return ret;
}
}

CSR CopyTo(const DLContext &ctx,
const DGLStreamHandle &stream = nullptr) const {
if (Context() == ctx) {
return *this;
} else {
return CSR(meta_graph_, adj_.CopyTo(ctx, stream));
}
}

/*! \brief Pin the adj_: CSRMatrix of the CSR graph. */
void PinMemory_() {
adj_.PinMemory_();
}

/*! \brief Unpin the adj_: CSRMatrix of the CSR graph. */
void UnpinMemory_() {
adj_.UnpinMemory_();
}

bool IsMultigraph() const override {
return aten::CSRHasDuplicate(adj_);
}

bool IsReadonly() const override {
return true;
}

uint64_t NumVertices(dgl_type_t vtype) const override {
if (vtype == SrcType()) {
return adj_.num_rows;
} else if (vtype == DstType()) {
return adj_.num_cols;
} else {
LOG(FATAL) << "Invalid vertex type: " << vtype;
return 0;
}
}

uint64_t NumEdges(dgl_type_t etype) const override {
return adj_.indices->shape[0];
}

bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
return vid < NumVertices(vtype);
}

BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override {
LOG(FATAL) << "Not enabled for COO graph";
return {};
}

bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::CSRIsNonZero(adj_, src, dst);
}

BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
}

IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
LOG(INFO) << "Not enabled for CSR graph.";
return {};
}

IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
return aten::CSRGetRowColumnIndices(adj_, src);
}

IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::CSRGetAllData(adj_, src, dst);
}

EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}

IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
return aten::CSRGetData(adj_, src, dst);
}

std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid};
}

EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
auto csrsubmat = aten::CSRSliceRows(adj_, vids);
auto coosubmat = aten::CSRToCOO(csrsubmat, false);
// Note that the row id in the csr submat is relabled, so
// we need to recover it using an index select.
auto row = aten::IndexSelect(vids, coosubmat.row);
return EdgeArray{row, coosubmat.col, coosubmat.data};
}

EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\".";
auto coo = aten::CSRToCOO(adj_, false);
if (order == std::string("srcdst")) {
// make sure the coo is sorted if an order is requested
coo = aten::COOSort(coo, true);
}
return EdgeArray{coo.row, coo.col, coo.data};
}

uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
return aten::CSRGetRowNNZ(adj_, vid);
}

DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
return aten::CSRGetRowNNZ(adj_, vids);
}

DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
CHECK_EQ(NumBits(), 64);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(indices_data + start, indices_data + end);
}

DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
const int32_t* indptr_data = static_cast<int32_t*>(adj_.indptr->data);
const int32_t* indices_data = static_cast<int32_t*>(adj_.indices->data);
const int32_t start = indptr_data[vid];
const int32_t end = indptr_data[vid + 1];
return DGLIdIters32(indices_data + start, indices_data + end);
}

DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
CHECK_EQ(NumBits(), 64);
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(eid_data + start, eid_data + end);
}

DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {adj_.indptr, adj_.indices, adj_.data};
}

aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return aten::COOMatrix();
}

aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return aten::CSRMatrix();
}

aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
return adj_;
}

SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::kCSR;
}

dgl_format_code_t GetAllowedFormats() const override {
LOG(FATAL) << "Not enabled for COO graph";
return 0;
}

dgl_format_code_t GetCreatedFormats() const override {
LOG(FATAL) << "Not enabled for CSR graph";
return 0;
}

HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
HeteroSubgraph subg;
const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
DLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
submat.indptr, submat.indices, sub_eids);
subg.induced_vertices = vids;
subg.induced_edges.emplace_back(submat.data);
return subg;
}

HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}

HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return nullptr;
}

aten::CSRMatrix adj() const {
return adj_;
}

bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph";
meta_graph_ = meta_imgraph;
CHECK(fs->Read(&adj_)) << "Invalid adj matrix";
return true;
}
void Save(dmlc::Stream* fs) const {
auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph());
fs->Write(meta_graph_ptr);
fs->Write(adj_);
}

private:
friend class Serializer;

/*! \brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_;
};

HeteroGraph

    HeteroGraph 是用于存储异构图的类,它是在 src/graph/heterograph.hdglsrc_src_graph_heterograph_h 中定义的,具体定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
class HeteroGraph : public BaseHeteroGraph {
public:
HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type = {});

HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
CHECK_LT(etype, meta_graph_->NumEdges()) << "Invalid edge type: " << etype;
return relation_graphs_[etype];
}

void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override {
LOG(FATAL) << "Bipartite graph is not mutable.";
}

void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override {
LOG(FATAL) << "Bipartite graph is not mutable.";
}

void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override {
LOG(FATAL) << "Bipartite graph is not mutable.";
}

void Clear() override {
LOG(FATAL) << "Bipartite graph is not mutable.";
}

DLDataType DataType() const override {
return relation_graphs_[0]->DataType();
}

DLContext Context() const override {
return relation_graphs_[0]->Context();
}

bool IsPinned() const override {
return relation_graphs_[0]->IsPinned();
}

uint8_t NumBits() const override {
return relation_graphs_[0]->NumBits();
}

bool IsMultigraph() const override;

bool IsReadonly() const override {
return true;
}

uint64_t NumVertices(dgl_type_t vtype) const override {
CHECK(meta_graph_->HasVertex(vtype)) << "Invalid vertex type: " << vtype;
return num_verts_per_type_[vtype];
}

inline std::vector<int64_t> NumVerticesPerType() const override {
return num_verts_per_type_;
}

uint64_t NumEdges(dgl_type_t etype) const override {
return GetRelationGraph(etype)->NumEdges(0);
}

bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override {
return vid < NumVertices(vtype);
}

BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;

bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
return GetRelationGraph(etype)->HasEdgeBetween(0, src, dst);
}

BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
return GetRelationGraph(etype)->HasEdgesBetween(0, src_ids, dst_ids);
}

IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override {
return GetRelationGraph(etype)->Predecessors(0, dst);
}

IdArray Successors(dgl_type_t etype, dgl_id_t src) const override {
return GetRelationGraph(etype)->Successors(0, src);
}

IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
return GetRelationGraph(etype)->EdgeId(0, src, dst);
}

EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsAll(0, src, dst);
}

IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsOne(0, src, dst);
}

std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
return GetRelationGraph(etype)->FindEdge(0, eid);
}

EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
return GetRelationGraph(etype)->FindEdges(0, eids);
}

EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->InEdges(0, vid);
}

EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
return GetRelationGraph(etype)->InEdges(0, vids);
}

EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->OutEdges(0, vid);
}

EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override {
return GetRelationGraph(etype)->OutEdges(0, vids);
}

EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
return GetRelationGraph(etype)->Edges(0, order);
}

uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->InDegree(0, vid);
}

DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
return GetRelationGraph(etype)->InDegrees(0, vids);
}

uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->OutDegree(0, vid);
}

DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override {
return GetRelationGraph(etype)->OutDegrees(0, vids);
}

DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->SuccVec(0, vid);
}

DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->OutEdgeVec(0, vid);
}

DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->PredVec(0, vid);
}

DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
return GetRelationGraph(etype)->InEdgeVec(0, vid);
}

std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override {
return GetRelationGraph(etype)->GetAdj(0, transpose, fmt);
}

aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
return GetRelationGraph(etype)->GetCOOMatrix(0);
}

aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
return GetRelationGraph(etype)->GetCSCMatrix(0);
}

aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
return GetRelationGraph(etype)->GetCSRMatrix(0);
}

SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return GetRelationGraph(etype)->SelectFormat(0, preferred_formats);
}

dgl_format_code_t GetAllowedFormats() const override {
return GetRelationGraph(0)->GetAllowedFormats();
}

dgl_format_code_t GetCreatedFormats() const override {
return GetRelationGraph(0)->GetCreatedFormats();
}

HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;

HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;

HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;

FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;

GraphPtr AsImmutableGraph() const override;

/*! \return Load HeteroGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs);

/*! \return Save HeteroGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const;

/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);

/*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext &ctx,
const DGLStreamHandle &stream = nullptr);

/*!
* \brief Pin all relation graphs of the current graph.
* \note The graph will be pinned inplace. Behavior depends on the current context,
* kDLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
void PinMemory_() override;

/*!
* \brief Unpin all relation graphs of the current graph.
* \note The graph will be unpinned inplace. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
*/
void UnpinMemory_();

/*! \brief Copy the data to shared memory.
*
* Also save names of node types and edge types of the HeteroGraph object to shared memory
*/
static HeteroGraphPtr CopyToSharedMem(
HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes,
const std::vector<std::string>& etypes, const std::set<std::string>& fmts);

/*! \brief Create a heterograph from
* \return the HeteroGraphPtr, names of node types, names of edge types
*/
static std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
CreateFromSharedMem(const std::string &name);

/*! \brief Creat a LineGraph of self */
HeteroGraphPtr LineGraph(bool backtracking) const;

const std::vector<UnitGraphPtr>& relation_graphs() const {
return relation_graphs_;
}

private:
// To create empty class
friend class Serializer;

// Empty Constructor, only for serializer
HeteroGraph() : BaseHeteroGraph() {}

/*! \brief A map from edge type to unit graph */
std::vector<UnitGraphPtr> relation_graphs_;

/*! \brief A map from vert type to the number of verts in the type */
std::vector<int64_t> num_verts_per_type_;

/*! \brief The shared memory object for meta info*/
std::shared_ptr<runtime::SharedMemory> shared_mem_;

/*! \brief The name of the shared memory. Return empty string if it is not in shared memory. */
std::string SharedMemName() const;

/*! \brief template class for Flatten operation
*
* \tparam IdType Graph's index data type, can be int32_t or int64_t
* \param etypes vector of etypes to be falttened
* \return pointer of FlattenedHeteroGraphh
*/
template <class IdType>
FlattenedHeteroGraphPtr FlattenImpl(const std::vector<dgl_type_t>& etypes) const;
};

    存储一个异构图需要两种类型的信息: 异构图的 Meta Graph 和对应各个边类型的 Unit Graph。在 HeteroGraph 中,继承自 BaseHeteroGraph 的成员变量 meta_graph_ 用于存储当前异构图对应的 MetaGraph,成员变量 relation_graphs_ 则用于存储对应各个边类型的 Unit Graphs。