Source code for federatedscope.core.splitters.graph.random_splitter

import torch

from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx, from_networkx

import numpy as np
import networkx as nx

from federatedscope.core.splitters import BaseSplitter

EPSILON = 1e-5


[docs]class RandomSplitter(BaseTransform, BaseSplitter): """ Split Data into small data via random sampling. Args: client_num (int): Split data into client_num of pieces. sampling_rate (str): Samples of the unique nodes for each client, \ eg. ``'0.2,0.2,0.2'`` overlapping_rate(float): Additional samples of overlapping data, \ eg. ``'0.4'`` drop_edge(float): Drop edges (drop_edge / client_num) for each \ client within overlapping part. """ def __init__(self, client_num, sampling_rate=None, overlapping_rate=0, drop_edge=0): BaseSplitter.__init__(self, client_num) self.ovlap = overlapping_rate if sampling_rate is not None: self.sampling_rate = np.array( [float(val) for val in sampling_rate.split(',')]) else: # Default: Average self.sampling_rate = (np.ones(client_num) - self.ovlap) / client_num if len(self.sampling_rate) != client_num: raise ValueError( f'The client_num ({client_num}) should be equal to the ' f'lenghth of sampling_rate and overlapping_rate.') if abs((sum(self.sampling_rate) + self.ovlap) - 1) > EPSILON: raise ValueError( f'The sum of sampling_rate:{self.sampling_rate} and ' f'overlapping_rate({self.ovlap}) should be 1.') self.drop_edge = drop_edge def __call__(self, data, **kwargs): data.index_orig = torch.arange(data.num_nodes) G = to_networkx( data, node_attrs=['x', 'y', 'train_mask', 'val_mask', 'test_mask'], to_undirected=True) nx.set_node_attributes(G, dict([(nid, nid) for nid in range(nx.number_of_nodes(G))]), name="index_orig") client_node_idx = {idx: [] for idx in range(self.client_num)} indices = np.random.permutation(data.num_nodes) sum_rate = 0 for idx, rate in enumerate(self.sampling_rate): client_node_idx[idx] = indices[round(sum_rate * data.num_nodes):round( (sum_rate + rate) * data.num_nodes)] sum_rate += rate if self.ovlap: ovlap_nodes = indices[round(sum_rate * data.num_nodes):] for idx in client_node_idx: client_node_idx[idx] = np.concatenate( (client_node_idx[idx], ovlap_nodes)) # Drop_edge index for each client if self.drop_edge: ovlap_graph = nx.Graph(nx.subgraph(G, ovlap_nodes)) ovlap_edge_ind = np.random.permutation( ovlap_graph.number_of_edges()) drop_all = ovlap_edge_ind[:round(ovlap_graph.number_of_edges() * self.drop_edge)] drop_client = [ drop_all[s:s + round(len(drop_all) / self.client_num)] for s in range(0, len(drop_all), round(len(drop_all) / self.client_num)) ] graphs = [] for owner in client_node_idx: nodes = client_node_idx[owner] sub_g = nx.Graph(nx.subgraph(G, nodes)) if self.drop_edge: sub_g.remove_edges_from( np.array(ovlap_graph.edges)[drop_client[owner]]) graphs.append(from_networkx(sub_g)) return graphs