我们团队复现了很多模型(GCN, LSTMTree, RGCN, SSE等等), 在速度以及perf都超越了original paper。 DGL目前支持mxnet和pytorch, 支持传统tensor运算到图运算的自由转换, 简化了搭建graph based neural network的过程。 DGL有着极高的运算效率(SSE, 50M nodes, 150M edges, 160s/epoch) 。与此同时我们的框架上手容易, 我做一个示范, 比如我们要创建一个含有五个节点的tensor graph, 每个node以及edge我们assign 一些feature。 那么下面的代码, 就展示了如何用DGL去搭建这样的图。
Mxnet:
import dgl import mxnet as mx g = dgl.DGLGraph() g.add_nodes(5) # add 5 nodes g.add_edges([0, 0, 0, 0], [1, 2, 3, 4]) # add 4 edges 0->1, 0->2, 0->3, 0->4 g.ndata['h'] = mx.nd.random.randn(5, 3) # assign feature to each node g.edata['h'] = mx.nd.random.randn(4, 4) # assign feature to each edge
Pytorch:
import dgl import torch as th g = dgl.DGLGraph() g.add_nodes(5) # add 5 nodes g.add_edges([0, 0, 0, 0], [1, 2, 3, 4]) # add 4 edges 0->1, 0->2, 0->3, 0->4 g.ndata['h'] = th.randn(5, 3) # assign feature to each node g.edata['h'] = th.randn(4, 4) # assign feature to each edge
再比如, 我们要做一个节点分类的东西,我们知道每篇paper之间的引用关系。以及每个paper含有的关键词信息, 我们就可以利用paper之间的引用关系以及每个paper关键词的embedding去训练一个神经网络对paper进行分类 (这个paper是属于天文学科, 还是地理学科, 还是计算机学科)。 当然普通的神经网络是利用不了graph里面的节点关系的,那么就需要GCN来做。 什么是GCN, 大家可以去看一下这篇paper, 。我演示一下如何用DGL搭建GCN layer。
Mxnet:
import numpy as np import mxnet as mx from mxnet import gluon import dgl from dgl import DGLGraph dgl.load_backend(os.environ.get('DGLBACKEND', 'mxnet').lower()) # set mxnet as backend def gcn_msg(edges): # message passing return {'m' : edges.src['h']} def gcn_reduce(nodes): # sum the the feature of incoming node return {'h' : mx.nd.sum(nodes.mailbox['m'], 1)} class NodeUpdateModule(gluon.Block): # define the GCN layer def __init__(self, out_feats, activation=None, dropout=0): super(NodeUpdateModule, self).__init__() self.linear = gluon.nn.Dense(out_feats, activation=activation) self.dropout = dropout def forward(self, node): h = self.linear(node.data['h']) if self.dropout: h = mx.nd.Dropout(h, p=self.dropout) return {'h': h} class GCN(gluon.Block): # define the model def __init__(self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout, ): super(GCN, self).__init__() self.g = g self.dropout = dropout self.conv_layers = gluon.nn.Sequential() for i in range(n_layers): self.conv_layers.add(NodeUpdateModule(n_hidden, activation, dropout)) self.out_layer = gluon.nn.Dense(n_classes) self.gcn_msg = gcn_msg self.gcn_reduce = gcn_reduce def forward(self, features): self.g.ndata['h'] = features for layer in self.conv_layers: self.g.update_all(self.gcn_msg, self.gcn_reduce, layer) # update the graph return self.out_layer(self.g.ndata.pop('h'))
Pytorch:
import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from dgl import DGLGraph from dgl.data import register_data_args, load_data def gcn_msg(edges): return {'m' : edges.src['h']} def gcn_reduce(nodes): return {'h' : torch.sum(nodes.mailbox['m'], 1)} class NodeApplyModule(nn.Module): # define the GCN layer. def __init__(self, in_feats, out_feats, activation=None): super(NodeApplyModule, self).__init__() self.linear = nn.Linear(in_feats, out_feats) self.activation = activation def forward(self, nodes): # normalization by square root of dst degree h = nodes.data['h'] h = self.linear(h) if self.activation: h = self.activation(h) return {'h' : h} class GCN(nn.Module): # define the model def __init__(self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout): super(GCN, self).__init__() self.g = g if dropout: self.dropout = nn.Dropout(p=dropout) else: self.dropout = 0. self.layers = nn.ModuleList() # input layer self.layers.append(NodeApplyModule(in_feats, n_hidden, activation)) # hidden layers for i in range(n_layers - 1): self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation)) # output layer self.layers.append(NodeApplyModule(n_hidden, n_classes)) def forward(self, features): self.g.ndata['h'] = features for idx, layer in enumerate(self.layers): # apply dropout if idx > 0 and self.dropout: self.g.ndata['h'] = self.dropout(self.g.ndata['h']) self.g.ndata['h'] = self.g.ndata['h'] self.g.update_all(gcn_msg, gcn_reduce, layer) return self.g.ndata.pop('h')
就这样一个GCN的model就完成了, DGL支持自定义message passing 的过程, 同时也内置很多function (fn.copy_src, fn.sum等等)提高图运算的速度, 并且能够让用户能够很清晰的看到图运算的过程。我们复现很多模型, 效果都基本超越了baseline。
很多特性, 在这里介绍不完, 大家想学习的话, 可以去我们的社区多逛逛 (顺便可以献个star)。 看一下文档。 第一次打广告, 很多用语写的不够溜, 有什么问题的话大家在评论区反馈一下哈。