import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
[docs]class SAGE_Net(torch.nn.Module):
r"""GraphSAGE model from the "Inductive Representation Learning on
Large Graphs" paper, in NeurIPS'17
Source:
https://github.com/pyg-team/pytorch_geometric/ \
blob/master/examples/ogbn_products_sage.py
Arguments:
in_channels (int): dimension of input.
out_channels (int): dimension of output.
hidden (int): dimension of hidden units, default=64.
max_depth (int): layers of GNN, default=2.
dropout (float): dropout ratio, default=.0.
"""
def __init__(self,
in_channels,
out_channels,
hidden=64,
max_depth=2,
dropout=.0):
super(SAGE_Net, self).__init__()
self.num_layers = max_depth
self.dropout = dropout
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden))
for _ in range(self.num_layers - 2):
self.convs.append(SAGEConv(hidden, hidden))
self.convs.append(SAGEConv(hidden, out_channels))
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
def forward_full(self, data):
if isinstance(data, Data):
x, edge_index = data.x, data.edge_index
elif isinstance(data, tuple):
x, edge_index = data
else:
raise TypeError('Unsupported data type!')
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if (i + 1) == len(self.convs):
break
x = F.relu(F.dropout(x, p=self.dropout, training=self.training))
return x
[docs] def forward(self, x, edge_index=None, edge_weight=None, adjs=None):
r"""
`train_loader` computes the k-hop neighborhood of a batch of nodes,
and returns, for each layer, a bipartite graph object, holding the
bipartite edges `edge_index`, the index `e_id` of the original edges,
and the size/shape `size` of the bipartite graph.
Target nodes are also included in the source nodes so that one can
easily apply skip-connections or add self-loops.
Arguments:
x (torch.Tensor or PyG.data or tuple): node features or \
full-batch data
edge_index (torch.Tensor): edge index.
edge_weight (torch.Tensor): edge weight.
adjs (List[PyG.loader.neighbor_sampler.EdgeIndex]): \
batched edge index
:returns:
x: output
:rtype:
torch.Tensor
"""
if isinstance(x, torch.Tensor):
if edge_index is None:
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
x = F.dropout(x,
p=self.dropout,
training=self.training)
else:
for conv in self.convs[:-1]:
x = conv(x, edge_index, edge_weight)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, edge_index, edge_weight)
return x
elif isinstance(x, Data) or isinstance(x, tuple):
return self.forward_full(x)
else:
raise TypeError
[docs] def inference(self, x_all, subgraph_loader, device):
r"""
Compute representations of nodes layer by layer, using *all*
available edges. This leads to faster computation in contrast to
immediately computing the final representations of each batch.
Arguments:
x_all (torch.Tensor): all node features
subgraph_loader (PyG.dataloader): dataloader
device (str): device
:returns:
x_all: output
"""
total_edges = 0
for i in range(self.num_layers):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
total_edges += edge_index.size(1)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
if i != self.num_layers - 1:
x = F.relu(x)
xs.append(x.cpu())
x_all = torch.cat(xs, dim=0)
return x_all