from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import torch
import numpy as np
import scipy.sparse as sp
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from federatedscope.gfl.model import SAGE_Net
"""
https://proceedings.neurips.cc//paper/2021/file/ \
34adeb8e3242824038aa65460a47c29e-Paper.pdf
Fedsageplus models from the "Subgraph Federated Learning with Missing
Neighbor Generation" (FedSage+) paper, in NeurIPS'21
Source: https://github.com/zkhku/fedsage
"""
class Sampling(nn.Module):
def __init__(self):
super(Sampling, self).__init__()
def forward(self, inputs):
rand = torch.normal(0, 1, size=inputs.shape)
return inputs + rand.to(inputs.device)
class FeatGenerator(nn.Module):
def __init__(self, latent_dim, dropout, num_pred, feat_shape):
super(FeatGenerator, self).__init__()
self.num_pred = num_pred
self.feat_shape = feat_shape
self.dropout = dropout
self.sample = Sampling()
self.fc1 = nn.Linear(latent_dim, 256)
self.fc2 = nn.Linear(256, 2048)
self.fc_flat = nn.Linear(2048, self.num_pred * self.feat_shape)
def forward(self, x):
x = self.sample(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.dropout(x, self.dropout, training=self.training)
x = torch.tanh(self.fc_flat(x))
return x
class NumPredictor(nn.Module):
def __init__(self, latent_dim):
self.latent_dim = latent_dim
super(NumPredictor, self).__init__()
self.reg_1 = nn.Linear(self.latent_dim, 1)
def forward(self, x):
x = F.relu(self.reg_1(x))
return x
# Mend the graph via NeighGen
class MendGraph(nn.Module):
def __init__(self, num_pred):
super(MendGraph, self).__init__()
self.num_pred = num_pred
for param in self.parameters():
param.requires_grad = False
def mend_graph(self, x, edge_index, pred_degree, gen_feats):
device = gen_feats.device
num_node, num_feature = x.shape
new_edges = []
gen_feats = gen_feats.view(-1, self.num_pred, num_feature)
if pred_degree.device.type != 'cpu':
pred_degree = pred_degree.cpu()
pred_degree = torch._cast_Int(torch.round(pred_degree)).detach()
x = x.detach()
fill_feats = torch.vstack((x, gen_feats.view(-1, num_feature)))
for i in range(num_node):
for j in range(min(self.num_pred, max(0, pred_degree[i]))):
new_edges.append(
np.asarray([i, num_node + i * self.num_pred + j]))
new_edges = torch.tensor(np.asarray(new_edges).reshape((-1, 2)),
dtype=torch.int64).T
new_edges = new_edges.to(device)
if len(new_edges) > 0:
fill_edges = torch.hstack((edge_index, new_edges))
else:
fill_edges = torch.clone(edge_index)
return fill_feats, fill_edges
def forward(self, x, edge_index, pred_missing, gen_feats):
fill_feats, fill_edges = self.mend_graph(x, edge_index, pred_missing,
gen_feats)
return fill_feats, fill_edges
[docs]class LocalSage_Plus(nn.Module):
def __init__(self,
in_channels,
out_channels,
hidden,
gen_hidden,
dropout=0.5,
num_pred=5):
super(LocalSage_Plus, self).__init__()
self.encoder_model = SAGE_Net(in_channels=in_channels,
out_channels=gen_hidden,
hidden=hidden,
max_depth=2,
dropout=dropout)
self.reg_model = NumPredictor(latent_dim=gen_hidden)
self.gen = FeatGenerator(latent_dim=gen_hidden,
dropout=dropout,
num_pred=num_pred,
feat_shape=in_channels)
self.mend_graph = MendGraph(num_pred)
self.classifier = SAGE_Net(in_channels=in_channels,
out_channels=out_channels,
hidden=hidden,
max_depth=2,
dropout=dropout)
[docs] def forward(self, data):
x = self.encoder_model(data)
degree = self.reg_model(x)
gen_feat = self.gen(x)
mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
degree, gen_feat)
nc_pred = self.classifier(
Data(x=mend_feats, edge_index=mend_edge_index))
return degree, gen_feat, nc_pred[:data.num_nodes]
def inference(self, impared_data, raw_data):
x = self.encoder_model(impared_data)
degree = self.reg_model(x)
gen_feat = self.gen(x)
mend_feats, mend_edge_index = self.mend_graph(raw_data.x,
raw_data.edge_index,
degree, gen_feat)
nc_pred = self.classifier(
Data(x=mend_feats, edge_index=mend_edge_index))
return degree, gen_feat, nc_pred[:raw_data.num_nodes]
[docs]class FedSage_Plus(nn.Module):
def __init__(self, local_graph: LocalSage_Plus):
super(FedSage_Plus, self).__init__()
self.encoder_model = local_graph.encoder_model
self.reg_model = local_graph.reg_model
self.gen = local_graph.gen
self.mend_graph = local_graph.mend_graph
self.classifier = local_graph.classifier
self.encoder_model.requires_grad_(False)
self.reg_model.requires_grad_(False)
self.mend_graph.requires_grad_(False)
self.classifier.requires_grad_(False)
[docs] def forward(self, data):
x = self.encoder_model(data)
degree = self.reg_model(x)
gen_feat = self.gen(x)
mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index,
degree, gen_feat)
nc_pred = self.classifier(
Data(x=mend_feats, edge_index=mend_edge_index))
return degree, gen_feat, nc_pred[:data.num_nodes]