import torch
import numpy as np
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import add_self_loops, remove_self_loops, \
to_undirected
from torch_geometric.data import Data
from federatedscope.core.auxiliaries.splitter_builder import get_splitter
from federatedscope.core.auxiliaries.transform_builder import get_transform
INF = np.iinfo(np.int64).max
[docs]def load_nodelevel_dataset(config=None):
r"""
:returns:
data_dict
:rtype:
Dict: dict{'client_id': Data()}
"""
path = config.data.root
name = config.data.type.lower()
# TODO: remove splitter
# Splitter
splitter = get_splitter(config)
# Transforms
transforms_funcs, _, _ = get_transform(config, 'torch_geometric')
# Dataset
if name in ["cora", "citeseer", "pubmed"]:
num_split = {
'cora': [232, 542, INF],
'citeseer': [332, 665, INF],
'pubmed': [3943, 3943, INF],
}
dataset = Planetoid(path,
name,
split='random',
num_train_per_class=num_split[name][0],
num_val=num_split[name][1],
num_test=num_split[name][2],
**transforms_funcs)
dataset = splitter(dataset[0])
global_dataset = Planetoid(path,
name,
split='random',
num_train_per_class=num_split[name][0],
num_val=num_split[name][1],
num_test=num_split[name][2],
**transforms_funcs)
elif name == "dblp_conf":
from federatedscope.gfl.dataset.dblp_new import DBLPNew
dataset = DBLPNew(path,
FL=1,
splits=config.data.splits,
**transforms_funcs)
global_dataset = DBLPNew(path,
FL=0,
splits=config.data.splits,
**transforms_funcs)
elif name == "dblp_org":
from federatedscope.gfl.dataset.dblp_new import DBLPNew
dataset = DBLPNew(path,
FL=2,
splits=config.data.splits,
**transforms_funcs)
global_dataset = DBLPNew(path,
FL=0,
splits=config.data.splits,
**transforms_funcs)
elif name.startswith("csbm"):
from federatedscope.gfl.dataset.cSBM_dataset import \
dataset_ContextualSBM
dataset = dataset_ContextualSBM(
root=path,
name=name if len(name) > len("csbm") else None,
theta=config.data.cSBM_phi,
epsilon=3.25,
n=2500,
d=5,
p=1000,
train_percent=0.2)
global_dataset = None
else:
raise ValueError(f'No dataset named: {name}!')
dataset = [ds for ds in dataset]
client_num = min(len(dataset), config.federate.client_num
) if config.federate.client_num > 0 else len(dataset)
config.merge_from_list(['federate.client_num', client_num])
# get local dataset
data_dict = dict()
for client_idx in range(1, len(dataset) + 1):
local_data = dataset[client_idx - 1]
# To undirected and add self-loop
local_data.edge_index = add_self_loops(
to_undirected(remove_self_loops(local_data.edge_index)[0]),
num_nodes=local_data.x.shape[0])[0]
data_dict[client_idx] = {
'data': local_data,
'train': [local_data],
'val': [local_data],
'test': [local_data]
}
# Keep ML split consistent with local graphs
if global_dataset is not None:
global_graph = global_dataset[0]
train_mask = torch.zeros_like(global_graph.train_mask)
val_mask = torch.zeros_like(global_graph.val_mask)
test_mask = torch.zeros_like(global_graph.test_mask)
for client_sampler in data_dict.values():
if isinstance(client_sampler, Data):
client_subgraph = client_sampler
else:
client_subgraph = client_sampler['data']
train_mask[client_subgraph.index_orig[
client_subgraph.train_mask]] = True
val_mask[client_subgraph.index_orig[
client_subgraph.val_mask]] = True
test_mask[client_subgraph.index_orig[
client_subgraph.test_mask]] = True
global_graph.train_mask = train_mask
global_graph.val_mask = val_mask
global_graph.test_mask = test_mask
data_dict[0] = {
'data': global_graph,
'train': [global_graph],
'val': [global_graph],
'test': [global_graph]
}
return data_dict, config