PyG搭建GCN实现节点分类
I. 前言
GCN原理可以参考:ICLR 2017 | GCN:基于图卷积网络的半监督分类。
一开始是打算手写一下GCN,毕竟原理也不是很难,但想了想还是直接调包吧。在使用各种深度学习框架时我们首先需要知道的是框架内的数据结构,因此这篇文章分为两个部分:第一部分数据处理,主要讲解PyG中的数据结构,第二部分模型搭建。
PyG (PyTorch Geometric)是一个基于PyTorch构建的库,可轻松编写和训练图形神经网络 (GNN),用于与结构化数据相关的广泛应用。
II. PyG数据结构
原始论文中使用的数据集:
这里就以Citeseer网络为例。Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。
使用PyG加载数据集:
data = Planetoid(root='/data/CiteSeer', name='CiteSeer')
print(len(data))
输出为1,说明CiteSeer中只有一个网络,然后我们输出一下这个网络:
data = data[0]
print(data)
print(data.is_directed())
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
False
1. x=[3327, 3703]。表示一共有3327个节点,然后节点的特征维度为3703,这里实际上是去除停用词和在文档中出现频率小于10次的词,整理得到的3703个唯一词。
2. edge_index=[2, 9104],表示一共9104条edge。数据一共两行,每一行都表示节点编号。
3. 输出一下data.y:
tensor([3, 1, 5, ..., 3, 1, 5])
data.y表示节点的标签编号,比如3表示该篇论文属于第3类。
4. 输出data.train_mask:
tensor([ True, True, True, ..., False, False, False])
data.train_mask的长度和y的长度一致,如果某个位置为True就表示该样本为训练样本。val_mask和test_mask类似,分别表示验证集和训练集。
那么很显然,如果我们最终得到了预测值,我们就可以通过以下代码来计算分类的正确数:
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
III. GCN
首先导入包:
from torch_geometric.nn import GCNConv
模型参数:
1. in_channels:输入通道,比如节点分类中表示每个节点的特征数。
2. out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。
3. improved:如果为True表示自环加强,也就是原始邻接矩阵基础上加上2I而不是I,默认为False。
4. cached:如果为True,GCNConv在第一次对邻接矩阵进行归一化时会进行缓存,以后将不再重复计算。
5. add_self_loops:如果为False不再强制添加自环,默认为True。
6. normalize:默认为True,表示对邻接矩阵进行归一化。
7. bias:默认添加偏置。
于是模型搭建如下:
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = F.softmax(x, dim=1)
return x
1. 前向传播
查看官方文档中GCNConv的输入输出要求:
可以发现,GCNConv中需要输入的是节点特征矩阵x和邻接关系edge_index,还有一个可选项edge_weight。因此我们首先:
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
此时我们不妨输出一下x及其size:
tensor([[0.0000, 0.1630, 0.0000, ..., 0.0000, 0.0488, 0.0000],
[0.0000, 0.2451, 0.1614, ..., 0.0000, 0.0125, 0.0000],
[0.1175, 0.0262, 0.2141, ..., 0.2592, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.1825, 0.0000],
[0.0000, 0.1024, 0.0000, ..., 0.0498, 0.0000, 0.0000],
[0.0000, 0.3263, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
device='cuda:0', grad_fn=<FusedDropoutBackward0>)
torch.Size([3327, 16])
此时的x一共3327行,每一行表示一个节点经过第一层卷积更新后的状态向量。第二层卷积同理,即最终输出为:
torch.Size([3327, 6])
即每个节点的维度为6的状态向量。由于我们需要进行6分类,所以最后需要加上一个softmax:
x = F.softmax(x, dim=1)
dim=1表示对每一行进行运算,最终每一行之和加起来为1,也就表示了该节点为每一类的概率。
2. 反向传播
在训练时,我们首先利用前向传播计算出输出,然后算出损失函数:
out = model(data)
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
然后计算梯度,反向更新!
3. 模型训练
def train():
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
model.train()
for epoch in range(500):
out = model(data)
optimizer.zero_grad()
loss = loss_function(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
print('Epoch {:03d} loss {:.4f}'.format(epoch, loss.item()))
4. 模型测试
def test(model, data):
model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('GCN Accuracy: {:.4f}'.format(acc))
IV. 完整代码
完整代码及数据:https://github.com/ki-ljl/PyG-GCN,点击阅读原文即可跳转至代码下载界面。
项目结构:
README文件:
相关文章
- 浅谈Oracle RAC(6) 之实战:节点reboot问题的调查方法
- 单向链表之删除节点(C语言实现)「建议收藏」
- 降本超30%,智聆口语通过 TKE 注册节点实现 IDC GPU 节点降本增效实践
- 如何集中式管理多个客户端节点传输任务-镭速
- 【Linux 内核 内存管理】物理内存组织结构 ⑥ ( 物理页 page 简介 | 物理页 page 与 MMU 内存管理单元 | 内存节点 pglist_data 与 物理页 page 联系 )
- 思科VPP系列砖题三:VPP节点注册
- IPFS 本地节点搭建(命令行)
- Redis有序集合类型的操作_动力节点Java学院整理
- Mysql 实现向上递归查找父节点并返回树结构的示例代码
- 全面优化Linux服务器节点,实现高性能运行(linux服务器节点)
- 创建利用Oracle实现子节点的快速创建(oracle子节点)
- K-3D是基于GNU/Linux和Win32的一个三维建模、动画和绘制系统,是一款免费、开放原始码的 3D 模型和动画制作与渲染 (rendering) 工具,它强大的功能可以满足专业人士的需求。它可以创建和编辑 3D 几何图形,提供极具弹性的面向对象的插件增强功能及以节点作基础的可视化管线架构,所有参数和选项的调整,都会立即显现结果,而且可以无限次数地复原与取消复原。此外,它使用与 RenderMan 相符的渲染引擎 (render engine),可创作出电影质量的 3D 动画。
- 生成Linux设备节点的指南(linux创建设备节点)
- 节点配置MySQL数据库双主节点配置实现高可用(mysql数据库双主)
- Redis集群稳定拓展,增加新节点(redis集群 添加节点)
- 节点构建Redis集群指定节点实现(redis集群指定)
- 基于Redis集群实现动态增加节点(redis集群加入节点)
- 个节点Redis越过16384节点实现巨大规模分布式集群(redis超过16384)
- 实现Redis主从节点的稳定配置(redis 设置主从节点)
- 使用Redis虚拟节点实现高可用性(redis 虚拟节点)
- 使用Redis节点实现高性能的应用(redis 节点用来干嘛)
- 利用Redis节点实现高效的数据分片(redis节点分片)
- Redis群集节点故障排查与修复(redis群集节点故障)
- 学习YUI.Ext第六天--关于树TreePanel(Part2异步获取节点)
- jquery实现点击TreeView文本父节点展开/折叠子节点
- DevExpress实现TreeList父子节点CheckState状态同步的方法