import torch
from torch_geometric.data import Data
from federatedscope.core.mlp import MLP
from federatedscope.gfl.model.gcn import GCN_Net
from federatedscope.gfl.model.sage import SAGE_Net
from federatedscope.gfl.model.gat import GAT_Net
from federatedscope.gfl.model.gin import GIN_Net
from federatedscope.gfl.model.gpr import GPR_Net
[docs]class GNN_Net_Link(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden=64,
max_depth=2,
dropout=.0,
gnn='gcn',
layers=2):
r"""GNN model with LinkPredictor for link prediction tasks.
Arguments:
in_channels (int): input channels.
out_channels (int): output channels.
hidden (int): hidden dim for all modules.
max_depth (int): number of layers for gnn.
dropout (float): dropout probability.
gnn (str): name of gnn type, use ("gcn" or "gin").
layers (int): number of layers for LinkPredictor.
"""
super(GNN_Net_Link, self).__init__()
self.dropout = dropout
# GNN layer
if gnn == 'gcn':
self.gnn = GCN_Net(in_channels=in_channels,
out_channels=hidden,
hidden=hidden,
max_depth=max_depth,
dropout=dropout)
elif gnn == 'sage':
self.gnn = SAGE_Net(in_channels=in_channels,
out_channels=hidden,
hidden=hidden,
max_depth=max_depth,
dropout=dropout)
elif gnn == 'gat':
self.gnn = GAT_Net(in_channels=in_channels,
out_channels=hidden,
hidden=hidden,
max_depth=max_depth,
dropout=dropout)
elif gnn == 'gin':
self.gnn = GIN_Net(in_channels=in_channels,
out_channels=hidden,
hidden=hidden,
max_depth=max_depth,
dropout=dropout)
elif gnn == 'gpr':
self.gnn = GPR_Net(in_channels=in_channels,
out_channels=hidden,
hidden=hidden,
K=max_depth,
dropout=dropout)
else:
raise ValueError(f'Unsupported gnn type: {gnn}.')
dim_list = [hidden for _ in range(layers)]
self.output = MLP([hidden] + dim_list + [out_channels],
batch_norm=True)
[docs] def forward(self, data):
if isinstance(data, Data):
x, edge_index = data.x, data.edge_index
elif isinstance(data, tuple):
x, edge_index = data
else:
raise TypeError('Unsupported data type!')
x = self.gnn((x, edge_index))
return x
def link_predictor(self, x, edge_index):
x = x[edge_index[0]] * x[edge_index[1]]
x = self.output(x)
return x