Source code for federatedscope.core.data.dummy_translator

from federatedscope.core.data.base_translator import BaseDataTranslator
from federatedscope.core.data.base_data import ClientData


[docs]class DummyDataTranslator(BaseDataTranslator): """ ``DummyDataTranslator`` convert datadict to ``StandaloneDataDict``. \ Compared to ``core.data.base_translator.BaseDataTranslator``, it do not \ perform FL split. """
[docs] def split(self, dataset): """ Perform ML split Returns: dict of ``ClientData`` with client_idx as key to build \ ``StandaloneDataDict`` """ if not isinstance(dataset, dict): raise TypeError(f'Not support data type {type(dataset)}') datadict = {} for client_id in dataset.keys(): if self.client_cfgs is not None: client_cfg = self.global_cfg.clone() client_cfg.merge_from_other_cfg( self.client_cfgs.get(f'client_{client_id}')) else: client_cfg = self.global_cfg if isinstance(dataset[client_id], dict): datadict[client_id] = ClientData(client_cfg, **dataset[client_id]) else: # Do not have train/val/test train, val, test = self.split_train_val_test( dataset[client_id], client_cfg) tmp_dict = dict(train=train, val=val, test=test) # Only for graph-level task, get number of graph labels if client_cfg.model.task.startswith('graph') and \ client_cfg.model.out_channels == 0: s = set() for g in dataset[client_id]: s.add(g.y.item()) tmp_dict['num_label'] = len(s) datadict[client_id] = ClientData(client_cfg, **tmp_dict) return datadict