import logging
from scipy.sparse.csc import csc_matrix
from federatedscope.core.data.utils import merge_data
from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
logger = logging.getLogger(__name__)
[docs]class StandaloneDataDict(dict):
"""
``StandaloneDataDict`` maintain several ``ClientData``, only used in \
``Standalone`` mode to be passed to ``Runner``, which will conduct \
several preprocess based on ``global_cfg``, see ``preprocess()`` \
for details.
Args:
datadict: ``Dict`` with ``client_id`` as key, ``ClientData`` as value.
global_cfg: global ``CfgNode``
"""
def __init__(self, datadict, global_cfg):
"""
Args:
datadict: `Dict` with `client_id` as key, `ClientData` as value.
global_cfg: global CfgNode
"""
self.global_cfg = global_cfg
self.client_cfgs = None
datadict = self.preprocess(datadict)
super(StandaloneDataDict, self).__init__(datadict)
[docs] def resetup(self, global_cfg, client_cfgs=None):
"""
Reset-up new configs for ``ClientData``, when the configs change \
which might be used in HPO.
Args:
global_cfg: enable new config for ``ClientData``
client_cfgs: enable new client-specific config for ``ClientData``
"""
self.global_cfg, self.client_cfgs = global_cfg, client_cfgs
for client_id, client_data in self.items():
if isinstance(client_data, ClientData):
if client_cfgs is not None:
client_cfg = global_cfg.clone()
client_cfg.merge_from_other_cfg(
client_cfgs.get(f'client_{client_id}'))
else:
client_cfg = global_cfg
client_data.setup(client_cfg)
else:
logger.warning('`client_data` is not subclass of '
'`ClientData`, and cannot re-setup '
'DataLoader with new configs.')
[docs] def preprocess(self, datadict):
"""
Preprocess for:
(1) Global evaluation (merge test data).
(2) Global mode (train with centralized setting, merge all data).
(3) Apply data attack algorithms.
Args:
datadict: dict with `client_id` as key, `ClientData` as value.
"""
if self.global_cfg.federate.merge_test_data:
merge_split = ['test']
if self.global_cfg.federate.merge_val_data:
merge_split += ['val']
server_data = merge_data(
all_data=datadict,
merged_max_data_id=self.global_cfg.federate.client_num,
specified_dataset_name=merge_split)
# `0` indicate Server
datadict[0] = ClientData(self.global_cfg, **server_data)
if self.global_cfg.federate.method == "global":
if self.global_cfg.federate.client_num != 1:
if self.global_cfg.data.server_holds_all:
assert datadict[0] is not None \
and len(datadict[0]) != 0, \
"You specified cfg.data.server_holds_all=True " \
"but data[0] is None. Please check whether you " \
"pre-process the data[0] correctly"
datadict[1] = datadict[0]
else:
logger.info(f"Will merge data from clients whose ids in "
f"[1, {self.global_cfg.federate.client_num}]")
merged_data = merge_data(
all_data=datadict,
merged_max_data_id=self.global_cfg.federate.client_num)
datadict[1] = ClientData(self.global_cfg, **merged_data)
datadict = self.attack(datadict)
return datadict
[docs] def attack(self, datadict):
"""
Apply attack to ``StandaloneDataDict``.
"""
if 'backdoor' in self.global_cfg.attack.attack_method and 'edge' in \
self.global_cfg.attack.trigger_type:
import os
import torch
from federatedscope.attack.auxiliary import \
create_ardis_poisoned_dataset, create_ardis_test_dataset
if not os.path.exists(self.global_cfg.attack.edge_path):
os.makedirs(self.global_cfg.attack.edge_path)
poisoned_edgeset = create_ardis_poisoned_dataset(
data_path=self.global_cfg.attack.edge_path)
ardis_test_dataset = create_ardis_test_dataset(
self.global_cfg.attack.edge_path)
logger.info("Writing poison_data to: {}".format(
self.global_cfg.attack.edge_path))
with open(
self.global_cfg.attack.edge_path +
"poisoned_edgeset_training", "wb") as saved_data_file:
torch.save(poisoned_edgeset, saved_data_file)
with open(
self.global_cfg.attack.edge_path +
"ardis_test_dataset.pt", "wb") as ardis_data_file:
torch.save(ardis_test_dataset, ardis_data_file)
logger.warning(
'please notice: downloading the poisoned dataset \
on cifar-10 from \
https://github.com/ksreenivasan/OOD_Federated_Learning'
)
if 'backdoor' in self.global_cfg.attack.attack_method:
from federatedscope.attack.auxiliary import poisoning
poisoning(datadict, self.global_cfg)
return datadict
[docs]class ClientData(dict):
"""
``ClientData`` converts split data to ``DataLoader``.
Args:
loader: ``Dataloader`` class or data dict which have been built
client_cfg: client-specific ``CfgNode``
data: raw dataset, which will stay raw
train: train dataset, which will be converted to ``Dataloader``
val: valid dataset, which will be converted to ``Dataloader``
test: test dataset, which will be converted to ``Dataloader``
Note:
Key ``{split}_data`` in ``ClientData`` is the raw dataset.
Key ``{split}`` in ``ClientData`` is the dataloader.
"""
SPLIT_NAMES = ['train', 'val', 'test']
def __init__(self, client_cfg, train=None, val=None, test=None, **kwargs):
self.client_cfg = None
self.train_data = train
self.val_data = val
self.test_data = test
self.setup(client_cfg)
if kwargs is not None:
for key in kwargs:
self[key] = kwargs[key]
super(ClientData, self).__init__()
[docs] def setup(self, new_client_cfg=None):
"""
Set up ``DataLoader`` in ``ClientData`` with new configurations.
Args:
new_client_cfg: new client-specific CfgNode
Returns:
Bool: Status for indicating whether the client_cfg is updated
"""
# if `batch_size` or `shuffle` change, re-instantiate DataLoader
if self.client_cfg is not None:
if dict(self.client_cfg.dataloader) == dict(
new_client_cfg.dataloader):
return False
self.client_cfg = new_client_cfg
for split_data, split_name in zip(
[self.train_data, self.val_data, self.test_data],
self.SPLIT_NAMES):
if split_data is not None:
# csc_matrix does not have ``__len__`` attributes
if isinstance(split_data, csc_matrix):
self[split_name] = get_dataloader(split_data,
self.client_cfg,
split_name)
elif len(split_data) > 0:
self[split_name] = get_dataloader(split_data,
self.client_cfg,
split_name)
return True