Source code for federatedscope.core.trainers.trainer_multi_model

import copy
from types import FunctionType
from typing import Type

from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer

import numpy as np


[docs]class GeneralMultiModelTrainer(GeneralTorchTrainer): def __init__(self, model_nums, models_interact_mode="sequential", model=None, data=None, device=None, config=None, base_trainer: Type[GeneralTorchTrainer] = None): """ `GeneralMultiModelTrainer` supports train/eval via multiple internal models Arguments: model_nums (int): how many internal models and optimizers will be held by the trainer models_interact_mode (str): how the models interact, can be "sequential" or "parallel". model: training model data: a dict contains train/val/test data device: device to run config: for trainer-related configuration base_trainer: if given, the GeneralMultiModelTrainer init will based on base_trainer copy The sequential mode indicates the interaction at run_routine level [one model runs its whole routine, then do sth. for interaction, then next model runs its whole routine] ... -> run_routine_model_i -> _switch_model_ctx -> (on_fit_end, _interact_to_other_models) -> run_routine_model_i+1 -> ... The parallel mode indicates the interaction at point-in-time level [At a specific point-in-time, one model call hooks ( including interaction), then next model call hooks] ... -> (on_xxx_point, hook_xxx_model_i) -> (on_xxx_point, _interact_to_other_models) -> (on_xxx_point, _switch_model_ctx) -> (on_xxx_point, hook_xxx_model_i+1) -> ... """ # support two initialization methods for the `GeneralMultiModelTrainer` # 1) from another trainer; or 2) standard init manner given (model, # data, device, config) if base_trainer is None: assert model is not None and \ data is not None and \ device is not None and \ config is not None, "when not copy construction, (model, " \ "data, device, config) should not be " \ "None" super(GeneralMultiModelTrainer, self).__init__(model, data, device, config) else: assert isinstance(base_trainer, GeneralMultiModelTrainer) or \ issubclass(type(base_trainer), GeneralMultiModelTrainer) \ or isinstance(base_trainer, GeneralTorchTrainer) or \ issubclass(type(base_trainer), GeneralTorchTrainer) or \ "can only copy instances of `GeneralMultiModelTrainer` " \ "and its subclasses, or " \ "`GeneralTorchTrainer` and its subclasses" self.__dict__ = copy.deepcopy(base_trainer.__dict__) assert models_interact_mode in ["sequential", "parallel"], \ f"Invalid models_interact_mode, should be `sequential` or " \ f"`parallel`, but got {models_interact_mode}" self.models_interact_mode = models_interact_mode if int(model_nums) != model_nums or model_nums < 1: raise ValueError( f"model_nums should be integer and >= 1, got {model_nums}.") self.model_nums = model_nums self.ctx.cur_model_idx = 0 # used to mark cur model # different internal models can have different hook_set self.hooks_in_train_multiple_models = [self.hooks_in_train] self.hooks_in_eval_multiple_models = [self.hooks_in_eval] self.init_multiple_models() self.init_multiple_model_hooks() assert len(self.ctx.models) == model_nums == \ len(self.hooks_in_train_multiple_models) == len( self.hooks_in_eval_multiple_models),\ "After init, len(hooks_in_train_multiple_models), " \ "len(hooks_in_eval_multiple_models), " \ "len(ctx.models) and model_nums should be the same"
[docs] def init_multiple_models(self): """ init multiple models and optimizers: the default implementation is copy init manner; ========================= Extension ============================= users can override this function according to their own requirements """ additional_models = [ copy.deepcopy(self.ctx.model) for _ in range(self.model_nums - 1) ] self.ctx.models = [self.ctx.model] + additional_models self.ctx.optimizers = [ get_optimizer(self.ctx.models[i], **self.cfg.train.optimizer) for i in range(0, self.model_nums) ]
[docs] def register_multiple_model_hooks(self): """ By default, all internal models adopt the same hook_set. Extension Users can override this function to register customized hooks \ for different internal models. Note: - for sequential mode, users can append interact_hook on \ begin/end triggers such as \ " -> (on_fit_end, _interact_to_other_models) -> " - for parallel mode, users can append interact_hook on any \ trigger they want such as \ " -> (on_xxx_point, _interact_to_other_models) -> " - we must tell the running hooks which data_loader to \ call and which num_samples to count """ self.hooks_in_train_multiple_models.extend([ self.hooks_in_train_multiple_models[0] for _ in range(1, self.model_nums) ]) self.hooks_in_eval_multiple_models.extend([ self.hooks_in_eval_multiple_models[0] for _ in range(1, self.model_nums) ])
def init_multiple_model_hooks(self): self.register_multiple_model_hooks() if self.models_interact_mode == "sequential": # hooks_in_xxx is a list of dict, hooks_in_xxx[i] stores # specific set for i-th internal model; # for each dict, the key indicates point-in-time and the value # indicates specific hook self.hooks_in_train = self.hooks_in_train_multiple_models self.hooks_in_eval = self.hooks_in_eval_multiple_models elif self.models_interact_mode == "parallel": # hooks_in_xxx is a dict whose key indicates point-in-time and # value indicates specific hook for trigger in list(self.hooks_in_train.keys()): self.hooks_in_train[trigger] = [] self.hooks_in_eval[trigger] = [] for model_idx in range(len(self.ctx.models)): self.hooks_in_train[trigger].extend( self.hooks_in_train_multiple_models[model_idx] [trigger]) self.hooks_in_train[trigger].extend( [self._switch_model_ctx]) self.hooks_in_eval[trigger].extend( self.hooks_in_eval_multiple_models[model_idx][trigger]) self.hooks_in_eval[trigger].extend( [self._switch_model_ctx]) else: raise RuntimeError( f"Invalid models_interact_mode, should be `sequential` or " f"`parallel`," f" but got {self.models_interact_mode}") def register_hook_in_train(self, new_hook, trigger, model_idx=0, insert_pos=None, base_hook=None, insert_mode="before"): hooks_dict = self.hooks_in_train_multiple_models[model_idx] self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos, new_hook, trigger) def register_hook_in_eval(self, new_hook, trigger, model_idx=0, insert_pos=None, base_hook=None, insert_mode="before"): hooks_dict = self.hooks_in_eval_multiple_models[model_idx] self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos, new_hook, trigger) def _switch_model_ctx(self, next_model_idx=None): if next_model_idx is None: next_model_idx = (self.ctx.cur_model_idx + 1) % len( self.ctx.models) self.ctx.cur_model_idx = next_model_idx self.ctx.model = self.ctx.models[next_model_idx] self.ctx.optimizer = self.ctx.optimizers[next_model_idx]
[docs] def _run_routine(self, mode, hooks_set, dataset_name=None): """Run the hooks_set and maintain the mode for multiple internal models Arguments: mode: running mode of client, chosen from train/val/test Note: Considering evaluation could be in ```hooks_set[ "on_epoch_end"]```, there could be two data loaders in \ self.ctx, we must tell the running hooks which data_loader to \ call and which num_samples to count """ num_samples_model = list() if self.models_interact_mode == "sequential": assert isinstance(hooks_set, list) and isinstance(hooks_set[0], dict), \ "When models_interact_mode=sequential, " \ "hooks_set should be a list of dict" \ "hooks_set[i] stores specific set for i-th internal model." \ "For each dict, the key indicates point-in-time and the " \ "value indicates specific hook" for model_idx in range(len(self.ctx.models)): # switch different hooks & ctx for different internal models hooks_set_model_i = hooks_set[model_idx] self._switch_model_ctx(model_idx) # [Interaction at run_routine level] # one model runs its whole routine, then do sth. for # interaction, then next model runs its whole routine # ... -> run_routine_model_i # -> _switch_model_ctx # -> (on_fit_end, _interact_to_other_models) # -> run_routine_model_i+1 # -> ... num_samples = super()._run_routine(mode, hooks_set_model_i, dataset_name) num_samples_model.append(num_samples) elif self.models_interact_mode == "parallel": assert isinstance(hooks_set, dict), \ "When models_interact_mode=parallel, hooks_set should be a " \ "dict whose key indicates point-in-time and value indicates " \ "specific hook" # [Interaction at point-in-time level] # at a specific point-in-time, one model call hooks (including # interaction), then next model call hooks # ... -> (on_xxx_point, hook_xxx_model_i) # -> (on_xxx_point, _interact_to_other_models) # -> (on_xxx_point, _switch_model_ctx) # -> (on_xxx_point, hook_xxx_model_i+1) # -> ... num_samples = super()._run_routine(mode, hooks_set, dataset_name) num_samples_model.append(num_samples) else: raise RuntimeError( f"Invalid models_interact_mode, should be `sequential` or " f"`parallel`," f" but got {self.models_interact_mode}") # For now, we return the average number of samples for different models return np.mean(num_samples_model)
[docs] def get_model_para(self): """ return multiple model parameters :return: """ trained_model_para = [] for model_idx in range(self.model_nums): trained_model_para.append( self._param_filter( self.ctx.models[model_idx].cpu().state_dict())) return trained_model_para[ 0] if self.model_nums == 1 else trained_model_para
[docs] def update(self, model_parameters, strict=False): # update multiple model paras """ Arguments: model_parameters (list[dict]): Multiple pyTorch Module object's state_dict. """ if self.model_nums == 1: super().update(model_parameters, strict=strict) else: assert isinstance(model_parameters, list) and isinstance( model_parameters[0], dict), \ "model_parameters should a list of multiple state_dict" assert len(model_parameters) == self.model_nums, \ f"model_parameters should has the same length to " \ f"self.model_nums, " \ f"but got {len(model_parameters)} and {self.model_nums} " \ f"respectively" for model_idx in range(self.model_nums): self.ctx.models[model_idx].load_state_dict(self._param_filter( model_parameters[model_idx]), strict=strict)
def train(self, target_data_split_name="train"): # return multiple model paras sample_size, _, results = super().train(target_data_split_name) return sample_size, self.get_model_para(), results