图神经网络(GNN)的数学原理
本文结构
首先,我将对图和图神经网络进行深入分析。在这里,我深入探讨了向前传球所采取的细粒度步骤。然后,我继续使用熟悉的端到端技术来训练这些网络。最后,我使用前向传递部分中的步骤作为框架或指南,从文献中介绍流行的图神经网络。注:本文手机端阅读有乱码,可以网页打开
- 前言
- 本文结构
- 图表示
- 与图像的联系
- 图神经网络
- 节点中有什么?
- 边缘也很重要!!
- 消息传递
- 集合体
- 更新
- 把它们放在一起
- 使用 Edge 功能
- 使用邻接矩阵
- 堆叠 GNN 层
- 训练 GNN(上下文:节点分类)
- 训练和测试图形数据
- 反向支撑和梯度下降
- 流行的图神经网络
- 消息传递神经网络
- 图卷积网络
- 图注意力网络
- GraphSAGE
- 时态图网络
在我们进入图神经网络之前,让我们先探讨一下计算机科学中的图是什么。
图G(V,E)是一种数据结构,包含一组顶点(节点) i∈V 和一组 eij∈E连接顶点 i 和 j 的边。如果两个节点 i 和 j 已连接,则 , eij=1 否则 eij=0 。可以将此连接信息存储在邻接矩阵中 A :
⚠️ 我假设本文中的图形是未加权的(没有边缘权重或距离)和无方向的(节点之间没有关联方向)。我假设这些图是同质的(单一类型的节点和边;相反的是“异构”)。
图形与常规数据的不同之处在于,它们具有神经网络必须尊重的结构;不利用它是一种浪费。下面是一个社交媒体图的例子,其中节点是用户,边缘是他们的互动(如关注/喜欢/转发)。
与图像的联系
图像本身就是一个图表!它是一种称为“网格图”的特殊变体,其中所有内部节点和角节点的节点传出边的数量都是恒定的。图像网格图中存在一些一致的结构,允许对其执行简单的类似卷积的操作。
图像可以被认为是一个特殊的图形,其中每个像素都是一个节点,并通过假想边缘连接到它周围的其他像素。当然,从这种角度查看图像是不切实际的,因为这意味着要有一个非常大的图形。例如,一个简单的 CIFAR-10 图像 32×32×332×32×3 将具有 30723072 节点和 1984 条边。对于较大的 ImageNet 图像 224×224×3224×224×3 ,这些数字会爆炸。
图像可以被认为是一个特殊的图形,其中每个像素都是一个节点,并通过假想边缘连接到它周围的其他像素。当然,从这种角度查看图像是不切实际的,因为这意味着要有一个非常大的图形。例如,一个简单的 CIFAR-10 图像 32×32×332×32×3 将具有 30723072 节点和 1984 条边。对于较大的 ImageNet 图像 224×224×3224×224×3 ,这些数字会爆炸。
然而,正如你所观察到的,图表并不是那么完美。不同的节点具有不同的程度(与其他节点的连接数),并且无处不在。没有固定的结构,但结构是为图形增加价值的原因。因此,任何在此图上学习的神经网络都必须在学习节点(和边)之间的空间关系时尊重这种结构。
😌 尽管我们想在这里使用图像处理技术,但最好有特殊的特定于图形的方法,这些方法对于小型和大型图形都是高效和全面的。
图神经网络
单个图神经网络 (GNN) 层具有在图中的每个节点上执行的一系列步骤:
- 消息传递
- 集合体
- 更新
它们共同构成了通过图形学习的构建块。GDL 中的创新主要涉及对这 3 个步骤的更改。
图的节点
请记住:节点表示实体或对象,例如用户或原子。因此,此节点具有所表示实体的一系列特征属性。这些节点属性构成了节点的特征(即“节点特征”或“节点嵌入”)。
通常,这些特征可以使用 Rd 中的向量来表示。此向量要么是潜在维度嵌入,要么是以每个条目都是实体的不同属性的方式构造的。
🤔 例如,在社交媒体图中,用户节点具有年龄、性别、政治倾向、关系状态等属性,这些属性可以用数字表示。
同样,在分子图中,原子节点可能具有化学性质,例如对水的亲和力、力、能量等,也可以用数字表示。
这些节点特征是GNN的输入,我们将在后面的章节中看到。从形式上讲,每个节点都有关联的节点 i 特征 xi∈Rd和标签 yi(可以是连续的,也可以是离散的,就像独热编码一样)。
图的边(关系)
边缘也可以具有特征 aij∈Rd′ ,例如,在边缘有意义的情况下(如原子之间的化学键)。我们可以将下面显示的分子视为一个图形,其中原子是节点,键是边缘。
虽然原子节点本身具有各自的特征向量,但边可以具有不同的边缘特征,这些特征编码不同类型的键(单键、双键、三键)。不过,为了简单起见,我将在下一篇文章中省略边缘功能。
现在我们知道了如何在图中表示节点和边,让我们从一个简单的图开始,其中包含一堆节点(具有节点特征)和边。
消息传递
GNN以其学习结构信息的能力而闻名。通常,具有相似特征或属性的节点会相互连接(在社交媒体设置中也是如此)。GNN利用这一事实,了解特定节点如何以及为什么相互连接,而某些节点则不连接。为此,GNN 会查看节点的邻域。
节点的邻域定义为由边连接到 Ni 的一组节点 i j 。 i 形式上, Ni={j : eij∈E} .
一个人是由他所处的圈子塑造的。类似地,GNN 可以通过查看其邻域中的节点来了解很多关于节点 i 的信息 Ni 。为了实现源节点 i 和邻居之间的信息共享 j ,GNN参与消息传递。
对于 GNN 层,消息传递被定义为获取邻居的节点特征、转换它们并将它们“传递”到源节点的过程。对图中的所有节点并行重复此过程。这样,在这一步结束时,所有社区都会被检查。
让我们放大节点 66 并检查邻域 N6={1, 3, 4}6={1, 3, 4} 。我们获取每个节点特征 x1 、 x3 和 x4 ,并使用函数对其进行转换,该函数 F 可以是简单的神经网络(MLP 或 RNN)或仿射变换 F(xj)=Wj⋅xj+b 。简单地说,“消息”是从源节点传入的转换节点特征。
F� 可以是简单的仿射变换或神经网络。
现在,为了数学上的方便,让我们说 F(xj)=Wj⋅xj。这里, □⋅□◻⋅◻ 表示简单的矩阵乘法。
集合体
现在我们已经将转换后的消息 {F(x1),F(x3),F(x4)}传递给了 node 66 ,我们必须以某种方式聚合(“组合”)它们。可以做很多事情来将它们结合起来。常用的聚合函数包括:
Sum =∑j∈NiWj⋅xj
Mean =∑j∈NiWj⋅xj|Ni|
Max =maxj∈Ni({Wj⋅xj})
Min =minj∈Ni({Wj⋅xj})
假设我们使用一个函数 G来聚合邻居的消息(使用总和、平均值、最大值或最小值)。最终聚合的消息可以表示如下:
¯mi=G({Wj⋅xj:j∈Ni})
更新
使用这些聚合消息,GNN层现在必须更新源节点 i的特征。在此更新步骤结束时,节点不仅应该了解自己,还应该了解其邻居。这是通过获取节点 i的特征向量并将其与聚合消息相结合来确保的。同样,一个简单的加法或串联操作可以解决这个问题。
使用加法:
hi=σ(K(H(xi)+¯mi)))(6)(6)ℎ
其中 σ是激活函数(ReLU、ELU、Tanh), H是简单神经网络 (MLP) 或仿射变换,是 K另一个将添加的向量投影到另一个维度的 MLP。
使用串联:
hi=σ(K(H(xi) ⊕ ¯mi)))(7)(7)ℎ
为了进一步抽象此更新步骤,我们可以将其 K� 视为将消息和源节点嵌入在一起的投影函数:
hi=σ(K(H(xi), ¯mi)))(8)(8)ℎ
👉🏻 在表示法方面,初始节点特征称为 xi
在前向通过第一个 GNN 层后,我们改为调用节点特征 hiℎ 。假设我们有更多的 GNN 层,我们可以将节点特征表示为 hliℎ 当前 GNN 层索引的位置 l 。此外,很明显 h0i=xiℎ (即GNN的输入)。
把它们放在一起
现在我们已经完成了消息传递、聚合和更新步骤,让我们把它们放在一起,在单个节点 i 上形成一个 GNN 层:
hi=σ(W1⋅hi+∑j∈NiW2⋅hj)(9)(9)ℎ
在这里,我们使用 sum
聚合和简单的前馈层作为函数 F 和 H.
⚠️ 请确保节点嵌入的 W1 尺寸和 W2与节点嵌入正确交换。如果 hi∈Rdℎ∈, W1,W2⊆Rd′×d 其中 d是嵌入维度。
使用 Edge 功能
在处理边缘特征时,我们必须找到一种方法来对它们进行 GNN 前向传递。假设边具有特征 aij∈Rd′。为了在特定层 l更新它们,我们可以考虑边缘两侧节点的嵌入。正式
alij=T(hli, hlj, al−1ij)(10)(10)
其中 T是一个简单的神经网络(MLP 或 RNN),它接收来自连接节点 i的嵌入以及 j前一层的边缘嵌入 al−1ij。
使用邻接矩阵
到目前为止,我们研究了整个GNN前向通过孤立的单个节点 i 及其邻域的透镜 Ni。但是,在给定整个邻接矩阵 A 和所有 N=∥V∥节点特征时 X⊆RN×d,了解如何实现 GNN 前向传递也很重要。
在普通的机器学习中,在MLP前向传递中,我们希望对特征向量中的项目进行加权 xi 。这可以看作是节点特征向量 xi∈Rd和参数矩阵的点积, W⊆Rd′×d 其中 d 是嵌入维度:
zi=W⋅xi ∈Rd′(11)(11)
如果我们想对数据集中的所有样本(矢量化)执行此操作,我们只需对参数矩阵和特征进行矩阵相乘即可获得转换后的节点特征(消息):
Z=(WX)T=XW ⊆RN×d′(12)(12)
现在,在 GNN 中,对于每个节点,消息聚合操作涉及获取相邻节点 i 的特征向量,转换它们并将它们相加(在聚合的情况下 sum
)。
邻接矩阵中的一行 Ai 告诉我们哪些节点 j 连接到 i 。对于每个 indiex j where Aij=1,我们知道节点 i 并 j 连接→ eij∈E 。
例如,如果 A2=[1,0,1,1,0]2=[1,0,1,1,0] ,我们知道节点连接到节点 22 11 、 33 和 44 。因此,当我们乘 A2 以 Z=XW时,我们只考虑列 、 33 ,而 44 忽略列 22 11 和 55 。在矩阵乘法方面,我们正在做:
让我们关注 A的第 2 行。
矩阵乘法只是每 A 一行与每一列的点积 Z=XW!!
…而这正是消息聚合的本质!!
为了根据连接获得图中所有 N 节点的聚合消息,我们可以 A 将整个邻接矩阵与转换后的节点特征进行矩阵相乘:
Y=AZ=AXW(13)(13)Y=
!️ 一个小问题:观察聚合消息没有考虑节点 i 自己的特征向量(就像我们上面所做的那样)。为此,我们添加了自循环 A (每个节点 i 都连接到自身)。
这意味着在每个位置 Aii (即对角线)更改为 00 a 11 。
通过一些线性代数,我们可以使用恒等矩阵来做到这一点!
~A=A+IN
添加自循环允许 GNN 将源节点的特征与其邻居的特征聚合在一起!!
这样一来,您就可以使用矩阵而不是单个节点进行 GNN 前向传递。
⭐ 要执行 mean
聚合,我们可以简单地将总和除以 11 中的 s 计数 Ai 。对于上面的例子,由于 中有三个 11 ,我们可以 ∑j∈N2Wxj∑除以 33 … A2=[1,0,0,1,1]这正是平均值!!
但是,不可能通过GNN的邻接矩阵公式来实现 max
和 min
聚合。
堆叠 GNN 层
现在我们已经弄清楚了单个GNN层是如何工作的,我们如何构建这些层的整个“网络”呢?信息如何在层之间流动,以及GNN如何优化节点(和/或边缘)的嵌入/表示?
- 第一GNN层的输入是节点特征 X⊆RN×d。输出是中间节点嵌入, H1⊆RN×d1 其中 d1是第一个嵌入维度。 H1由 h1i : 1→N∈Rd1ℎ .
- H1是第二层的输入。下一个输出是 H2⊆RN×d2第二层的嵌入维度。同样, H2由 h2i : 1→N∈Rd2ℎ.
- 几层之后,在输出层 L,输出为 HL⊆RN×dL。最后, HL由 hLi : 1→N∈RdLℎ .
的选择 {d1,d2,…,dL} 完全取决于我们,并且是GNN的超参数。可以把这些看作是为一堆MLP层选择单位(“神经元”的数量)。
节点特征/嵌入(“表示”)通过 GNN 传递。结构保持不变,但节点表示在各层中不断变化。或者,边表示也会更改,但不会更改连接或方向。
现在,我们可以做 HL一些事情:
- 我们可以沿着第一个轴(即 ∑Nk=1hLk∑)添加它以获得一个向量 RdL。此向量是整个图形的最新维度表示。它可以用于图分类(例如:这是什么分子?
- 我们可以将向量连接起来 HL (即向量连接操作 ⨁Nk=1hk⨁在哪里 ⊕⊕ ),并通过图形自动编码器传递它。当输入图嘈杂或损坏并且我们想要重建去噪图时,这可能会有所帮助。
- 我们可以做节点分类→这个节点属于哪个类?嵌入在特定索引 hLiℎ( ) 的节点可以通过分类器(如 MLP i:1→N)放入 K 类中(例如:这是碳原子、氢原子还是氧原子?
- 我们可以执行链接预测→某个节点 i 和 j ?节点嵌入 hLiℎ可以馈送到另一个基于 Sigmoid 的 MLP 中,该 MLP 会吐出这些节点之间存在边缘的概率。
无论哪种方式,有趣的是,每个 h1→N∈HLℎ1→都可以堆叠起来,并被认为是一批样品。人们可以很容易地将其视为一个批次。
🚨 对于给定的节点,GNN聚合中的 lth 层具有节点 i 的 i l -hop邻域。最初,节点看到它的近邻,并深入到网络中,它与邻居的邻居进行交互,等等。
这就是为什么对于非常小的、稀疏的(很少的边)图,大量的GNN层通常会导致性能下降。这是因为所有嵌入的节点都收敛到一个单一向量,因为每个节点都看到了许多跳之外的节点。这是一个无用的情况!!
这就解释了为什么大多数GNN论文经常使用 ≤4≤4 层进行实验,以防止网络死亡。
训练 GNN(上下文:节点分类)
🥳 在训练过程中,可以使用损失函数(例如:交叉熵)将节点、边或整个图的预测与数据集中的真值标签进行比较。
这使得 GNN 能够使用原版反向道具和梯度下降以端到端的方式进行训练。
训练和测试图形数据
与常规 ML 一样,图形数据也可以拆分为训练和测试。这可以通过以下两种方式之一完成:
透导性
训练和测试数据都存在于同一个图中。每个集合中的节点相互连接。只不过,在训练过程中,测试节点的标签是隐藏的,而训练节点的标签是可见的。但是,所有节点的特征对 GNN 都是可见的。
我们可以使用所有节点上的二进制掩码来做到这一点(如果训练节点连接到测试节点 i j ,只需在邻接矩阵中设置 Aij=0 即可)。
在转导设置中,训练和测试节点都是 SAME 图的一部分。只是训练节点暴露其特征和标签,而测试节点仅暴露其特征。测试标签在模型中是隐藏的。需要二进制掩码来告诉GNN什么是训练节点,什么是测试节点。
归纳
在这里,有单独的训练图和测试图,它们彼此隐藏。这类似于常规 ML,其中模型在训练期间仅看到特征和标签,并且只看到用于测试的特征。训练和测试在两个独立的、孤立的图形上进行。有时,这些测试图是分布外的,以检查训练期间泛化的质量。
与常规 ML 一样,训练和测试数据是分开保存的。GNN 仅使用来自训练节点的特征和标签。这里不需要二进制掩码来隐藏测试节点,因为它们来自不同的集合。
反向支撑和梯度下降
在训练过程中,一旦我们通过GNN进行前向传递,我们就会得到最终的 hLi∈HLℎ节点表示。要以端到端的方式训练网络,我们可以执行以下操作:
- 将每个 hLiℎ数据馈送到 MLP 分类器中以获得预测 ^yi�^�
- 使用地面实况 yi和预测 ^yi → J(^yi,yi)计算损失
- 使用 Backpropagatino 计算梯度, ∂J∂Wl∂其中 Wl 是层的参数矩阵
- 使用一些优化器(如梯度下降)来更新 GNN 中每个层的参数 Wl
- (可选)您还可以微调分类器 (MLP) 网络的权重。
🥳 这意味着GNN在消息传递和训练方面都很容易并行化。整个过程可以矢量化(如上所示)并在 GPU 上执行!!
流行的图神经网络
在本节中,我将介绍文献中的一些流行作品,并将其方程式和数学归类为上述 3 个 GNN 步骤(或者至少我尝试过)。许多流行的体系结构将消息传递和聚合步骤合并到一个一起执行的函数中,而不是一个接一个地显式执行。我试图在本节中分解它们,但为了数学上的方便,最好将它们视为一个单一的运算!
我调整了本节中介绍的网络符号,使其与本文的符号一致。
消息传递神经网络
量子化学的神经信息传递
消息传递神经网络(MPNN)将前向传递分解为具有消息传递功能的消息传递阶段,以及具有顶点更新功能 MlUl的读出阶段。
MPNN 将消息传递和聚合步骤合并到单个消息传递阶段:
ml+1i=∑j∈NiMl(hli, hlj, eij)(15)(15)
读出阶段是更新步骤:
hl+1i=Ul(hli, ml+1i)(16)(16)ℎ
其中 ml+1v是聚合消息, hl+1vℎ是嵌入的更新节点。这与我上面提到的过程非常相似。消息函数是 F 和 G 的混合,函数 Ul Ml是 K 。这里, eij指的是也可以省略的可能边缘特征。
本文以MPNN为一般框架,并将文献中的其他作品作为MPNN的特殊变体。作者进一步将MPNN用于量子化学应用。
图卷积网络
基于图卷积网络的半监督分类
图卷积网络 (GCN) 论文以邻接矩阵的形式研究整个图。首先,将自连接添加到邻接矩阵中,以确保所有节点都连接到自身以获得 ~A~ .这确保了我们在消息聚合期间考虑了源节点的嵌入。组合的 Message Aggregation 和 Update 步骤如下所示:
Hl+1=σ(~AHlWl)(17)(17)
其中 Wl是可学习的参数矩阵。当然,我改 X 为 H 在任意层 l 概括节点特征,其中 H0=X
🤔 由于矩阵乘法 ( A(BC)=(AB)C 的关联性质,我们在哪个序列中复配矩阵并不重要(要么是 ~AHl第一个,下一个乘法后,要么是 HlWl 下一个乘法前乘法 ~AWl )。
然而,作者 Kipf 和 Welling 进一步引入了度矩阵 ~D作为重整化的一种形式,以避免数值不稳定和梯度爆炸/消失:
~Dii=∑j~Aij(18)(18)
“重整化”是在增强邻接矩阵上进行的 ^A=~D−12~A~D−12。总而言之,新的组合消息传递和更新步骤如下所示:
Hl+1=σ(^AHlWl)(19)(19)
图注意力网络
聚合通常涉及在总和、平均值、最大值和最小值设置中平等对待所有邻居。然而,在大多数情况下,一些邻居比其他邻居更重要。图注意力网络(GAT)通过使用Vaswani等人(2017)的自注意力对源节点与其邻居之间的边缘进行加权来确保这一点。
边缘权重 αij 的生成方式如下。
αij=Softmax(LeakyReLU(WaT⋅[Whli ⊕ Whlj]))(20)(20)
其中 Wa∈R2d′和 W⊆Rd′×d是学习参数,是嵌入维度, d′ ⊕⊕ 是向量串联操作。
虽然初始消息传递步骤与 MPNN/GCN 相同,但组合的消息聚合和更新步骤是所有邻居和节点本身的加权总和:
hi=∑j∈Ni ∪ {i}αij ⋅ Whlj
边缘重要性权重有助于了解邻居对源节点的影响程度。
与 GCN 一样,添加了自循环,以便源节点可以考虑自己的表示,以便将来的表示。
GraphSAGE
GraphSAGE 代表 Graph SAmple 和 AggreGatE。这是一个为大型、非常密集的图形生成节点嵌入的模型(用于 Pinterest 等公司)。
这项工作介绍了节点邻域上的学习聚合器。与考虑邻域中所有节点的传统 GAT 或 GCN 不同,GraphSAGE 统一对邻域进行采样,并在其上使用学习到的聚合器。
假设我们在网络(深度)中有层,每一 L层 l∈{1,…,L}都着眼于源节点的较大 l 跃点邻域(正如人们所期望的那样)。然后,在通过 MLP F和非线性 σ 传递之前,通过将嵌入的节点与采样消息连接起来来更新每个源节点。
对于某一层 l ,
hlN(i)=AGGREGATEl({hl−1j:j∈N(i)})hli=σ(F(hl−1i ⊕ hlN(i)))(22)
其中 ⊕⊕ 是向量串联运算, N(i) 是返回所有邻居的子集的统一采样函数。因此,如果一个节点有 5 个邻居 {1,2,3,4,5}{1,2,3,4,5} ,则可能的 N(i)输出将是 {1,4,5}{1,4,5} 或 {2,5}{2,5} 。
聚合器 k=1聚合来自 -hop 邻域的采样节点(彩色),而聚合器 k=2聚合来自 22 11 -hop 邻域的采样节点(彩色)
未来可能的工作可能是试验非均匀抽样函数来选择邻居。
注意:在本文中,作者使用 K和 k 来表示图层索引。在本文中,我分别使用 L 和 l 来保持一致。此外,本文还用于 v 表示源节点和 u 表示邻居 j 。 i
奖励:GraphSAGE之前的工作包括DeepWalk。一探究竟!
时态图网络
用于动态图深度学习的时态图网络
到目前为止所描述的网络在静态图上工作。大多数现实情况都适用于动态图,其中节点和边在一段时间内添加、删除或更新。时态图网络 (TGN) 适用于连续时间动态图 (CTDG),可以表示为按时间顺序排序的事件列表。
本文将事件分为两种类型:节点级事件和交互事件。节点级事件涉及一个孤立的节点(例如:用户更新其个人资料的简历),而交互事件涉及两个可能连接或可能不连接的节点(例如:用户 A 转发/关注用户 B)。
TGN 提供模块化的 CTDG 处理方法,包括以下组件:
- 消息传递函数 →隔离节点或交互节点之间的消息传递(对于任一类型的事件)。
- 消息聚合函数 → **通过查看多个时间步长的时间邻域而不是给定时间步长的局部邻域来使用 GAT 的聚合。
- 内存更新程序→允许节点具有长期依赖关系,并表示节点在潜在(“压缩”)空间中的历史记录。该模块根据随时间发生的交互来更新节点的内存。
- 时间嵌入→一种表示捕获时间本质的节点的方法。
- 链路预测→事件中涉及的节点的时间嵌入通过一些神经网络来计算边缘概率(即,边缘将来会发生吗?当然,在训练过程中,我们知道边缘存在,所以边缘标签是 11 。我们需要训练基于 Sigmoid 的网络来像往常一样预测这一点。
每当节点参与活动(节点更新或节点间交互)时,内存都会更新。
(1) 对于每个事件和 22 批处理中,TGN 为涉及该事件 11 的所有节点生成消息。
(2)接下来,for TGN聚合每个节点 mi 所有时间步 t 的消息;这称为节点 i 的时间邻域。
(3)接下来,TGN使用聚合的消息 ¯mi(t)来更新每个节点 si(t) 的内存。
(4) 一旦所有节点的内存 si(t) 都是最新的,它就用于计算批处理中特定交互中使用的所有节点的“临时节点嵌入”。 zi(t)
(5) 然后将这些节点嵌入输入 MLP 或神经网络,以获得每个事件发生的概率(使用 Sigmoid 激活)。
(6) 然后,我们可以像往常一样使用二元交叉熵 (BCE) 计算损失(未显示)。
结论
上面就是我们对图神经网络的数学总结,图深度学习在处理具有类似网络结构的问题时是一个很好的工具。它们很容易理解,我们可以使用PyTorch Geometric、spectral、Deep Graph Library、Jraph(jax)以及TensorFlow-gnn来实现。GDL已经显示出前景,并将继续作为一个领域发展。
本文章转载微信公众号@Python人工智能前沿