Source code for federatedscope.gfl.dataloader.dataloader_graph

from torch_geometric import transforms
from torch_geometric.datasets import TUDataset, MoleculeNet

from federatedscope.core.auxiliaries.transform_builder import get_transform
from federatedscope.gfl.dataset.cikm_cup import CIKMCUPDataset


[docs]def load_graphlevel_dataset(config=None): r"""Convert dataset to Dataloader. :returns: data_local_dict :rtype: Dict { 'client_id': { 'train': DataLoader(), 'val': DataLoader(), 'test': DataLoader() } } """ splits = config.data.splits path = config.data.root name = config.data.type.upper() # Transforms transforms_funcs, _, _ = get_transform(config, 'torch_geometric') if name in [ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1', 'ENZYMES', 'DD', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY' ]: # Add feat for datasets without attrubute if name in ['IMDB-BINARY', 'IMDB-MULTI' ] and 'pre_transform' not in transforms_funcs: transforms_funcs['pre_transform'] = transforms.Constant(value=1.0, cat=False) dataset = TUDataset(path, name, **transforms_funcs) elif name in [ 'HIV', 'ESOL', 'FREESOLV', 'LIPO', 'PCBA', 'MUV', 'BACE', 'BBBP', 'TOX21', 'TOXCAST', 'SIDER', 'CLINTOX' ]: dataset = MoleculeNet(path, name, **transforms_funcs) return dataset, config elif name.startswith('graph_multi_domain'.upper()): """ The `graph_multi_domain` datasets follows GCFL Federated Graph Classification over Non-IID Graphs (NeurIPS 2021) """ if name.endswith('mol'.upper()): dnames = ['MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1'] elif name.endswith('small'.upper()): dnames = [ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'ENZYMES', 'DD', 'PROTEINS' ] elif name.endswith('mix'.upper()): if 'pre_transform' not in transforms_funcs: raise ValueError('pre_transform is None!') dnames = [ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1', 'ENZYMES', 'DD', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'IMDB-MULTI' ] elif name.endswith('biochem'.upper()): dnames = [ 'MUTAG', 'BZR', 'COX2', 'DHFR', 'PTC_MR', 'AIDS', 'NCI1', 'ENZYMES', 'DD', 'PROTEINS' ] else: raise ValueError(f'No dataset named: {name}!') dataset = [] # Some datasets contain x for dname in dnames: if dname.startswith('IMDB') or dname == 'COLLAB': tmp_dataset = TUDataset(path, dname, **transforms_funcs) else: tmp_dataset = TUDataset( path, dname, pre_transform=None, transform=transforms_funcs['transform'] if 'transform' in transforms_funcs else None) dataset.append(tmp_dataset) elif name == 'CIKM': dataset = CIKMCUPDataset(config.data.root) else: raise ValueError(f'No dataset named: {name}!') client_num = min(len(dataset), config.federate.client_num ) if config.federate.client_num > 0 else len(dataset) config.merge_from_list(['federate.client_num', client_num]) # get local dataset data_dict = dict() for client_idx in range(1, len(dataset) + 1): data_dict[client_idx] = dataset[client_idx - 1] return data_dict, config