import logging
import numpy as np
from federatedscope.core.auxiliaries.splitter_builder import get_splitter
from federatedscope.core.data import ClientData, StandaloneDataDict
logger = logging.getLogger(__name__)
[docs]class BaseDataTranslator:
"""
Translator is a tool to convert a centralized dataset to \
``StandaloneDataDict``, which is the input data of runner.
Notes:
The ``Translator`` is consist of several stages:
Dataset -> ML split (``split_train_val_test()``) -> \
FL split (``split_to_client()``) -> ``StandaloneDataDict``
"""
def __init__(self, global_cfg, client_cfgs=None):
"""
Convert data to `StandaloneDataDict`.
Args:
global_cfg: global CfgNode
client_cfgs: client cfg `Dict`
"""
self.global_cfg = global_cfg
self.client_cfgs = client_cfgs
self.splitter = get_splitter(global_cfg)
def __call__(self, dataset):
"""
Args:
dataset: `torch.utils.data.Dataset`, `List` of (feature, label)
or split dataset tuple of (train, val, test) or Tuple of
split dataset with [train, val, test]
Returns:
datadict: instance of `StandaloneDataDict`, which is a subclass of
`dict`.
"""
datadict = self.split(dataset)
datadict = StandaloneDataDict(datadict, self.global_cfg)
return datadict
[docs] def split(self, dataset):
"""
Perform ML split and FL split.
Returns:
dict of ``ClientData`` with client_idx as key to build \
``StandaloneDataDict``
"""
train, val, test = self.split_train_val_test(dataset)
datadict = self.split_to_client(train, val, test)
return datadict
[docs] def split_train_val_test(self, dataset, cfg=None):
"""
Split dataset to train, val, test if not provided.
Returns:
List: List of split dataset, like ``[train, val, test]``
"""
from torch.utils.data import Dataset, Subset
if cfg is not None:
splits = cfg.data.splits
else:
splits = self.global_cfg.data.splits
if isinstance(dataset, tuple):
# No need to split train/val/test for tuple dataset.
error_msg = 'If dataset is tuple, it must contains ' \
'train, valid and test split.'
assert len(dataset) == len(['train', 'val', 'test']), error_msg
return [dataset[0], dataset[1], dataset[2]]
index = np.random.permutation(np.arange(len(dataset)))
train_size = int(splits[0] * len(dataset))
val_size = int(splits[1] * len(dataset))
if isinstance(dataset, Dataset):
train_dataset = Subset(dataset, index[:train_size])
val_dataset = Subset(dataset,
index[train_size:train_size + val_size])
test_dataset = Subset(dataset, index[train_size + val_size:])
else:
train_dataset = [dataset[x] for x in index[:train_size]]
val_dataset = [
dataset[x] for x in index[train_size:train_size + val_size]
]
test_dataset = [dataset[x] for x in index[train_size + val_size:]]
return train_dataset, val_dataset, test_dataset
[docs] def split_to_client(self, train, val, test):
"""
Split dataset to clients and build ``ClientData``.
Returns:
dict: dict of ``ClientData`` with ``client_idx`` as key.
"""
# Initialization
client_num = self.global_cfg.federate.client_num
split_train, split_val, split_test = [[None] * client_num] * 3
train_label_distribution = None
# Split train/val/test to client
if len(train) > 0:
split_train = self.splitter(train)
if self.global_cfg.data.consistent_label_distribution:
try:
train_label_distribution = [[j[1] for j in x]
for x in split_train]
except:
logger.warning(
'Cannot access train label distribution for '
'splitter.')
if len(val) > 0:
split_val = self.splitter(val, prior=train_label_distribution)
if len(test) > 0:
split_test = self.splitter(test, prior=train_label_distribution)
# Build data dict with `ClientData`, key `0` for server.
data_dict = {
0: ClientData(self.global_cfg, train=train, val=val, test=test)
}
for client_id in range(1, client_num + 1):
if self.client_cfgs is not None:
client_cfg = self.global_cfg.clone()
client_cfg.merge_from_other_cfg(
self.client_cfgs.get(f'client_{client_id}'))
else:
client_cfg = self.global_cfg
data_dict[client_id] = ClientData(client_cfg,
train=split_train[client_id - 1],
val=split_val[client_id - 1],
test=split_test[client_id - 1])
return data_dict