异质图链接预测

由于时效问题,该文某些代码、技术可能已经过期,请注意!!!本文最后更新于:2 年前

如题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch import nn
import copy

from torch_geometric.nn import RGCNConv
from tqdm.notebook import tqdm
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score
from torch_geometric.data import HeteroData
from torch_geometric.datasets import DBLP


dataset = DBLP(root='./DBLP')
graph = dataset[0]

HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={ x=[14328, 4231] },
term={ x=[7723, 50] },
conference={ num_nodes=20 },
(author, to, paper)={ edge_index=[2, 19645] },
(paper, to, author)={ edge_index=[2, 19645] },
(paper, to, term)={ edge_index=[2, 85810] },
(paper, to, conference)={ edge_index=[2, 14328] },
(term, to, paper)={ edge_index=[2, 85810] },
(conference, to, paper)={ edge_index=[2, 14328] }
)

1
2
3
4
5
6
7
8
9
10
11
12
# # 因为conference节点没有特征信息,所以将conference节点的特征初始化为1
graph['conference'].x = torch.ones(graph['conference'].num_nodes, 1)
graph

node_types, edge_types = graph.metadata()
print('node_types:', node_types)
print('edge_types:', edge_types)
num_relations = len(edge_types)
print('num_relations:', num_relations)
init_sizes = [graph[x].x.shape[1] for x in node_types]
print(init_sizes)
device = torch.device('cpu')
1
2
3
4
5
6
7
8
9
10
11
12
train_data, val_data, test_data = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
is_undirected=True,
add_negative_train_samples=False,
disjoint_train_ratio=0,
edge_types=[('author', 'to', 'paper'), ('paper', 'to', 'term'),
('paper', 'to', 'conference')],
rev_edge_types=[('paper', 'to', 'author'), ('term', 'to', 'paper'),
('conference', 'to', 'paper')]
)(graph.to_homogeneous())
train_data

Data(edge_index=[2, 191654], node_type=[26128], edge_type=[191654], edge_label=[95827], edge_label_index=[2, 95827])

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def negative_sample(data):
# 从训练集中采样与正边相同数量的负边
neg_edge_index = negative_sampling(
edge_index=data.edge_index, num_nodes=data.num_nodes,
num_neg_samples=data.edge_label_index.size(1), method='sparse')
# print(neg_edge_index.size(1)) # 3642条负边,即每次采样与训练集中正边数量一致的负边
edge_label_index = torch.cat(
[data.edge_label_index, neg_edge_index],
dim=-1,
)
edge_label = torch.cat([
data.edge_label,
data.edge_label.new_zeros(neg_edge_index.size(1))
], dim=0)

return edge_label, edge_label_index
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class RGCN_LP(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(RGCN_LP, self).__init__()
self.conv1 = RGCNConv(in_channels, hidden_channels,
num_relations=num_relations, num_bases=30)
self.conv2 = RGCNConv(hidden_channels, out_channels,
num_relations=num_relations, num_bases=30)
self.lins = torch.nn.ModuleList()
for i in range(len(node_types)):
lin = nn.Linear(init_sizes[i], in_channels)
self.lins.append(lin)

def trans_dimensions(self, g):
data = copy.deepcopy(g)
for node_type, lin in zip(node_types, self.lins):
data[node_type].x = lin(data[node_type].x)

return data

def encode(self, data):
data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
# print(homogeneous_data)
edge_index, edge_type = homogeneous_data.edge_index, homogeneous_data.edge_type
# print(edge_type)
x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)
return x

def decode(self, z, edge_label_index):
# z所有节点的表示向量
src = z[edge_label_index[0]]
dst = z[edge_label_index[1]]
# print(dst.size())
r = (src * dst).sum(dim=-1)
# print(r.size())
return r

def forward(self, data, edge_label_index):
z = self.encode(data)
return self.decode(z, edge_label_index)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def test(model, data):
model.eval()
with torch.no_grad():
edge_label, edge_label_index = negative_sample(data)
out = model(graph, edge_label_index).view(-1)
model.train()
return roc_auc_score(edge_label.cpu().numpy(), out.cpu().numpy())



model = RGCN_LP(128, 64, 4).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss().to(device)
min_epochs = 10
best_model = None
best_val_auc = 0
final_test_auc = 0
model.train()
for epoch in range(500):
optimizer.zero_grad()
edge_label, edge_label_index = negative_sample(train_data)
out = model(graph, edge_label_index).view(-1)
loss = criterion(out, edge_label)
loss.backward()
optimizer.step()
# validation
val_auc = test(model, val_data)
# test_auc = test(model, test_data)
if epoch + 1 > min_epochs and val_auc > best_val_auc:
best_val_auc = val_auc
# final_test_auc = test_auc
if epoch % 10 == 0:
print('epoch {:03d} train_loss {:.8f} val_auc {:.4f}'.format(epoch, loss.item(), val_auc))
# print('epoch {:03d} train_loss {:.8f}'.format(epoch, loss.item()))

参考:https://blog.csdn.net/circle2015/article/details/128004889
https://blog.csdn.net/Cyril_KI/article/details/126186418?spm=1001.2014.3001.5502
https://blog.csdn.net/Cyril_KI/article/details/126048682
https://zhuanlan.zhihu.com/p/354258797
https://blog.csdn.net/PolarisRisingWar/article/details/128130870
https://blog.csdn.net/PolarisRisingWar/article/details/126980943


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!