Source code for federatedscope.core.data.base_translator

import logging
import numpy as np

from federatedscope.core.auxiliaries.splitter_builder import get_splitter
from federatedscope.core.data import ClientData, StandaloneDataDict

logger = logging.getLogger(__name__)


[docs]class BaseDataTranslator: """ Translator is a tool to convert a centralized dataset to \ ``StandaloneDataDict``, which is the input data of runner. Notes: The ``Translator`` is consist of several stages: Dataset -> ML split (``split_train_val_test()``) -> \ FL split (``split_to_client()``) -> ``StandaloneDataDict`` """ def __init__(self, global_cfg, client_cfgs=None): """ Convert data to `StandaloneDataDict`. Args: global_cfg: global CfgNode client_cfgs: client cfg `Dict` """ self.global_cfg = global_cfg self.client_cfgs = client_cfgs self.splitter = get_splitter(global_cfg) def __call__(self, dataset): """ Args: dataset: `torch.utils.data.Dataset`, `List` of (feature, label) or split dataset tuple of (train, val, test) or Tuple of split dataset with [train, val, test] Returns: datadict: instance of `StandaloneDataDict`, which is a subclass of `dict`. """ datadict = self.split(dataset) datadict = StandaloneDataDict(datadict, self.global_cfg) return datadict
[docs] def split(self, dataset): """ Perform ML split and FL split. Returns: dict of ``ClientData`` with client_idx as key to build \ ``StandaloneDataDict`` """ train, val, test = self.split_train_val_test(dataset) datadict = self.split_to_client(train, val, test) return datadict
[docs] def split_train_val_test(self, dataset, cfg=None): """ Split dataset to train, val, test if not provided. Returns: List: List of split dataset, like ``[train, val, test]`` """ from torch.utils.data import Dataset, Subset if cfg is not None: splits = cfg.data.splits else: splits = self.global_cfg.data.splits if isinstance(dataset, tuple): # No need to split train/val/test for tuple dataset. error_msg = 'If dataset is tuple, it must contains ' \ 'train, valid and test split.' assert len(dataset) == len(['train', 'val', 'test']), error_msg return [dataset[0], dataset[1], dataset[2]] index = np.random.permutation(np.arange(len(dataset))) train_size = int(splits[0] * len(dataset)) val_size = int(splits[1] * len(dataset)) if isinstance(dataset, Dataset): train_dataset = Subset(dataset, index[:train_size]) val_dataset = Subset(dataset, index[train_size:train_size + val_size]) test_dataset = Subset(dataset, index[train_size + val_size:]) else: train_dataset = [dataset[x] for x in index[:train_size]] val_dataset = [ dataset[x] for x in index[train_size:train_size + val_size] ] test_dataset = [dataset[x] for x in index[train_size + val_size:]] return train_dataset, val_dataset, test_dataset
[docs] def split_to_client(self, train, val, test): """ Split dataset to clients and build ``ClientData``. Returns: dict: dict of ``ClientData`` with ``client_idx`` as key. """ # Initialization client_num = self.global_cfg.federate.client_num split_train, split_val, split_test = [[None] * client_num] * 3 train_label_distribution = None # Split train/val/test to client if len(train) > 0: split_train = self.splitter(train) if self.global_cfg.data.consistent_label_distribution: try: train_label_distribution = [[j[1] for j in x] for x in split_train] except: logger.warning( 'Cannot access train label distribution for ' 'splitter.') if len(val) > 0: split_val = self.splitter(val, prior=train_label_distribution) if len(test) > 0: split_test = self.splitter(test, prior=train_label_distribution) # Build data dict with `ClientData`, key `0` for server. data_dict = { 0: ClientData(self.global_cfg, train=train, val=val, test=test) } for client_id in range(1, client_num + 1): if self.client_cfgs is not None: client_cfg = self.global_cfg.clone() client_cfg.merge_from_other_cfg( self.client_cfgs.get(f'client_{client_id}')) else: client_cfg = self.global_cfg data_dict[client_id] = ClientData(client_cfg, train=split_train[client_id - 1], val=split_val[client_id - 1], test=split_test[client_id - 1]) return data_dict