Source code for federatedscope.core.splitters.graph.randchunk_splitter

import numpy as np

from torch_geometric.transforms import BaseTransform
from federatedscope.core.splitters import BaseSplitter


[docs]class RandChunkSplitter(BaseTransform, BaseSplitter): """ Split graph-level dataset via random chunk strategy. Arguments: dataset (List or PyG.dataset): The graph-level datasets. """ def __init__(self, client_num): BaseSplitter.__init__(self, client_num) def __call__(self, dataset, **kwargs): data_list = [] dataset = [ds for ds in dataset] num_graph = len(dataset) # Split dataset num_graph = len(dataset) min_size = min(50, int(num_graph / self.client_num)) for i in range(self.client_num): data_list.append(dataset[i * min_size:(i + 1) * min_size]) for graph in dataset[self.client_num * min_size:]: client_idx = np.random.randint(low=0, high=self.client_num, size=1)[0] data_list[client_idx].append(graph) return data_list