Source code for federatedscope.gfl.model.fedsageplus

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import numpy as np
import scipy.sparse as sp

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data

from federatedscope.gfl.model import SAGE_Net
"""
https://proceedings.neurips.cc//paper/2021/file/ \
34adeb8e3242824038aa65460a47c29e-Paper.pdf
Fedsageplus models from the "Subgraph Federated Learning with Missing
Neighbor Generation" (FedSage+) paper, in NeurIPS'21
Source: https://github.com/zkhku/fedsage
"""


class Sampling(nn.Module):
    def __init__(self):
        super(Sampling, self).__init__()

    def forward(self, inputs):
        rand = torch.normal(0, 1, size=inputs.shape)

        return inputs + rand.to(inputs.device)


class FeatGenerator(nn.Module):
    def __init__(self, latent_dim, dropout, num_pred, feat_shape):
        super(FeatGenerator, self).__init__()
        self.num_pred = num_pred
        self.feat_shape = feat_shape
        self.dropout = dropout
        self.sample = Sampling()
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 2048)
        self.fc_flat = nn.Linear(2048, self.num_pred * self.feat_shape)

    def forward(self, x):
        x = self.sample(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.tanh(self.fc_flat(x))

        return x


class NumPredictor(nn.Module):
    def __init__(self, latent_dim):
        self.latent_dim = latent_dim
        super(NumPredictor, self).__init__()
        self.reg_1 = nn.Linear(self.latent_dim, 1)

    def forward(self, x):
        x = F.relu(self.reg_1(x))
        return x


# Mend the graph via NeighGen
class MendGraph(nn.Module):
    def __init__(self, num_pred):
        super(MendGraph, self).__init__()
        self.num_pred = num_pred
        for param in self.parameters():
            param.requires_grad = False

    def mend_graph(self, x, edge_index, pred_degree, gen_feats):
        device = gen_feats.device
        num_node, num_feature = x.shape
        new_edges = []
        gen_feats = gen_feats.view(-1, self.num_pred, num_feature)

        if pred_degree.device.type != 'cpu':
            pred_degree = pred_degree.cpu()
        pred_degree = torch._cast_Int(torch.round(pred_degree)).detach()
        x = x.detach()
        fill_feats = torch.vstack((x, gen_feats.view(-1, num_feature)))

        for i in range(num_node):
            for j in range(min(self.num_pred, max(0, pred_degree[i]))):
                new_edges.append(
                    np.asarray([i, num_node + i * self.num_pred + j]))

        new_edges = torch.tensor(np.asarray(new_edges).reshape((-1, 2)),
                                 dtype=torch.int64).T
        new_edges = new_edges.to(device)
        if len(new_edges) > 0:
            fill_edges = torch.hstack((edge_index, new_edges))
        else:
            fill_edges = torch.clone(edge_index)
        return fill_feats, fill_edges

    def forward(self, x, edge_index, pred_missing, gen_feats):
        fill_feats, fill_edges = self.mend_graph(x, edge_index, pred_missing,
                                                 gen_feats)

        return fill_feats, fill_edges


[docs]class LocalSage_Plus(nn.Module): def __init__(self, in_channels, out_channels, hidden, gen_hidden, dropout=0.5, num_pred=5): super(LocalSage_Plus, self).__init__() self.encoder_model = SAGE_Net(in_channels=in_channels, out_channels=gen_hidden, hidden=hidden, max_depth=2, dropout=dropout) self.reg_model = NumPredictor(latent_dim=gen_hidden) self.gen = FeatGenerator(latent_dim=gen_hidden, dropout=dropout, num_pred=num_pred, feat_shape=in_channels) self.mend_graph = MendGraph(num_pred) self.classifier = SAGE_Net(in_channels=in_channels, out_channels=out_channels, hidden=hidden, max_depth=2, dropout=dropout)
[docs] def forward(self, data): x = self.encoder_model(data) degree = self.reg_model(x) gen_feat = self.gen(x) mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index, degree, gen_feat) nc_pred = self.classifier( Data(x=mend_feats, edge_index=mend_edge_index)) return degree, gen_feat, nc_pred[:data.num_nodes]
def inference(self, impared_data, raw_data): x = self.encoder_model(impared_data) degree = self.reg_model(x) gen_feat = self.gen(x) mend_feats, mend_edge_index = self.mend_graph(raw_data.x, raw_data.edge_index, degree, gen_feat) nc_pred = self.classifier( Data(x=mend_feats, edge_index=mend_edge_index)) return degree, gen_feat, nc_pred[:raw_data.num_nodes]
[docs]class FedSage_Plus(nn.Module): def __init__(self, local_graph: LocalSage_Plus): super(FedSage_Plus, self).__init__() self.encoder_model = local_graph.encoder_model self.reg_model = local_graph.reg_model self.gen = local_graph.gen self.mend_graph = local_graph.mend_graph self.classifier = local_graph.classifier self.encoder_model.requires_grad_(False) self.reg_model.requires_grad_(False) self.mend_graph.requires_grad_(False) self.classifier.requires_grad_(False)
[docs] def forward(self, data): x = self.encoder_model(data) degree = self.reg_model(x) gen_feat = self.gen(x) mend_feats, mend_edge_index = self.mend_graph(data.x, data.edge_index, degree, gen_feat) nc_pred = self.classifier( Data(x=mend_feats, edge_index=mend_edge_index)) return degree, gen_feat, nc_pred[:data.num_nodes]