框架使用:构建 GNN 的单个层次

⚠ 转载请注明出处:作者:ZobinHuang,更新日期:June 19 2022


知识共享许可协议

    本作品ZobinHuang 采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 进行许可,在进行使用或分享前请查看权限要求。若发现侵权行为,会采取法律手段维护作者正当合法权益,谢谢配合。


目录

有特定需要的内容直接跳转到相关章节查看即可。

正在加载目录...

GraphSAGE 模型

    本文将以 GraphSAGE 模型为例,展开说明如何在 DGL 中构建 GNN Model Layer。下面我们先简要地复习一下 GraphSAGE 模型的定义。

    GraphSAGE 模型来自 NIPS 2017 年的论文 sage, dgl_sage,其定义的模型如下所示。

    在前向传播过程中,首先 Node 会聚合其邻居节点在上一层的表示:

$h_{\mathcal{N}(i)}^{(l+1)} = \mathrm{aggregate}\left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right)$

    然后 Node 会把聚合结果和其在上一层的表示进行 Concatenate:

$h_{i}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat}(h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right)$

    最后 Node 会将其在本层的表示进行归一化:

$h_{i}^{(l+1)} = \mathrm{norm}(h_{i}^{(l+1)})$

在 DGL 上构建 GraphSAGE Layer

    本节参考自 DGL 官方提供的 GraphSAGE 代码 dgl_sage_class 以及 DGL 官方提供的说明文档 dgl_build_own_block,但是对它们做了更新的说明。

构造函数

    首先我们需要为我们的 GraphSAGE Layer 构建一个 class。取决于所使用的后端框架的不同,构建出来的 class 需要继承自不同的父类。对于 PyTorch 后端, 它应该继承 PyTorch 的 NN.Module 类。

1
2
from torch import nn
class SAGEConv(nn.Module):

    GraphSAGE 类的构造函数如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(
self,
in_feats,
out_feats,
aggregator_type,
feat_drop=0.,
bias=True,
norm=None,
activation=None
):
super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
)

self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.feat_drop = nn.Dropout(feat_drop)
self.activation = activation
# aggregator type: mean/pool/lstm/gcn
if aggregator_type == 'pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type != 'gcn':
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False)
if bias:
self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats))
else:
self.register_buffer('bias', None)
self.reset_parameters()

    下图展示了各个传入参数和实际模型 Layer 之间的关系:

    在传入的参数中:

  • in_feats & out_feats: 对于一般的 PyTorch 模块,维度通常包括输入的维度、输出的维度和隐层的维度。对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。in_feats 指明了 $h^{(l)}_i$ 以及点 $i$ 的各个邻居 $h^{(l)}_j, \; j \in \mathcal{N}(i)$ 的维度,out_feats 则指明了 $h^{(l+1)}_i$ 的维度。

    值得注意的是,in_feats 有可能传入的是一个数,也可能传入的是一个 pair

    • 当传入的是一个数时,说明当前 Layer 处理的是一个 同构图 (Homogeneous Graph),边的关系只有一种,所有点的类型都是一样的,因此所有点的 Feature 的长度也是一样的;
    • 当传入的是一个 pair 时,说明当前 Layer 处理的是一个 异构图 (Heterogeneous Graph),边的关系有多种,点的类型也不止一种。在 DGL 中,异构图会被拆分成为若干个 二部图 (Bipartite Graph),每一个二部图对应于一种关系。这样一来,每一个 Layer 处理的就是一个二部图。对于这样一个二部图来说,图上两种不同类型的 Node 的 Feature 长度不一定是相等的,因此这里 in_feats 传入的就是一个 pair,形式是 (_in_src_feats, _in_dst_feats),分别代表图上两种 Node 的 Feature 长度,我们在下面将会看到相关的处理代码。
  • aggregator_type: 指定 forward_1 中的聚合函数类型,对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。常用的聚合类型包括 $\mathrm{mean}$、$\mathrm{sum}$、$\mathrm{max}$ 和 $\mathrm{min}$。一些模块可能会使用更加复杂的聚合函数,比如 $\mathrm{lstm}$。
  • norm: 一个 callable function,用于对 Aggregate 和 Update 后的 $h^{(l+1)}_i$ 进行归一化操作,也即 forward_3
  • activation: 一个 callable function,用于对 Aggregate 和 Update 后的 $h^{(l+1)}_i$ 进行激活操作,也即 forward_2

    构造函数的 Line 19 中我们调用了 expand_as_pair 函数,基于传入的 in_feats 的值,初始化了私有变量 _in_src_feats_in_dst_feats,这两个变量的含义分别是点 $i$ 的邻居传入的特征的维度 $|h^{(l)}_j|, \forall j \in \mathcal{N}(i)$,以及点 $i$ 本身的特征的维度 $|h^{(l)}_i|$。在上面我们提到构造函数传入的参数 in_feats 的值有可能是一个数,也有可能是一个 pair,分别对应于处理的图的类型的不同。DGL 中定义了 expand_as_pair 函数,以应对在不同图类型下对特征维度值的初始化,其定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def expand_as_pair(input_, g=None):
if isinstance(input_, tuple):
# 二分图的情况 (i.e. 传入的 input_ 是 pair): 直接返回传入的 pair
return input_
elif g is not None and g.is_block:
# 子图块的情况
if isinstance(input_, Mapping):
input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()}
else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst
else:
# 同构图的情况 (i.e. 传入的 input_ 是一个数): 返回该数形成的 pair,也即所有点的特征有相同的维度
return input_, input_

    在 Line 23 中,我们初始化了一个 nn.Dropout 实例 torch_nn_dropout,并且将其存储在私有变量 self.feat_drop 中。nn.Dropout 用于在训练过程中随机地将输入 Tensor 的某些 Elements 置为 0,以防止训练的过拟合。注意到我们传入了一个 float 类型的参数 feat_drop,这个参数可以理解为采样率,最终 nn.Dropout 将会以 feat_drop 的采样率来随机置零。

    构造函数的 Line 25~32 对本层 Layer 的 $\text{aggregate}$ 函数进行了初始化。根据传入的 $\text{aggregate}$ 函数的类型的不同,初始化私有变量成员也有所不同:

  • pool: 使用 nn.Linear 函数 torch_nn_linear 定义了一个全连接层,全连接层输入和输出的维度都为 _in_src_feats,也即邻居节点的特征的维度大小。
  • lstm: 使用 nn.LSTM 函数 torch_nn_lstm 定义了一个单层 LSTM RNN,全连接层输入和输出的维度都为 _in_src_feats,也即邻居节点的特征的维度大小,同时代码中还设置了 batch_first 参数,使得输入和输出的 tensor 的格式为 [batch, seq, feature]

    另外,如果我们传入的 $\text{aggregate}$ 函数的类型不为 gcn,则在 Line 31 中我们还使用 nn.Linear 函数 torch_nn_linear 创建了一个名为 fc_self 的全连接层,其输入维度为 _in_dst_feats,也即点 $i$ 在第 $l$ 层的特征的维度 $|h^{(l)}_i|$,输出维度为 out_feats,也即点 $i$ 在第 $l+1$ 层的特征的维度 $|h^{(l+1)}_i|$。这个 fc_neigh 层的作用实际上适用于实现 输入源节点维度 $\rightarrow$ 输出维度 的维度转换。

    另外,在 Line 32 中我们还定义了一个名为 fc_neigh 的全连接层,其输入维度为 _in_src_feats,也即点 $i$ 的邻居们 $j \in \mathcal{N}(i)$ 在第 $l$ 层的特征的维度 $|h^{(l)}_j|$,输出维度为 out_feats,也即点 $i$ 在第 $l+1$ 层的特征的维度 $|h^{(l+1)}_i|$。这个 fc_neigh 层的作用实际上适用于实现 输入目的节点维度 $\rightarrow$ 输出维度 的维度转换。

    最后,构造函数的 Line 37 中我们调用了 reset_parameters 私有方法对上述创建的神经网络的可学习参数进行了初始化。reset_parameters 私有方法定义如下所示:

1
2
3
4
5
6
7
8
9
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

Forward 函数

    在 PyTorch 的 NN.Module 类中,forward 函数执行了实际的前向传播计算。与传统以张量为参数的 PyTorch NN.Module 相比,DGL 的 NN.Module 额外增加了 $1$ 个参数 dgl.DGLGraph —— 也即指定被训练的图,其函数原型如下所示:

1
def forward(self, graph, feat):

    在查看 forward 的源码时,我们会发现所有的代码都被包装在 garph.local_scope() dgl_local_scope 中:

1
2
3
def forward(self, graph, feat):
with graph.local_scope():
# ......

    简单来说,这个函数的目的是为我们操作的 graph 创建一个 local 的操作空间,在这个空间中操作 graph 的各种 Node Feature 和 Edge Feature 的时候,都不会影响到原有的 graph 实例,仅在这个局部空间中生效。并且值得注意的是,只有 "out-place" 的修改是不影响的,而 "in-place" 的修改在 local_scope 中仍然会引起对原有 graph 实例的修改,举例来说如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 在 local_scope 中的 out-place 的修改不会引起对原有 graph 数据的更新
def foo(g):
with g.local_scope():
g.edata['h'] = torch.ones((g.num_edges(), 3))
g.edata['h2'] = torch.ones((g.num_edges(), 3))
return g.edata['h']

# 在 local_scope 中的 in-place 的修改将会引起对原有 graph 数据的更新
def foo(g):
with g.local_scope():
# in-place operation
g.edata['h'] += 1
return g.edata['h']

    具体例子可见 DGL 的参考文档 dgl_local_scope

    基于 forward 函数的函数原型,我们可以发现其传入了三个参数,这三个参数的含义分别是:

  • graph: 一个 DGLGraph 实例,包含了图的拓扑信息;
  • feat: 图上各个 Node 的特征,具体来说有两种可能的情况:
    • 传入的是 单个 torch.Tensor: 代表着此时处理的是一个同构图,该 Tensor 中包含了图上各个 Node 的 Feature (i.e. 代表着构造函数的传入参数 in_feats 是一个数,也即图上所有 Nodes 的 Features 的维度都是相同的);
    • 传入的是 一对 torch.Tensor: 代表着此时处理的是一个二部图,两个 Tensors 分别包含了图上两种类型的 Node 的 Feature (i.e. 代表着构造函数的传入参数 in_feats 是一个 pair);
  • edge_weight (可选地): 包含了各条 Edge 的 Weight 的 torch.Tensor

    基于上述传入的三个参数,forward 函数的内容一般可以分为3项操作:

  1. 对特征进行处理;
  2. 消息传递和聚合;
  3. 聚合后,更新特征作为输出

    完整的 forward 函数的定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def forward(self, graph, feat, edge_weight=None):
self._compatibility_check()
with graph.local_scope():
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
msg_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')

h_self = feat_dst

# Handle the case of graphs without edges
if graph.number_of_edges() == 0:
graph.dstdata['neigh'] = torch.zeros(
feat_dst.shape[0], self._in_src_feats).to(feat_dst)

# Determine whether to apply linear transformation before message passing A(XW)
lin_before_mp = self._in_src_feats > self._out_feats

# Message Passing
if self._aggre_type == 'mean':
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
graph.update_all(msg_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
if isinstance(feat, tuple): # heterogeneous
graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
else:
if graph.is_block:
graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
else:
graph.dstdata['h'] = graph.srcdata['h']
graph.update_all(msg_fn, fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
elif self._aggre_type == 'pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(msg_fn, fn.max('m', 'neigh'))
h_neigh = self.fc_neigh(graph.dstdata['neigh'])
elif self._aggre_type == 'lstm':
graph.srcdata['h'] = feat_src
graph.update_all(msg_fn, self._lstm_reducer)
h_neigh = self.fc_neigh(graph.dstdata['neigh'])
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = h_neigh
else:
rst = self.fc_self(h_self) + h_neigh

# bias term
if self.bias is not None:
rst = rst + self.bias

# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst

    下面我们分别就 forward 的如上三个步骤进行展开说明。

对特征进行处理

    在进入 forward 函数后,首先对传入的特征参数 feat 进行了处理:

1
2
3
4
5
6
7
if isinstance(feat, tuple):
feat_src = self.feat_drop(feat[0])
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]

    我们首先调用了私有函数 feat_drop 对 Node 的输入特征进行 Dropout 处理,上面的代码中针对传入的 feat 的情况的不同进行了不同的处理,这里不再赘述。

定义消息函数

    接着我们对消息函数进行定义。如下所示,我们使用了 copy_src 的内置消息函数作为我们我们生成 Message 的方法 (p.s. 该函数已经被弃用,替代函数是 copy_u),该函数简单地将源节点的特征 h 作为 Message m。注意到以下程序还对带权重的图进行了适配,如果输入的 edge_weight 不为空,则首先把 edge_weight 存为图的边上的特征 _edge_weight,然后应用 u_mul_e 内置消息函数进行带权的 Message 的计算。

1
2
3
4
5
msg_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
assert edge_weight.shape[0] == graph.number_of_edges()
graph.edata['_edge_weight'] = edge_weight
msg_fn = fn.u_mul_e('h', '_edge_weight', 'm')

消息传递过程

    在进行消息传递之前,DGL 进行了一个有趣的判断,如下所示。如果 DGL 发现源节点的输入特征维度 _in_src_feats 要比输出特征维度 _out_feats 要高的话,那么指示变量 lin_before_mp 将被赋为 true,也即在进行消息传递之前进行线性的矩阵计算,也即在运行消息传递之前,先完成 fc_neigh 全连接层所实现的维度转换,这样一来就可以在消息传递过程开始之前降低源节点的输入特征维度,以减小消息本身的维度,以减小计算量。我们在 在 DGL 上实现消息传递编程范式 一文中有过讨论。

1
2
# Determine whether to apply linear transformation before message passing A(XW)
lin_before_mp = self._in_src_feats > self._out_feats

    由于我们当下构造的模块提供了四种 $\text{Aggregate}$ 函数的类型选择,因此它们在消息传递的过程上也有若干区别。下面我们分别对它们进行分析。

mean 聚合
1
2
3
4
5
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
graph.update_all(msg_fn, fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)

    上面的代码比较简单,首先对源节点的 h 特征进行初始化以供后面借助 msg_fn 消息函数产生 Message:如果 lin_before_mptrue,则将源节点的 h 特征初始化为应用了线性计算后的结果,也即 fc_neigh(feat_src);若不然则直接应用为 feat_src

    接着调用 update_all,以产生 Messages 和使用 mean 方法来更新各个节点的 neigh 特征。

    如果 lin_before_mptrue,则我们已经完成了特征的更新,我们把节点的 neigh 特征提取到变量 h_neigh 中并返回;若不然则需要对节点的 neigh 特征进行线性运算之后,再保存到变量 h_neigh 中。

    此处我们基于 Mean 聚合定义的消息传播层次可以用下图进行表示:

gcn 聚合

    使用 gcn 聚合的代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
check_eq_shape(feat)
graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
if isinstance(feat, tuple): # heterogeneous
graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst
else:
if graph.is_block:
graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()]
else:
graph.dstdata['h'] = graph.srcdata['h']
graph.update_all(msg_fn, fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
if not lin_before_mp:
h_neigh = self.fc_neigh(h_neigh)
pool 聚合

    使用 pool 聚合的代码如下所示:

1
2
3
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(msg_fn, fn.max('m', 'neigh'))
h_neigh = self.fc_neigh(graph.dstdata['neigh'])
lstm 聚合

    使用 lstm 聚合的代码如下所示:

1
2
3
graph.srcdata['h'] = feat_src
graph.update_all(msg_fn, self._lstm_reducer)
h_neigh = self.fc_neigh(graph.dstdata['neigh'])

    其中聚合函数 _lstm_reducer 的定义如下所示:

1
2
3
4
5
6
7
8
9
10
11
def _lstm_reducer(self, nodes):
"""LSTM reducer
NOTE(zihao): lstm reducer with default schedule (degree bucketing)
is slow, we could accelerate this with degree padding in the future.
"""
m = nodes.mailbox['m'] # (B, L, D)
batch_size = m.shape[0]
h = (m.new_zeros((1, batch_size, self._in_src_feats)),
m.new_zeros((1, batch_size, self._in_src_feats)))
_, (rst, _) = self.lstm(m, h)
return {'neigh': rst.squeeze(0)}

Concatenate

    在完成 Messages 的聚合值 $h^{l+1}_{\mathcal{N}(i)}$ 的运算后,下一步我们需要将每一个节点上的聚合结果,结合节点自身在上一层输出的特征,进行 Concatenation。值得注意的是,如果我们使用了 GCN 作为我们的 Aggregator,则不需要进行 Concatenation 操作。具体代码如下所示。

1
2
3
4
5
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = h_neigh
else:
rst = self.fc_self(h_self) + h_neigh

Activation

    下一步我们将激活函数应用在 Concatenate 结果上,代码如下所示:

1
2
3
# activation
if self.activation is not None:
rst = self.activation(rst)

Normalization

    最后我们进行归一化,就可以输出最终的特征结果了,代码如下所示:

1
2
3
# normalization
if self.norm is not None:
rst = self.norm(rst)