Source code for federatedscope.gfl.dataloader.dataloader_node

import torch
import numpy as np

from torch_geometric.datasets import Planetoid
from torch_geometric.utils import add_self_loops, remove_self_loops, \
    to_undirected
from torch_geometric.data import Data

from federatedscope.core.auxiliaries.splitter_builder import get_splitter
from federatedscope.core.auxiliaries.transform_builder import get_transform

INF = np.iinfo(np.int64).max


[docs]def load_nodelevel_dataset(config=None): r""" :returns: data_dict :rtype: Dict: dict{'client_id': Data()} """ path = config.data.root name = config.data.type.lower() # TODO: remove splitter # Splitter splitter = get_splitter(config) # Transforms transforms_funcs, _, _ = get_transform(config, 'torch_geometric') # Dataset if name in ["cora", "citeseer", "pubmed"]: num_split = { 'cora': [232, 542, INF], 'citeseer': [332, 665, INF], 'pubmed': [3943, 3943, INF], } dataset = Planetoid(path, name, split='random', num_train_per_class=num_split[name][0], num_val=num_split[name][1], num_test=num_split[name][2], **transforms_funcs) dataset = splitter(dataset[0]) global_dataset = Planetoid(path, name, split='random', num_train_per_class=num_split[name][0], num_val=num_split[name][1], num_test=num_split[name][2], **transforms_funcs) elif name == "dblp_conf": from federatedscope.gfl.dataset.dblp_new import DBLPNew dataset = DBLPNew(path, FL=1, splits=config.data.splits, **transforms_funcs) global_dataset = DBLPNew(path, FL=0, splits=config.data.splits, **transforms_funcs) elif name == "dblp_org": from federatedscope.gfl.dataset.dblp_new import DBLPNew dataset = DBLPNew(path, FL=2, splits=config.data.splits, **transforms_funcs) global_dataset = DBLPNew(path, FL=0, splits=config.data.splits, **transforms_funcs) elif name.startswith("csbm"): from federatedscope.gfl.dataset.cSBM_dataset import \ dataset_ContextualSBM dataset = dataset_ContextualSBM( root=path, name=name if len(name) > len("csbm") else None, theta=config.data.cSBM_phi, epsilon=3.25, n=2500, d=5, p=1000, train_percent=0.2) global_dataset = None else: raise ValueError(f'No dataset named: {name}!') dataset = [ds for ds in dataset] client_num = min(len(dataset), config.federate.client_num ) if config.federate.client_num > 0 else len(dataset) config.merge_from_list(['federate.client_num', client_num]) # get local dataset data_dict = dict() for client_idx in range(1, len(dataset) + 1): local_data = dataset[client_idx - 1] # To undirected and add self-loop local_data.edge_index = add_self_loops( to_undirected(remove_self_loops(local_data.edge_index)[0]), num_nodes=local_data.x.shape[0])[0] data_dict[client_idx] = { 'data': local_data, 'train': [local_data], 'val': [local_data], 'test': [local_data] } # Keep ML split consistent with local graphs if global_dataset is not None: global_graph = global_dataset[0] train_mask = torch.zeros_like(global_graph.train_mask) val_mask = torch.zeros_like(global_graph.val_mask) test_mask = torch.zeros_like(global_graph.test_mask) for client_sampler in data_dict.values(): if isinstance(client_sampler, Data): client_subgraph = client_sampler else: client_subgraph = client_sampler['data'] train_mask[client_subgraph.index_orig[ client_subgraph.train_mask]] = True val_mask[client_subgraph.index_orig[ client_subgraph.val_mask]] = True test_mask[client_subgraph.index_orig[ client_subgraph.test_mask]] = True global_graph.train_mask = train_mask global_graph.val_mask = val_mask global_graph.test_mask = test_mask data_dict[0] = { 'data': global_graph, 'train': [global_graph], 'val': [global_graph], 'test': [global_graph] } return data_dict, config