Source code for federatedscope.core.splitters.graph.reltype_splitter

import torch

from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from torch_geometric.transforms import BaseTransform

from federatedscope.core.splitters.utils import \
    dirichlet_distribution_noniid_slice
from federatedscope.core.splitters import BaseSplitter


[docs]class RelTypeSplitter(BaseTransform, BaseSplitter): """ Split Data into small data via dirichlet distribution to \ generate non-i.i.d data split. Arguments: client_num (int): Split data into client_num of pieces. alpha (float): Partition hyperparameter in LDA, smaller alpha \ generates more extreme heterogeneous scenario see \ ``np.random.dirichlet`` """ def __init__(self, client_num, alpha=0.5, realloc_mask=False): BaseSplitter.__init__(self, client_num) self.alpha = alpha self.realloc_mask = realloc_mask def __call__(self, data, **kwargs): data_list = [] label = data.edge_type.numpy() idx_slice = dirichlet_distribution_noniid_slice( label, self.client_num, self.alpha) # Reallocation train/val/test mask train_ratio = data.train_edge_mask.sum().item() / data.num_edges test_ratio = data.test_edge_mask.sum().item() / data.num_edges for idx_j in idx_slice: edge_index = data.edge_index.T[idx_j].T edge_type = data.edge_type[idx_j] train_edge_mask = data.train_edge_mask[idx_j] valid_edge_mask = data.valid_edge_mask[idx_j] test_edge_mask = data.test_edge_mask[idx_j] if self.realloc_mask: num_edges = edge_index.size(-1) indices = torch.randperm(num_edges) train_edge_mask = torch.zeros(num_edges, dtype=torch.bool) train_edge_mask[indices[:round(train_ratio * num_edges)]] = True valid_edge_mask = torch.zeros(num_edges, dtype=torch.bool) valid_edge_mask[ indices[round(train_ratio * num_edges):-round(test_ratio * num_edges)]] = True test_edge_mask = torch.zeros(num_edges, dtype=torch.bool) test_edge_mask[indices[-round(test_ratio * num_edges):]] = True sub_g = Data(x=data.x, edge_index=edge_index, index_orig=data.index_orig, edge_type=edge_type, train_edge_mask=train_edge_mask, valid_edge_mask=valid_edge_mask, test_edge_mask=test_edge_mask, input_edge_index=to_undirected( edge_index.T[train_edge_mask].T)) data_list.append(sub_g) return data_list