Source code for federatedscope.core.splitters.graph.louvain_splitter

import torch

from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx, from_networkx

import networkx as nx
import community as community_louvain

from federatedscope.core.splitters import BaseSplitter


[docs]class LouvainSplitter(BaseTransform, BaseSplitter): """ Split Data into small data via louvain algorithm. Args: client_num (int): Split data into ``client_num`` of pieces. delta (int): The gap between the number of nodes on each client. """ def __init__(self, client_num, delta=20): self.delta = delta BaseSplitter.__init__(self, client_num) def __call__(self, data, **kwargs): data.index_orig = torch.arange(data.num_nodes) G = to_networkx( data, node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'], to_undirected=True) nx.set_node_attributes(G, dict([(nid, nid) for nid in range(nx.number_of_nodes(G))]), name="index_orig") partition = community_louvain.best_partition(G) cluster2node = {} for node in partition: cluster = partition[node] if cluster not in cluster2node: cluster2node[cluster] = [node] else: cluster2node[cluster].append(node) max_len = len(G) // self.client_num - self.delta max_len_client = len(G) // self.client_num tmp_cluster2node = {} for cluster in cluster2node: while len(cluster2node[cluster]) > max_len: tmp_cluster = cluster2node[cluster][:max_len] tmp_cluster2node[len(cluster2node) + len(tmp_cluster2node) + 1] = tmp_cluster cluster2node[cluster] = cluster2node[cluster][max_len:] cluster2node.update(tmp_cluster2node) orderedc2n = (zip(cluster2node.keys(), cluster2node.values())) orderedc2n = sorted(orderedc2n, key=lambda x: len(x[1]), reverse=True) client_node_idx = {idx: [] for idx in range(self.client_num)} idx = 0 for (cluster, node_list) in orderedc2n: while len(node_list) + len( client_node_idx[idx]) > max_len_client + self.delta: idx = (idx + 1) % self.client_num client_node_idx[idx] += node_list idx = (idx + 1) % self.client_num graphs = [] for owner in client_node_idx: nodes = client_node_idx[owner] graphs.append(from_networkx(nx.subgraph(G, nodes))) return graphs