import logging
import collections
from federatedscope.core.auxiliaries.criterion_builder import get_criterion
from federatedscope.core.auxiliaries.model_builder import \
get_trainable_para_names
from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
from federatedscope.core.trainers.enums import MODE
from federatedscope.core.trainers.utils import calculate_batch_epoch_num
logger = logging.getLogger(__name__)
class LifecycleDict(dict):
"""A customized dict that provides lifecycle management
Arguments:
init_dict: initialized dict
"""
__delattr__ = dict.__delitem__
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError("Attribute {} is not found".format(item))
def __init__(self, init_dict=None):
if init_dict is not None:
super(LifecycleDict, self).__init__(init_dict)
self.lifecycles = collections.defaultdict(set)
def __setattr__(self, key, value):
if isinstance(value, CtxVar):
self.lifecycles[value.lifecycle].add(key)
super(LifecycleDict, self).__setitem__(key, value.obj)
else:
super(LifecycleDict, self).__setitem__(key, value)
def clear(self, lifecycle):
keys = list(self.lifecycles[lifecycle])
for key in keys:
if key in self:
del self[key]
self.lifecycles[lifecycle].remove(key)
[docs]class Context(LifecycleDict):
"""
Record and pass variables among different hook functions.
Arguments:
model: training model
cfg: config
data (dict): a dict contains train/val/test dataset or dataloader
device: running device
init_dict (dict): a dict used to initialize the instance of Context
init_attr (bool): if set up the static variables
Note:
- The variables within an instance of class `Context` can be set/get \
as an attribute.
```
ctx.${NAME_VARIABLE} = ${VALUE_VARIABLE}
```
where ``${NAME_VARIABLE}`` and ``${VALUE_VARIABLE}``
is the name and value of the variable.
- To achieve automatically lifecycle management, you can \
wrap the variable with ``CtxVar`` and a lifecycle parameter \
as follows
```
ctx.${NAME_VARIABLE} = CtxVar(${VALUE_VARIABLE}, ${LIFECYCLE})
```
The parameter ``${LIFECYCLE}`` can be chosen from \
``LIFECYCLE.BATCH``, ``LIFECYCLE.EPOCH`` and ``LIFECYCLE.ROUTINE``. \
Then the variable ``ctx.${NAME_VARIABLE}`` will be deleted at \
the end of the corresponding stage
- ``LIFECYCLE.BATCH``: the variables will \
be deleted after running a batch
- ``LIFECYCLE.EPOCH``: the variables will be \
deleted after running a epoch
- ``LIFECYCLE.ROUTINE``: the variables will be \
deleted after running a routine
More details please refer to our
[tutorial](https://federatedscope.io/docs/trainer/).
We classify and show the default attributes below:
Data-related attributes
- ``ctx.data``: the raw data (not split) the trainer holds
- ``ctx.num_samples``: the number of samples used in training
- ``ctx.train_data``, ``ctx.val_data``, ``ctx.test_data``: the \
split data the trainer holds
- ``ctx.train_loader``, ``ctx.val_loader``, ``ctx.test_loader``: \
the DataLoader of each split data
- ``ctx.num_train_data``, ``ctx.num_val_data``, \
``ctx.num_test_data``: the number of samples of the split data \
Model-related attributes
- ``ctx.model``: the model used
- ``ctx.models``: the multi models if use
- ``ctx.mirrored_models``: the mirrored models
- ``ctx.trainable_para_names``: the trainable parameter names of \
the model
Optimizer-related attributes
- ``ctx.optimizer``: see ``torch.optim``
- ``ctx.scheduler``: decays the learning rate of each parameter group
- ``ctx.criterion``: loss/criterion function
- ``ctx.regularizer``: regular terms
- ``ctx.grad_clip``: gradient clipping
Mode-related attributes
- ``ctx.cur_mode``: mode of trainer, which is one of ``['train', \
'val', 'test']``
- ``ctx.mode_stack``: stack of mode, only used for switching mode
- ``ctx.cur_split``: split of data, which is one of ``['train', \
'val', 'test']`` (Note: use ``train`` data in ``test`` mode is \
allowed)
- ``ctx.split_stack``: stack of split, only used for switching data \
split
Metric-related attributes
- ``ctx.loss_batch_total``: Loss of current batch
- ``ctx.loss_regular_total``: Loss of regular term
- ``ctx.y_true``: true label of batch data
- ``ctx.y_prob``: output of the model with batch data as input
- ``ctx.ys_true``: true label of data
- ``ctx.ys_prob``: output of the model
- ``ctx.eval_metrics``: evaluation metrics calculated by \
``ctx.monitor``
- ``ctx.monitor``: used for monitor trainer's behavior and statistics
Other (statistics) attributes (@property, query from ``cfg`` if not \
set)
- ``ctx.cfg``: configuration of FL course
- ``ctx.device``: current device, such as ``cpu`` and ``gpu0``.
- ``ctx.num_train_batch_last_epoch``, \
``ctx.num_total_train_batch``: the number of batch
- ``ctx.num_train_epoch``, ``ctx.num_val_epoch``, \
``ctx.num_test_epoch``: the number of epoch in each data split
- ``ctx.num_train_batch``, ``ctx.num_val_batch``, \
``ctx.num_test_batch``: the number of batch in each data split
"""
def __init__(self, model, cfg, data=None, device=None):
super(Context, self).__init__({})
self.cfg = cfg
self.model = model
self.data = data
self.device = device
self.cur_mode = None
self.mode_stack = list()
self.cur_split = None
self.split_stack = list()
self.lifecycles = collections.defaultdict(set)
# Setup optimize-related context variable
if self.cfg.backend == 'torch':
self.trainable_para_names = get_trainable_para_names(self.model)
# TODO: make `criterion` and `regularizer` @property and cached
# to compare whether changes happen
self.criterion = get_criterion(self.cfg.criterion.type,
self.device)
self.regularizer = get_regularizer(self.cfg.regularizer.type)
self.grad_clip = self.cfg.grad.grad_clip
if self.cfg.federate.process_num > 1:
self.model.to(self.device)
elif self.cfg.backend == 'tensorflow':
self.trainable_para_names = self.model.trainable_variables()
self.criterion = None
self.regularizer = None
self.optimizer = None
self.grad_clip = None
# Train related property, query from `cfg` if not set
@property
def num_train_batch(self):
if self.get('num_train_batch'):
return self.get('num_train_batch')
return self._calculate_batch_epoch_num(mode='train')[0]
@property
def num_train_batch_last_epoch(self):
if self.get('num_train_batch_last_epoch'):
return self.get('num_train_batch_last_epoch')
return self._calculate_batch_epoch_num(mode='train')[1]
@property
def num_train_epoch(self):
if self.get('num_train_epoch'):
return self.get('num_train_epoch')
return self._calculate_batch_epoch_num(mode='train')[2]
@property
def num_total_train_batch(self):
if self.get('num_total_train_batch'):
return self.get('num_total_train_batch')
return self._calculate_batch_epoch_num(mode='train')[3]
# Val related property, query from `cfg` if not set
@property
def num_val_batch(self):
if self.get('num_val_batch'):
return self.get('num_val_batch')
return self._calculate_batch_epoch_num(mode='val')[0]
@property
def num_val_epoch(self):
if self.get('num_val_epoch'):
return self.get('num_val_epoch')
return self._calculate_batch_epoch_num(mode='val')[2]
# Test related property, query from `cfg` if not set
@property
def num_test_batch(self):
if self.get('num_test_batch'):
return self.get('num_test_batch')
return self._calculate_batch_epoch_num(mode='test')[0]
@property
def num_test_epoch(self):
if self.get('num_test_epoch'):
return self.get('num_test_epoch')
return self._calculate_batch_epoch_num(mode='test')[2]
def _calculate_batch_epoch_num(self, mode='train'):
if self.cur_mode is not None and self.cur_mode != mode:
logger.warning(
f'cur_mode `{self.cur_mode}` mismatch mode `{mode}`, '
f'will use `{mode}` to calculate `ctx.var`.')
if self.cur_split is None:
logger.warning(
f'cur_split `{self.cur_split}` not found in data_split, '
f'will use `train` split to calculate `ctx.var`.')
cur_split = 'train'
else:
cur_split = self.cur_split
num_batch_last_epoch, num_total_batch = None, None
if mode in ['train', 'finetune']:
num_batch, num_batch_last_epoch, num_epoch, num_total_batch = \
calculate_batch_epoch_num(
self.cfg.train.local_update_steps *
self.cfg.grad.grad_accum_count,
self.cfg.train.batch_or_epoch,
self.get(f'num_{cur_split}_data'),
self.cfg.dataloader.batch_size,
self.cfg.dataloader.drop_last)
elif mode in ['val', 'test']:
num_epoch = 1
num_batch = self.get(f'num_{cur_split}_data'
) // self.cfg.dataloader.batch_size + int(
not self.cfg.dataloader.drop_last
and bool(
self.get(f'num_{cur_split}_data') %
self.cfg.dataloader.batch_size))
else:
raise ValueError(f'Invalid mode {mode}.')
return num_batch, num_batch_last_epoch, num_epoch, num_total_batch
def track_mode(self, mode):
self.mode_stack.append(mode)
self.cur_mode = self.mode_stack[-1]
self.change_mode(self.cur_mode)
def reset_mode(self):
self.mode_stack.pop()
self.cur_mode = self.mode_stack[-1] if len(
self.mode_stack) != 0 else None
if len(self.mode_stack) != 0:
self.change_mode(self.cur_mode)
def change_mode(self, mode):
# change state
if self.cfg.backend == 'torch':
getattr(
self.model, 'train'
if mode == MODE.TRAIN or mode == MODE.FINETUNE else 'eval')()
else:
pass
def track_split(self, dataset):
# stack-style to enable mixture usage such as evaluation on train
# dataset
self.split_stack.append(dataset)
self.cur_split = self.split_stack[-1]
def reset_split(self):
self.split_stack.pop()
self.cur_split = self.split_stack[-1] if \
len(self.split_stack) != 0 else None
def check_split(self, target_split_name, skip=False):
if self.get(f"{target_split_name}_data") is None and self.get(
f"{target_split_name}_loader") is None:
if skip:
logger.warning(
f"No {target_split_name}_data or"
f" {target_split_name}_loader in the trainer, "
f"will skip evaluation."
f"If this is not the case you want, please check "
f"whether there is typo for the name")
return False
else:
raise ValueError(f"No {target_split_name}_data or"
f" {target_split_name}_loader in the trainer")
else:
return True
def merge_from_dict(self, other_dict):
for key, value in other_dict.items():
setattr(self, key, value)
class CtxVar(object):
"""
Basic variable class
Arguments:
lifecycle: specific lifecycle of the attribute
"""
LIFECYCLES = ["batch", "epoch", "routine", None]
def __init__(self, obj, lifecycle=None):
assert lifecycle in CtxVar.LIFECYCLES
self.obj = obj
self.lifecycle = lifecycle
def lifecycle(lifecycle):
"""
Manage the lifecycle of the variables within context, \
and blind these operations from user.
Arguments:
lifecycle: the type of lifecycle, choose from "batch/epoch/routine"
"""
if lifecycle == "routine":
def decorate(func):
def wrapper(self, mode, hooks_set, dataset_name=None):
self.ctx.track_mode(mode)
self.ctx.track_split(dataset_name or mode)
res = func(self, mode, hooks_set, dataset_name)
# Clear the variables at the end of lifecycles
self.ctx.clear(lifecycle)
# rollback the model and data_split
self.ctx.reset_mode()
self.ctx.reset_split()
# Move the model into CPU to avoid memory leak
self.discharge_model()
return res
return wrapper
else:
def decorate(func):
def wrapper(self, *args, **kwargs):
res = func(self, *args, **kwargs)
# Clear the variables at the end of lifecycles
self.ctx.clear(lifecycle)
return res
return wrapper
return decorate