百科问答小站 logo
百科问答小站 font logo



Graph Attention Network的本质是什么? 第1页

  

user avatar   tczhangzhi 网友的相关建议: 
      

泻药。

目前主流研究使用 Attention 进行边的动态调整,如果你的数据本身不是图结构,又找不到好的距离函数或者难以表达节点之间的相关关系,或许 Attention 是一个不错的选择 ~

什么是 Attention

在机器学习研究中,很多论文习惯用“人是怎样认识世界的“来类比“模型是怎样识别模式的”,虽然没那么清晰严谨但是形象生动。注意力也是类比了人的思维习惯。人在观察的时候是会抓重点的:我们在读句子的时候可能会更关注句子中的几个单词(NLP)揣测发言者的情感,在看图片的时候可能更关注感兴趣的区域(CV)判断图像内容,刷知乎的时候关注大 V 的发言了解舆论走向(Graph)。

个人比较认同 Attention 是一种加权平均 的说法。关注这种行为表现在数学上,就是某些属性或实例拥有更高的权重。比如,某个单词的权重高,这个单词的属性会比其他单词的属性更有力地对句子的属性产生影响。

当然,下面同学说的 Attention model 一种层次化的概率模型 也是有道理的。如果我们把加权平均中的权值之和变换为 1,也可以将这时的权值理解为关注的概率。

说句题外话,杠精就是一种没有训练好的注意力模型,重点总是抓错,所以会对正样本作出错误的判断。

基于 Graph Convolution 的 Attention

要了解基于 Graph Convolution 的 Attention,就得先了解 Graph Convolution 是在做什么。

在之前的回答中提到过,Graph Convolution 的核心思想是利用边的信息对节点进行聚合,从而生成新的节点表示。具体来说,给定一个图 ,其中 为节点集合, 为边的集合,节点的特征用 来表示。我们可以使用图卷积公式 生成新的节点的特征表示 ,其中 是节点 的邻居节点的特征 的加权平均。这里如果不懂请参考原回答。

在(简化的)Graph Convolution 中,权重是直接用边上的 weight 替代的:

其中, 代表图 的邻接矩阵 中的第 行第 列的值,即 边的 weight。

由于 Graph 的边是简单、固定的,因此 Convolution 加权平均过程中邻居节点的权值也是简单、固定的。有没有一种办法可以像人学得注意力一样,让模型学得邻居节点的权值呢。也就是:

其中, 是可学习的权重。可学习的 有两种设计思路,基于相似度的 Attention 和基于学习的 Attention。两种思路并没有明显的优劣差异,都可以尝试一下。

利用相似度

基于相似度的 Attention 需要一些先验信息,例如,余弦相似度衡量节点间的差异是有效的。

其中, 和 是训练参数, 是余弦相似度。废话几句其他的字母: 表示 节点是 节点的邻居, 代表 节点的属性特征。

思路比较好理解:

  1. 对节点特征做变换,得到

2. 求变换后的特征的余弦相似度,得到

3. 乘一个训练参数对余弦相似度进行缩放,得到

4. Softmax 归一化上述结果,得到注意力的概率

这就是 AGNN,个人感觉它对高维特征的处理还是挺有效的。

论文:Attention-based Graph Neural Network for semi-supervised learning

代码:dawnranger/pytorch-AGNN

完全利用学习

基于学习的 Attention 不需要任何先验知识,例如,上一方法中余弦相似度也可以由复杂的神经网络习得。

其中, 和 是训练参数, 代表组合向量 和向量 。废话几句其他的字母: 表示 节点是 节点的邻居, 代表 节点的属性特征。

思路也比较好理解:

  1. 对节点特征做变换,得到

2. 组合变换后的特征,使用训练参数变换,得到节点间的关系

3. LeakyReLU 增强非线性表达能力,得到

4. Softmax 归一化上述结果,最终得到注意力的概率

这就是 GAT,个人感觉它在一些 task 上表现惊人,但是结果不太稳定。

论文:Graph Attention Networks

代码:PetarV-/GAT

Graph 上的 Attention 为什么有效

在大规模 Graph 中由于节点较多,复杂的背景噪声会对 GNN 性能产生不良影响。在 Attention 的作用下,GNN 模型会关注到 Graph 中最重要的节点/节点中最重要的信息从而提高信噪比。

Attention 更巧妙地利用了 Graph 节点之间的相互联系,区分了联系的层级,能够增强任务中需要的有效信息。比如在玩狼人的时候预言家说你是平民,你的平民信息会得到大幅度增强,而普通人说你是平民,你的平民信息增强有限。

参考文献

理论方面,推荐大家看一下综述:Attention Models in Graphs: A Survey

实现方面,推荐学习一下 DGL 的实现,消息传递机制的 Graph Convolution 更好理解一些:Understand Graph Attention Network

       # 贴上来应该有的同学会一见钟情(逃 import torch import torch.nn as nn import torch.nn.functional as F   class GATLayer(nn.Module):     def __init__(self, g, in_dim, out_dim):         super(GATLayer, self).__init__()         self.g = g         # equation (1)         self.fc = nn.Linear(in_dim, out_dim, bias=False)         # equation (2)         self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)      def edge_attention(self, edges):         # edge UDF for equation (2)         z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)         a = self.attn_fc(z2)         return {'e': F.leaky_relu(a)}      def message_func(self, edges):         # message UDF for equation (3) & (4)         return {'z': edges.src['z'], 'e': edges.data['e']}      def reduce_func(self, nodes):         # reduce UDF for equation (3) & (4)         # equation (3)         alpha = F.softmax(nodes.mailbox['e'], dim=1)         # equation (4)         h = torch.sum(alpha * nodes.mailbox['z'], dim=1)         return {'h': h}      def forward(self, h):         # equation (1)         z = self.fc(h)         self.g.ndata['z'] = z         # equation (2)         self.g.apply_edges(self.edge_attention)         # equation (3) & (4)         self.g.update_all(self.message_func, self.reduce_func)         return self.g.ndata.pop('h')     




  

相关话题

  如何看待在某度搜不到megengine官网? 
  如何评价微软亚洲研究院提出的LightRNN? 
  如何评价生成模型框架 ZhuSuan? 
  如何看待FAIR提出的8-bit optimizer:效果和32-bit optimizer相当? 
  CTC和Encoder-Decoder有什么关系? 
  深度学习中有哪些数据增强方法? 
  目前 AI 在疾病的诊断和治疗上,有哪些成功的应用? 
  adversarial training为什么会起作用? 
  2021年深度学习哪些方向比较新颖,处于上升期或者朝阳阶段,没那么饱和,比较有研究潜力? 
  深度学习attention机制中的Q,K,V分别是从哪来的? 

前一个讨论
和女朋友在商场走丢了,随机乱逛和守在特定地点等候,哪个相遇的概率更高?
下一个讨论
如何评价斯大林作为父亲的角色?





© 2024-09-20 - tinynew.org. All Rights Reserved.
© 2024-09-20 - tinynew.org. 保留所有权利