图神经网络(GNN)详解

图结构因其能够以可分析的方式表示现实世界而受到广泛关注。图可以用来表示社交网络、分子结构、地理地图、网站链接数据、自然科学、蛋白质-蛋白质相互作用网络、知识图谱等多种现实世界数据集。此外,图像和文本等非结构化数据也可以以图的形式建模。图是一种非欧几里得数据结构,用于机器学习,图分析专注于节点分类、图分类、链接预测、图聚类和图可视化等任务。图神经网络(GNN)是基于深度学习的方法,它们在非欧几里得空间的实际问题中表现出色,因此最近成为了一种广泛应用的图分析方法。

图神经网络算法

在图神经网络中,一个节点可以通过其特征和图中的邻近节点来表示。GNN的目标是为每个节点学习一个状态嵌入,该嵌入编码了邻域的信息。状态嵌入用于产生输出,例如预测节点标签的分布。GNN是信息扩散机制和神经网络的结合体,代表了一组转换函数和一组输出函数。信息扩散机制由节点表示,它们的状态通过向邻近节点传递“消息”来更新和交换信息,直到达到稳定的平衡。转换函数以每个节点的特征、每个节点的边特征、邻近节点的状态和邻近节点的特征作为输入,输出是节点的新状态。

GNN在Karate网络上的实现

本节让看看如何将GNN应用于Karate网络,这是一个简单的图网络之一。Karate网络数据背景:两个34×34矩阵——1. ZACHE对称,二进制;2. ZACHC对称,有值。这些数据由Wayne Zachary从大学空手道俱乐部的成员那里收集。ZACHE矩阵代表俱乐部成员之间联系的存在或缺失;ZACHC矩阵指示关联的相对强度(在俱乐部内外发生互动的情况数量)。Zachary(1977)使用这些数据和一个网络冲突解决的信息流模型来解释这个群体在成员之间的争议之后分裂的原因。

可以使用DGL库来表示图,其中每个节点是一个俱乐部成员,每条边代表他们的互动。在DGL中,节点是从零开始的连续整数。因此,在准备数据时,重要的是重新标记或重新洗牌行顺序,以便第一行对应第一节点,依此类推。在这个例子中,已经以正确的顺序准备了数据,所以可以通过edges.csv表中的‘Src’和‘Dst’列创建图。

import dgl src = edges_data['Src'].to_numpy() dst = edges_data['Dst'].to_numpy() # 从numpy数组对创建DGL图 g = dgl.graph((src, dst))

为了可视化目的,可以将dgl图转换为网络图:

import networkx as nx # 由于实际图是无向的,将其转换为可视化目的。 nx_g = g.to_networkx().to_undirected() # Kamada-Kawai布局通常对任意图看起来很漂亮 pos = nx.kamada_kawai_layout(nx_g) nx.draw(nx_g,pos, with_labels=True)

在Karate网络上训练GNN模型:

# “Club”列表示每个节点属于哪个社区。 # 值是字符串类型,因此必须将其转换为分类整数值或独热编码。 club = nodes_data['Club'].to_list() # 将其转换为分类整数值,'Mr. Hi'为0,'Officer'为1。 club = torch.tensor([c == 'Officer' for c in club]).long() # 也可以将其转换为独热编码。 club_onehot = F.one_hot(club) print(club_onehot) # 使用`g.ndata`像使用普通字典一样 g.ndata.update({'club' : club, 'club_onehot' : club_onehot})

更新dgl图中的边特征:

# 从DataFrame中获取边特征并将其输入到图中。 edge_weight = torch.tensor(edges_data['Weight'].to_numpy()) # 类似地,使用`g.edata`获取/设置边特征。 g.edata['weight'] = edge_weight

更新节点嵌入:

node_embed = nn.Embedding(g.number_of_nodes(), 5) # 每个节点都有一个大小为5的嵌入。 inputs = node_embed.weight # 使用嵌入权重作为节点特征。 nn.init.xavier_uniform_(inputs)

更新标签特征,对于两个组领导——0和33 ID:

labels = g.ndata['club'] labeled_nodes = [0, 33]

使用GraphSage模型实现GNN:

from dgl.nn import SAGEConv # 构建一个两层GraphSAGE模型 class GraphSAGE(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(GraphSAGE, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats, 'mean') self.conv2 = SAGEConv(h_feats, num_classes, 'mean') def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h # 使用给定的维度创建模型 # 输入层维度:5,节点嵌入 # 隐藏层维度:16 # 输出层维度:2,两个类别,0和1 net = GraphSAGE(5, 16, 2)

设置损失和优化器并训练模型:

# 在这种情况下,损失将在训练循环中 optimizer = torch.optim.Adam(itertools.chain(net.parameters(), node_embed.parameters()), lr=0.01) all_logits = [] for e in range(100): # 前向 logits = net(g, inputs) # 计算损失 logp = F.log_softmax(logits, 1) loss = F.nll_loss(logp[labeled_nodes], labels[labeled_nodes]) # 后向 optimizer.zero_grad() loss.backward() optimizer.step() all_logits.append(logits.detach()) if e % 5 == 0: print('In epoch {}, loss: {}'.format(e, loss))

输出:

Training epochs – Self project

获取结果:

pred = torch.argmax(logits, axis=1) print('Accuracy', (pred == labels).sum().item() / len(pred))

图神经网络的应用

图神经网络面临的挑战

  • A Comprehensive Survey on Graph Neural Networks. arxiv 2019. 论文 Zonghan Wu, Shirui Pan, Fengwen Chen, Guodong Long, Chengqi Zhang, Philip S. Yu.
  • Graph Neural Networks: A Review of Methods and Applications. AI Open 2020. 论文 Jie Zhou, Ganqu Cui, Zhengyan Zhang, Cheng Yang, Zhiyuan Liu, Maosong Sun.
  • Supervised Neural Networks for the Classification of Structures. IEEE TNN 1997. 论文 Alessandro Sperduti and Antonina Starita.
  • A new model for learning in graph domains. IJCNN 2005. 论文 Marco Gori, Gabriele Monfardini, Franco Scarselli.
  • Deep Learning on Graphs: A Survey. arxiv 2018. 论文 Ziwei Zhang, Peng Cui, Wenwu Zhu.
沪ICP备2024098111号-1
上海秋旦网络科技中心:上海市奉贤区金大公路8218号1幢 联系电话:17898875485