Source code for federatedscope.core.auxiliaries.dataloader_builder

from federatedscope.core.data.utils import filter_dict

try:
    import torch
    from torch.utils.data import Dataset
except ImportError:
    torch = None
    Dataset = object


[docs]def get_dataloader(dataset, config, split='train'): """ Instantiate a DataLoader via config. Args: dataset: dataset from which to load the data. config: configs containing batch_size, shuffle, etc. split: current split (default: ``train``), if split is ``test``, \ ``cfg.dataloader.shuffle`` will be ``False``. And in PyG, ``test`` \ split will use ``NeighborSampler`` by default. Returns: Instance of specific ``DataLoader`` configured by config. Note: The key-value pairs of ``dataloader.type`` and ``DataLoader``: ======================== =============================== ``dataloader.type`` Source ======================== =============================== ``raw`` No DataLoader ``base`` ``torch.utils.data.DataLoader`` ``pyg`` ``torch_geometric.loader.DataLoader`` ``graphsaint-rw`` \ ``torch_geometric.loader.GraphSAINTRandomWalkSampler`` ``neighbor`` ``torch_geometric.loader.NeighborSampler`` ``mf`` ``federatedscope.mf.dataloader.MFDataLoader`` ======================== =============================== """ # DataLoader builder only support torch backend now. if config.backend != 'torch': return None if config.dataloader.type == 'base': from torch.utils.data import DataLoader loader_cls = DataLoader elif config.dataloader.type == 'raw': # No DataLoader return dataset elif config.dataloader.type == 'pyg': from torch_geometric.loader import DataLoader as PyGDataLoader loader_cls = PyGDataLoader elif config.dataloader.type == 'graphsaint-rw': if split == 'train': from torch_geometric.loader import GraphSAINTRandomWalkSampler loader_cls = GraphSAINTRandomWalkSampler else: from torch_geometric.loader import NeighborSampler loader_cls = NeighborSampler elif config.dataloader.type == 'neighbor': from torch_geometric.loader import NeighborSampler loader_cls = NeighborSampler elif config.dataloader.type == 'mf': from federatedscope.mf.dataloader import MFDataLoader loader_cls = MFDataLoader else: raise ValueError(f'data.loader.type {config.data.loader.type} ' f'not found!') raw_args = dict(config.dataloader) if split != 'train': raw_args['shuffle'] = False raw_args['sizes'] = [-1] raw_args['drop_last'] = False # For evaluation in GFL if config.dataloader.type in ['graphsaint-rw', 'neighbor']: raw_args['batch_size'] = 4096 dataset = dataset[0].edge_index else: if config.dataloader.type in ['graphsaint-rw']: # Raw graph dataset = dataset[0] elif config.dataloader.type in ['neighbor']: # edge_index of raw graph dataset = dataset[0].edge_index filtered_args = filter_dict(loader_cls.__init__, raw_args) dataloader = loader_cls(dataset, **filtered_args) return dataloader