import copy
from types import FunctionType
from typing import Type
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
import numpy as np
[docs]class GeneralMultiModelTrainer(GeneralTorchTrainer):
def __init__(self,
model_nums,
models_interact_mode="sequential",
model=None,
data=None,
device=None,
config=None,
base_trainer: Type[GeneralTorchTrainer] = None):
"""
`GeneralMultiModelTrainer` supports train/eval via multiple
internal models
Arguments:
model_nums (int): how many internal models and optimizers
will be held by the trainer
models_interact_mode (str): how the models interact, can be
"sequential" or "parallel".
model: training model
data: a dict contains train/val/test data
device: device to run
config: for trainer-related configuration
base_trainer: if given, the GeneralMultiModelTrainer init
will based on base_trainer copy
The sequential mode indicates the interaction at
run_routine level
[one model runs its whole routine, then do sth. for
interaction, then next model runs its whole routine]
... -> run_routine_model_i
-> _switch_model_ctx
-> (on_fit_end, _interact_to_other_models)
-> run_routine_model_i+1
-> ...
The parallel mode indicates the interaction
at point-in-time level
[At a specific point-in-time, one model call hooks (
including interaction), then next model call hooks]
... -> (on_xxx_point, hook_xxx_model_i)
-> (on_xxx_point, _interact_to_other_models)
-> (on_xxx_point, _switch_model_ctx)
-> (on_xxx_point, hook_xxx_model_i+1)
-> ...
"""
# support two initialization methods for the `GeneralMultiModelTrainer`
# 1) from another trainer; or 2) standard init manner given (model,
# data, device, config)
if base_trainer is None:
assert model is not None and \
data is not None and \
device is not None and \
config is not None, "when not copy construction, (model, " \
"data, device, config) should not be " \
"None"
super(GeneralMultiModelTrainer,
self).__init__(model, data, device, config)
else:
assert isinstance(base_trainer, GeneralMultiModelTrainer) or \
issubclass(type(base_trainer), GeneralMultiModelTrainer) \
or isinstance(base_trainer, GeneralTorchTrainer) or \
issubclass(type(base_trainer), GeneralTorchTrainer) or \
"can only copy instances of `GeneralMultiModelTrainer` " \
"and its subclasses, or " \
"`GeneralTorchTrainer` and its subclasses"
self.__dict__ = copy.deepcopy(base_trainer.__dict__)
assert models_interact_mode in ["sequential", "parallel"], \
f"Invalid models_interact_mode, should be `sequential` or " \
f"`parallel`, but got {models_interact_mode}"
self.models_interact_mode = models_interact_mode
if int(model_nums) != model_nums or model_nums < 1:
raise ValueError(
f"model_nums should be integer and >= 1, got {model_nums}.")
self.model_nums = model_nums
self.ctx.cur_model_idx = 0 # used to mark cur model
# different internal models can have different hook_set
self.hooks_in_train_multiple_models = [self.hooks_in_train]
self.hooks_in_eval_multiple_models = [self.hooks_in_eval]
self.init_multiple_models()
self.init_multiple_model_hooks()
assert len(self.ctx.models) == model_nums == \
len(self.hooks_in_train_multiple_models) == len(
self.hooks_in_eval_multiple_models),\
"After init, len(hooks_in_train_multiple_models), " \
"len(hooks_in_eval_multiple_models), " \
"len(ctx.models) and model_nums should be the same"
[docs] def init_multiple_models(self):
"""
init multiple models and optimizers: the default implementation
is copy init manner;
========================= Extension =============================
users can override this function according to their own
requirements
"""
additional_models = [
copy.deepcopy(self.ctx.model) for _ in range(self.model_nums - 1)
]
self.ctx.models = [self.ctx.model] + additional_models
self.ctx.optimizers = [
get_optimizer(self.ctx.models[i], **self.cfg.train.optimizer)
for i in range(0, self.model_nums)
]
[docs] def register_multiple_model_hooks(self):
"""
By default, all internal models adopt the same hook_set.
Extension
Users can override this function to register customized hooks \
for different internal models.
Note:
- for sequential mode, users can append interact_hook on \
begin/end triggers such as \
" -> (on_fit_end, _interact_to_other_models) -> "
- for parallel mode, users can append interact_hook on any \
trigger they want such as \
" -> (on_xxx_point, _interact_to_other_models) -> "
- we must tell the running hooks which data_loader to \
call and which num_samples to count
"""
self.hooks_in_train_multiple_models.extend([
self.hooks_in_train_multiple_models[0]
for _ in range(1, self.model_nums)
])
self.hooks_in_eval_multiple_models.extend([
self.hooks_in_eval_multiple_models[0]
for _ in range(1, self.model_nums)
])
def init_multiple_model_hooks(self):
self.register_multiple_model_hooks()
if self.models_interact_mode == "sequential":
# hooks_in_xxx is a list of dict, hooks_in_xxx[i] stores
# specific set for i-th internal model;
# for each dict, the key indicates point-in-time and the value
# indicates specific hook
self.hooks_in_train = self.hooks_in_train_multiple_models
self.hooks_in_eval = self.hooks_in_eval_multiple_models
elif self.models_interact_mode == "parallel":
# hooks_in_xxx is a dict whose key indicates point-in-time and
# value indicates specific hook
for trigger in list(self.hooks_in_train.keys()):
self.hooks_in_train[trigger] = []
self.hooks_in_eval[trigger] = []
for model_idx in range(len(self.ctx.models)):
self.hooks_in_train[trigger].extend(
self.hooks_in_train_multiple_models[model_idx]
[trigger])
self.hooks_in_train[trigger].extend(
[self._switch_model_ctx])
self.hooks_in_eval[trigger].extend(
self.hooks_in_eval_multiple_models[model_idx][trigger])
self.hooks_in_eval[trigger].extend(
[self._switch_model_ctx])
else:
raise RuntimeError(
f"Invalid models_interact_mode, should be `sequential` or "
f"`parallel`,"
f" but got {self.models_interact_mode}")
def register_hook_in_train(self,
new_hook,
trigger,
model_idx=0,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_train_multiple_models[model_idx]
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def register_hook_in_eval(self,
new_hook,
trigger,
model_idx=0,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_eval_multiple_models[model_idx]
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def _switch_model_ctx(self, next_model_idx=None):
if next_model_idx is None:
next_model_idx = (self.ctx.cur_model_idx + 1) % len(
self.ctx.models)
self.ctx.cur_model_idx = next_model_idx
self.ctx.model = self.ctx.models[next_model_idx]
self.ctx.optimizer = self.ctx.optimizers[next_model_idx]
[docs] def _run_routine(self, mode, hooks_set, dataset_name=None):
"""Run the hooks_set and maintain the mode for multiple internal models
Arguments:
mode: running mode of client, chosen from train/val/test
Note:
Considering evaluation could be in ```hooks_set[
"on_epoch_end"]```, there could be two data loaders in \
self.ctx, we must tell the running hooks which data_loader to \
call and which num_samples to count
"""
num_samples_model = list()
if self.models_interact_mode == "sequential":
assert isinstance(hooks_set, list) and isinstance(hooks_set[0],
dict), \
"When models_interact_mode=sequential, " \
"hooks_set should be a list of dict" \
"hooks_set[i] stores specific set for i-th internal model." \
"For each dict, the key indicates point-in-time and the " \
"value indicates specific hook"
for model_idx in range(len(self.ctx.models)):
# switch different hooks & ctx for different internal models
hooks_set_model_i = hooks_set[model_idx]
self._switch_model_ctx(model_idx)
# [Interaction at run_routine level]
# one model runs its whole routine, then do sth. for
# interaction, then next model runs its whole routine
# ... -> run_routine_model_i
# -> _switch_model_ctx
# -> (on_fit_end, _interact_to_other_models)
# -> run_routine_model_i+1
# -> ...
num_samples = super()._run_routine(mode, hooks_set_model_i,
dataset_name)
num_samples_model.append(num_samples)
elif self.models_interact_mode == "parallel":
assert isinstance(hooks_set, dict), \
"When models_interact_mode=parallel, hooks_set should be a " \
"dict whose key indicates point-in-time and value indicates " \
"specific hook"
# [Interaction at point-in-time level]
# at a specific point-in-time, one model call hooks (including
# interaction), then next model call hooks
# ... -> (on_xxx_point, hook_xxx_model_i)
# -> (on_xxx_point, _interact_to_other_models)
# -> (on_xxx_point, _switch_model_ctx)
# -> (on_xxx_point, hook_xxx_model_i+1)
# -> ...
num_samples = super()._run_routine(mode, hooks_set, dataset_name)
num_samples_model.append(num_samples)
else:
raise RuntimeError(
f"Invalid models_interact_mode, should be `sequential` or "
f"`parallel`,"
f" but got {self.models_interact_mode}")
# For now, we return the average number of samples for different models
return np.mean(num_samples_model)
[docs] def get_model_para(self):
"""
return multiple model parameters
:return:
"""
trained_model_para = []
for model_idx in range(self.model_nums):
trained_model_para.append(
self._param_filter(
self.ctx.models[model_idx].cpu().state_dict()))
return trained_model_para[
0] if self.model_nums == 1 else trained_model_para
[docs] def update(self, model_parameters, strict=False):
# update multiple model paras
"""
Arguments:
model_parameters (list[dict]): Multiple pyTorch Module object's
state_dict.
"""
if self.model_nums == 1:
super().update(model_parameters, strict=strict)
else:
assert isinstance(model_parameters, list) and isinstance(
model_parameters[0], dict), \
"model_parameters should a list of multiple state_dict"
assert len(model_parameters) == self.model_nums, \
f"model_parameters should has the same length to " \
f"self.model_nums, " \
f"but got {len(model_parameters)} and {self.model_nums} " \
f"respectively"
for model_idx in range(self.model_nums):
self.ctx.models[model_idx].load_state_dict(self._param_filter(
model_parameters[model_idx]),
strict=strict)
def train(self, target_data_split_name="train"):
# return multiple model paras
sample_size, _, results = super().train(target_data_split_name)
return sample_size, self.get_model_para(), results