import collections
import copy
import logging
from federatedscope.core.trainers.base_trainer import BaseTrainer
from federatedscope.core.trainers.enums import MODE, LIFECYCLE
from federatedscope.core.auxiliaries.decorators import use_diff
from federatedscope.core.trainers.utils import format_log_hooks, \
filter_by_specified_keywords
from federatedscope.core.trainers.context import Context, CtxVar, lifecycle
logger = logging.getLogger(__name__)
[docs]class Trainer(BaseTrainer):
"""
Register, organize and run the train/test/val procedures
"""
HOOK_TRIGGER = [
"on_fit_start", "on_epoch_start", "on_batch_start", "on_batch_forward",
"on_batch_backward", "on_batch_end", "on_epoch_end", "on_fit_end"
]
def __init__(self,
model,
data,
device,
config,
only_for_eval=False,
monitor=None):
self._cfg = config
self.ctx = Context(model, self.cfg, data, device)
# Parse data and setup init vars in ctx
self._setup_data_related_var_in_ctx(self.ctx)
assert monitor is not None, \
f"Monitor not found in trainer with class {type(self)}"
self.ctx.monitor = monitor
# the "model_nums", and "models" are used for multi-model case and
# model size calculation
self.model_nums = 1
self.ctx.models = [model]
# "mirrored_models": whether the internal multi-models adopt the
# same architects and almost the same behaviors,
# which is used to simply the flops, model size calculation
self.ctx.mirrored_models = False
# Atomic operation during training/evaluation
self.hooks_in_train = collections.defaultdict(list)
# By default, use the same trigger keys
self.hooks_in_eval = copy.deepcopy(self.hooks_in_train)
self.hooks_in_ft = copy.deepcopy(self.hooks_in_train)
# register necessary hooks into self.hooks_in_train and
# self.hooks_in_eval
if not only_for_eval:
self.register_default_hooks_train()
if self.cfg.finetune.before_eval:
self.register_default_hooks_ft()
self.register_default_hooks_eval()
if self.cfg.federate.mode == 'distributed':
self.print_trainer_meta_info()
else:
# in standalone mode, by default, we print the trainer info only
# once for better logs readability
pass
@property
def cfg(self):
return self._cfg
@cfg.setter
def cfg(self, new_cfg):
self._cfg = new_cfg
self.ctx.cfg = new_cfg
self._setup_data_related_var_in_ctx(self.ctx)
[docs] def parse_data(self, data):
"""
Populate ``${split}_data``, ``${split}_loader`` and \
``num_${split}_data`` for different data splits
"""
raise NotImplementedError
[docs] def setup_data(self, ctx):
"""
Initialization data by ``cfg``.
"""
pass
def register_default_hooks_train(self):
pass
def register_default_hooks_eval(self):
pass
def register_default_hooks_ft(self):
pass
def reset_hook_in_train(self, target_trigger, target_hook_name=None):
hooks_dict = self.hooks_in_train
del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict,
target_hook_name,
target_trigger)
return del_one_hook_idx
def reset_hook_in_eval(self, target_trigger, target_hook_name=None):
hooks_dict = self.hooks_in_eval
del_one_hook_idx = self._reset_hook_in_trigger(hooks_dict,
target_hook_name,
target_trigger)
return del_one_hook_idx
def replace_hook_in_train(self, new_hook, target_trigger,
target_hook_name):
del_one_hook_idx = self.reset_hook_in_train(
target_trigger=target_trigger, target_hook_name=target_hook_name)
self.register_hook_in_train(new_hook=new_hook,
trigger=target_trigger,
insert_pos=del_one_hook_idx)
def replace_hook_in_eval(self, new_hook, target_trigger, target_hook_name):
del_one_hook_idx = self.reset_hook_in_eval(
target_trigger=target_trigger, target_hook_name=target_hook_name)
self.register_hook_in_eval(new_hook=new_hook,
trigger=target_trigger,
insert_pos=del_one_hook_idx)
def _reset_hook_in_trigger(self, hooks_dict, target_hook_name,
target_trigger):
# clean/delete existing hooks for a specific trigger,
# if target_hook_name given, will clean only the specific one;
# otherwise, will clean all hooks for the trigger.
assert target_trigger in self.HOOK_TRIGGER, \
f"Got {target_trigger} as hook trigger, you should specify a " \
f"string within {self.HOOK_TRIGGER}."
del_one_hook_idx = None
if target_hook_name is None:
hooks_dict[target_trigger] = []
del_one_hook_idx = -1 # -1 indicates del the whole list
else:
for hook_idx in range(len(hooks_dict[target_trigger])):
if target_hook_name == hooks_dict[target_trigger][
hook_idx].__name__:
del_one = hooks_dict[target_trigger].pop(hook_idx)
logger.info(f"Remove the hook `{del_one.__name__}` from "
f"hooks_set at trigger `{target_trigger}`")
del_one_hook_idx = hook_idx
break
if del_one_hook_idx is None:
logger.warning(
f"In hook del procedure, can't find the target hook "
f"named {target_hook_name}")
return del_one_hook_idx
def register_hook_in_train(self,
new_hook,
trigger,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_train
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def register_hook_in_ft(self,
new_hook,
trigger,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_ft
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def register_hook_in_eval(self,
new_hook,
trigger,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_eval
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def _register_hook(self, base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger):
assert trigger in self.HOOK_TRIGGER, \
f"Got {trigger} as hook trigger, you should specify a string " \
f"within {self.HOOK_TRIGGER}."
# parse the insertion position
target_hook_set = hooks_dict[trigger]
if insert_pos is not None:
assert (insert_pos == -1) or (insert_pos == len(target_hook_set)
== 0) or \
(0 <= insert_pos <= (len(target_hook_set))), \
f"Got {insert_pos} as insert pos, you should specify a " \
f"integer (1) =-1 " \
f"or (2) =0 for null target_hook_set;" \
f"or (3) within [0, {(len(target_hook_set))}]."
elif base_hook is not None:
base_hook_pos = target_hook_set.index(base_hook)
insert_pos = base_hook_pos - 1 if insert_mode == "before" else \
base_hook_pos + 1
# bounding the insert_pos in rational range
insert_pos = 0 if insert_pos < 0 else insert_pos
insert_pos = -1 if insert_pos > len(
target_hook_set) else insert_pos
else:
insert_pos = -1 # By default, the new hook is called finally
# register the new hook
if insert_pos == -1:
hooks_dict[trigger].append(new_hook)
else:
hooks_dict[trigger].insert(insert_pos, new_hook)
@use_diff
def train(self, target_data_split_name="train", hooks_set=None):
hooks_set = hooks_set or self.hooks_in_train
self.ctx.check_split(target_data_split_name)
num_samples = self._run_routine(MODE.TRAIN, hooks_set,
target_data_split_name)
return num_samples, self.get_model_para(), self.ctx.eval_metrics
def evaluate(self, target_data_split_name="test", hooks_set=None):
hooks_set = hooks_set or self.hooks_in_eval
if self.ctx.check_split(target_data_split_name, skip=True):
self._run_routine(MODE.TEST, hooks_set, target_data_split_name)
else:
self.ctx.eval_metrics = dict()
return self.ctx.eval_metrics
def finetune(self, target_data_split_name="train", hooks_set=None):
hooks_set = hooks_set or self.hooks_in_ft
self.ctx.check_split(target_data_split_name)
self._run_routine(MODE.FINETUNE, hooks_set, target_data_split_name)
@lifecycle(LIFECYCLE.ROUTINE)
def _run_routine(self, mode, hooks_set, dataset_name=None):
"""Run the hooks_set and maintain the mode
Arguments:
mode: running mode of client, chosen from train/val/test
Note:
Considering evaluation could be in ```hooks_set["on_epoch_end"]```,
there could be two data loaders in self.ctx, we must tell the
running hooks which data_loader to call and which
num_samples to count
"""
for hook in hooks_set["on_fit_start"]:
hook(self.ctx)
self._run_epoch(hooks_set)
for hook in hooks_set["on_fit_end"]:
hook(self.ctx)
return self.ctx.num_samples
@lifecycle(LIFECYCLE.EPOCH)
def _run_epoch(self, hooks_set):
for epoch_i in range(
getattr(self.ctx, f"num_{self.ctx.cur_split}_epoch")):
self.ctx.cur_epoch_i = CtxVar(epoch_i, "epoch")
for hook in hooks_set["on_epoch_start"]:
hook(self.ctx)
self._run_batch(hooks_set)
for hook in hooks_set["on_epoch_end"]:
hook(self.ctx)
@lifecycle(LIFECYCLE.BATCH)
def _run_batch(self, hooks_set):
for batch_i in range(
getattr(self.ctx, f"num_{self.ctx.cur_split}_batch")):
self.ctx.cur_batch_i = CtxVar(batch_i, LIFECYCLE.BATCH)
for hook in hooks_set["on_batch_start"]:
hook(self.ctx)
for hook in hooks_set["on_batch_forward"]:
hook(self.ctx)
for hook in hooks_set["on_batch_backward"]:
hook(self.ctx)
for hook in hooks_set["on_batch_end"]:
hook(self.ctx)
# Break in the final epoch
if self.ctx.cur_mode in [
MODE.TRAIN, MODE.FINETUNE
] and self.ctx.cur_epoch_i == self.ctx.num_train_epoch - 1:
if batch_i >= self.ctx.num_train_batch_last_epoch - 1:
break
[docs] def update(self, model_parameters, strict=False):
"""
Called by the FL client to update the model parameters
Arguments:
model_parameters (dict): {model_name: model_val}
strict (bool): ensure the k-v paris are strictly same
"""
pass
[docs] def get_model_para(self):
"""
:return: model_parameters (dict): {model_name: model_val}
"""
pass
[docs] def _param_filter(self, state_dict, filter_keywords=None):
"""
model parameter filter when transmit between local and gloabl,
which is useful in personalization.
e.g., setting cfg.personalization.local_param= ['bn', 'norms']
indicates the implementation of
"FedBN: Federated Learning on Non-IID Features via Local Batch
Normalization, ICML2021", which can be found in
https://openreview.net/forum?id=6YEQUn0QICG
Arguments:
state_dict (dict): PyTorch Module object's state_dict.
Returns:
state_dict (dict): remove the keys that match any of the given
keywords.
"""
if self.cfg.federate.method in ["local", "global"]:
return {}
if filter_keywords is None:
filter_keywords = self.cfg.personalization.local_param
trainable_filter = lambda p: True if \
self.cfg.personalization.share_non_trainable_para else \
lambda p: p in self.ctx.trainable_para_names
keyword_filter = filter_by_specified_keywords
return dict(
filter(
lambda elem: trainable_filter(elem[1]) and keyword_filter(
elem[0], filter_keywords), state_dict.items()))
def save_model(self, path, cur_round=-1):
raise NotImplementedError(
"The function `save_model` should be implemented according to "
"the ML backend (Pytorch, Tensorflow ...).")
def load_model(self, path):
raise NotImplementedError(
"The function `load_model` should be implemented according to "
"the ML backend (Pytorch, Tensorflow ...).")