Source code for federatedscope.core.trainers.context

import logging
import collections

from federatedscope.core.auxiliaries.criterion_builder import get_criterion
from federatedscope.core.auxiliaries.model_builder import \
from federatedscope.core.auxiliaries.regularizer_builder import get_regularizer
from federatedscope.core.trainers.enums import MODE
from federatedscope.core.trainers.utils import calculate_batch_epoch_num

logger = logging.getLogger(__name__)

class LifecycleDict(dict):
    """A customized dict that provides lifecycle management
        init_dict: initialized dict
    __delattr__ = dict.__delitem__

    def __getattr__(self, item):
            return self[item]
        except KeyError:
            raise AttributeError("Attribute {} is not found".format(item))

    def __init__(self, init_dict=None):
        if init_dict is not None:
            super(LifecycleDict, self).__init__(init_dict)
        self.lifecycles = collections.defaultdict(set)

    def __setattr__(self, key, value):
        if isinstance(value, CtxVar):
            super(LifecycleDict, self).__setitem__(key, value.obj)
            super(LifecycleDict, self).__setitem__(key, value)

    def clear(self, lifecycle):
        keys = list(self.lifecycles[lifecycle])
        for key in keys:
            if key in self:
                del self[key]

[docs]class Context(LifecycleDict): """ Record and pass variables among different hook functions. Arguments: model: training model cfg: config data (dict): a dict contains train/val/test dataset or dataloader device: running device init_dict (dict): a dict used to initialize the instance of Context init_attr (bool): if set up the static variables Note: - The variables within an instance of class `Context` can be set/get \ as an attribute. ``` ctx.${NAME_VARIABLE} = ${VALUE_VARIABLE} ``` where ``${NAME_VARIABLE}`` and ``${VALUE_VARIABLE}`` is the name and value of the variable. - To achieve automatically lifecycle management, you can \ wrap the variable with ``CtxVar`` and a lifecycle parameter \ as follows ``` ctx.${NAME_VARIABLE} = CtxVar(${VALUE_VARIABLE}, ${LIFECYCLE}) ``` The parameter ``${LIFECYCLE}`` can be chosen from \ ``LIFECYCLE.BATCH``, ``LIFECYCLE.EPOCH`` and ``LIFECYCLE.ROUTINE``. \ Then the variable ``ctx.${NAME_VARIABLE}`` will be deleted at \ the end of the corresponding stage - ``LIFECYCLE.BATCH``: the variables will \ be deleted after running a batch - ``LIFECYCLE.EPOCH``: the variables will be \ deleted after running a epoch - ``LIFECYCLE.ROUTINE``: the variables will be \ deleted after running a routine More details please refer to our [tutorial]( We classify and show the default attributes below: Data-related attributes - ````: the raw data (not split) the trainer holds - ``ctx.num_samples``: the number of samples used in training - ``ctx.train_data``, ``ctx.val_data``, ``ctx.test_data``: the \ split data the trainer holds - ``ctx.train_loader``, ``ctx.val_loader``, ``ctx.test_loader``: \ the DataLoader of each split data - ``ctx.num_train_data``, ``ctx.num_val_data``, \ ``ctx.num_test_data``: the number of samples of the split data \ Model-related attributes - ``ctx.model``: the model used - ``ctx.models``: the multi models if use - ``ctx.mirrored_models``: the mirrored models - ``ctx.trainable_para_names``: the trainable parameter names of \ the model Optimizer-related attributes - ``ctx.optimizer``: see ``torch.optim`` - ``ctx.scheduler``: decays the learning rate of each parameter group - ``ctx.criterion``: loss/criterion function - ``ctx.regularizer``: regular terms - ``ctx.grad_clip``: gradient clipping Mode-related attributes - ``ctx.cur_mode``: mode of trainer, which is one of ``['train', \ 'val', 'test']`` - ``ctx.mode_stack``: stack of mode, only used for switching mode - ``ctx.cur_split``: split of data, which is one of ``['train', \ 'val', 'test']`` (Note: use ``train`` data in ``test`` mode is \ allowed) - ``ctx.split_stack``: stack of split, only used for switching data \ split Metric-related attributes - ``ctx.loss_batch_total``: Loss of current batch - ``ctx.loss_regular_total``: Loss of regular term - ``ctx.y_true``: true label of batch data - ``ctx.y_prob``: output of the model with batch data as input - ``ctx.ys_true``: true label of data - ``ctx.ys_prob``: output of the model - ``ctx.eval_metrics``: evaluation metrics calculated by \ ``ctx.monitor`` - ``ctx.monitor``: used for monitor trainer's behavior and statistics Other (statistics) attributes (@property, query from ``cfg`` if not \ set) - ``ctx.cfg``: configuration of FL course - ``ctx.device``: current device, such as ``cpu`` and ``gpu0``. - ``ctx.num_train_batch_last_epoch``, \ ``ctx.num_total_train_batch``: the number of batch - ``ctx.num_train_epoch``, ``ctx.num_val_epoch``, \ ``ctx.num_test_epoch``: the number of epoch in each data split - ``ctx.num_train_batch``, ``ctx.num_val_batch``, \ ``ctx.num_test_batch``: the number of batch in each data split """ def __init__(self, model, cfg, data=None, device=None): super(Context, self).__init__({}) self.cfg = cfg self.model = model = data self.device = device self.cur_mode = None self.mode_stack = list() self.cur_split = None self.split_stack = list() self.lifecycles = collections.defaultdict(set) # Setup optimize-related context variable if self.cfg.backend == 'torch': self.trainable_para_names = get_trainable_para_names(self.model) # TODO: make `criterion` and `regularizer` @property and cached # to compare whether changes happen self.criterion = get_criterion(self.cfg.criterion.type, self.device) self.regularizer = get_regularizer(self.cfg.regularizer.type) self.grad_clip = self.cfg.grad.grad_clip if self.cfg.federate.process_num > 1: elif self.cfg.backend == 'tensorflow': self.trainable_para_names = self.model.trainable_variables() self.criterion = None self.regularizer = None self.optimizer = None self.grad_clip = None # Train related property, query from `cfg` if not set @property def num_train_batch(self): if self.get('num_train_batch'): return self.get('num_train_batch') return self._calculate_batch_epoch_num(mode='train')[0] @property def num_train_batch_last_epoch(self): if self.get('num_train_batch_last_epoch'): return self.get('num_train_batch_last_epoch') return self._calculate_batch_epoch_num(mode='train')[1] @property def num_train_epoch(self): if self.get('num_train_epoch'): return self.get('num_train_epoch') return self._calculate_batch_epoch_num(mode='train')[2] @property def num_total_train_batch(self): if self.get('num_total_train_batch'): return self.get('num_total_train_batch') return self._calculate_batch_epoch_num(mode='train')[3] # Val related property, query from `cfg` if not set @property def num_val_batch(self): if self.get('num_val_batch'): return self.get('num_val_batch') return self._calculate_batch_epoch_num(mode='val')[0] @property def num_val_epoch(self): if self.get('num_val_epoch'): return self.get('num_val_epoch') return self._calculate_batch_epoch_num(mode='val')[2] # Test related property, query from `cfg` if not set @property def num_test_batch(self): if self.get('num_test_batch'): return self.get('num_test_batch') return self._calculate_batch_epoch_num(mode='test')[0] @property def num_test_epoch(self): if self.get('num_test_epoch'): return self.get('num_test_epoch') return self._calculate_batch_epoch_num(mode='test')[2] def _calculate_batch_epoch_num(self, mode='train'): if self.cur_mode is not None and self.cur_mode != mode: logger.warning( f'cur_mode `{self.cur_mode}` mismatch mode `{mode}`, ' f'will use `{mode}` to calculate `ctx.var`.') if self.cur_split is None: logger.warning( f'cur_split `{self.cur_split}` not found in data_split, ' f'will use `train` split to calculate `ctx.var`.') cur_split = 'train' else: cur_split = self.cur_split num_batch_last_epoch, num_total_batch = None, None if mode in ['train', 'finetune']: num_batch, num_batch_last_epoch, num_epoch, num_total_batch = \ calculate_batch_epoch_num( self.cfg.train.local_update_steps * self.cfg.grad.grad_accum_count, self.cfg.train.batch_or_epoch, self.get(f'num_{cur_split}_data'), self.cfg.dataloader.batch_size, self.cfg.dataloader.drop_last) elif mode in ['val', 'test']: num_epoch = 1 num_batch = self.get(f'num_{cur_split}_data' ) // self.cfg.dataloader.batch_size + int( not self.cfg.dataloader.drop_last and bool( self.get(f'num_{cur_split}_data') % self.cfg.dataloader.batch_size)) else: raise ValueError(f'Invalid mode {mode}.') return num_batch, num_batch_last_epoch, num_epoch, num_total_batch def track_mode(self, mode): self.mode_stack.append(mode) self.cur_mode = self.mode_stack[-1] self.change_mode(self.cur_mode) def reset_mode(self): self.mode_stack.pop() self.cur_mode = self.mode_stack[-1] if len( self.mode_stack) != 0 else None if len(self.mode_stack) != 0: self.change_mode(self.cur_mode) def change_mode(self, mode): # change state if self.cfg.backend == 'torch': getattr( self.model, 'train' if mode == MODE.TRAIN or mode == MODE.FINETUNE else 'eval')() else: pass def track_split(self, dataset): # stack-style to enable mixture usage such as evaluation on train # dataset self.split_stack.append(dataset) self.cur_split = self.split_stack[-1] def reset_split(self): self.split_stack.pop() self.cur_split = self.split_stack[-1] if \ len(self.split_stack) != 0 else None def check_split(self, target_split_name, skip=False): if self.get(f"{target_split_name}_data") is None and self.get( f"{target_split_name}_loader") is None: if skip: logger.warning( f"No {target_split_name}_data or" f" {target_split_name}_loader in the trainer, " f"will skip evaluation." f"If this is not the case you want, please check " f"whether there is typo for the name") return False else: raise ValueError(f"No {target_split_name}_data or" f" {target_split_name}_loader in the trainer") else: return True def merge_from_dict(self, other_dict): for key, value in other_dict.items(): setattr(self, key, value)
class CtxVar(object): """ Basic variable class Arguments: lifecycle: specific lifecycle of the attribute """ LIFECYCLES = ["batch", "epoch", "routine", None] def __init__(self, obj, lifecycle=None): assert lifecycle in CtxVar.LIFECYCLES self.obj = obj self.lifecycle = lifecycle def lifecycle(lifecycle): """ Manage the lifecycle of the variables within context, \ and blind these operations from user. Arguments: lifecycle: the type of lifecycle, choose from "batch/epoch/routine" """ if lifecycle == "routine": def decorate(func): def wrapper(self, mode, hooks_set, dataset_name=None): self.ctx.track_mode(mode) self.ctx.track_split(dataset_name or mode) res = func(self, mode, hooks_set, dataset_name) # Clear the variables at the end of lifecycles self.ctx.clear(lifecycle) # rollback the model and data_split self.ctx.reset_mode() self.ctx.reset_split() # Move the model into CPU to avoid memory leak self.discharge_model() return res return wrapper else: def decorate(func): def wrapper(self, *args, **kwargs): res = func(self, *args, **kwargs) # Clear the variables at the end of lifecycles self.ctx.clear(lifecycle) return res return wrapper return decorate