from typing import Type
import numpy as np
import torch
from torch.nn.functional import softmax as f_softmax
from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from federatedscope.core.trainers.trainer_multi_model import \
GeneralMultiModelTrainer
[docs]class FedEMTrainer(GeneralMultiModelTrainer):
"""
The FedEM implementation, "Federated Multi-Task Learning under a \
Mixture of Distributions (NeurIPS 2021)" \
based on the Algorithm 1 in their paper and official codes:
https://github.com/omarfoq/FedEM
"""
def __init__(self,
model_nums,
models_interact_mode="sequential",
model=None,
data=None,
device=None,
config=None,
base_trainer: Type[GeneralTorchTrainer] = None):
super(FedEMTrainer,
self).__init__(model_nums, models_interact_mode, model, data,
device, config, base_trainer)
device = self.ctx.device
# --------------- attribute-level modifications ----------------------
# used to mixture the internal models
self.weights_internal_models = (torch.ones(self.model_nums) /
self.model_nums).to(device)
self.weights_data_sample = (
torch.ones(self.model_nums, self.ctx.num_train_batch) /
self.model_nums).to(device)
self.ctx.all_losses_model_batch = torch.zeros(
self.model_nums, self.ctx.num_train_batch).to(device)
self.ctx.cur_batch_idx = -1
# `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in
# func `_hook_on_fit_end_ensemble_eval`
# -> self.ctx.test_y_prob_ensemble = 0
# -> self.ctx.train_y_prob_ensemble = 0
# -> self.ctx.val_y_prob_ensemble = 0
# ---------------- action-level modifications -----------------------
# see register_multiple_model_hooks(),
# which is called in the __init__ of `GeneralMultiModelTrainer`
[docs] def register_multiple_model_hooks(self):
"""
customized multiple_model_hooks, which is called
in the __init__ of `GeneralMultiModelTrainer`
"""
# First register hooks for model 0
# ---------------- train hooks -----------------------
self.register_hook_in_train(
new_hook=self._hook_on_fit_start_mixture_weights_update,
trigger="on_fit_start",
insert_pos=0) # insert at the front
self.register_hook_in_train(
new_hook=self._hook_on_fit_start_flop_count,
trigger="on_fit_start",
insert_pos=1 # follow the mixture operation
)
self.register_hook_in_train(new_hook=self._hook_on_fit_end_flop_count,
trigger="on_fit_end",
insert_pos=-1)
self.register_hook_in_train(
new_hook=self._hook_on_batch_forward_weighted_loss,
trigger="on_batch_forward",
insert_pos=-1)
self.register_hook_in_train(
new_hook=self._hook_on_batch_start_track_batch_idx,
trigger="on_batch_start",
insert_pos=0) # insert at the front
# ---------------- eval hooks -----------------------
self.register_hook_in_eval(
new_hook=self._hook_on_batch_end_gather_loss,
trigger="on_batch_end",
insert_pos=0
) # insert at the front, (we need gather the loss before clean it)
self.register_hook_in_eval(
new_hook=self._hook_on_batch_start_track_batch_idx,
trigger="on_batch_start",
insert_pos=0) # insert at the front
# replace the original evaluation into the ensemble one
self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval,
target_trigger="on_fit_end",
target_hook_name="_hook_on_fit_end")
# Then for other models, set the same hooks as model 0
# since we differentiate different models in the hook
# implementations via ctx.cur_model_idx
self.hooks_in_train_multiple_models.extend([
self.hooks_in_train_multiple_models[0]
for _ in range(1, self.model_nums)
])
self.hooks_in_eval_multiple_models.extend([
self.hooks_in_eval_multiple_models[0]
for _ in range(1, self.model_nums)
])
[docs] def _hook_on_batch_start_track_batch_idx(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.optimizer_for_global_model`` False
================================== ===========================
"""
# for both train & eval
ctx.cur_batch_idx = (self.ctx.cur_batch_idx +
1) % self.ctx.num_train_batch
[docs] def _hook_on_batch_forward_weighted_loss(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.loss_batch`` Multiply by \
``weights_internal_models``
================================== ===========================
"""
# for only train
ctx.loss_batch *= self.weights_internal_models[ctx.cur_model_idx]
[docs] def _hook_on_batch_end_gather_loss(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.all_losses_model_batch`` Gather loss
================================== ===========================
"""
# for only eval
# before clean the loss_batch; we record it
# for further weights_data_sample update
ctx.all_losses_model_batch[ctx.cur_model_idx][
ctx.cur_batch_idx] = ctx.loss_batch.item()
[docs] def _hook_on_fit_start_mixture_weights_update(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.mode`` Evaluate
================================== ===========================
"""
# for only train
if ctx.cur_model_idx != 0:
# do the mixture_weights_update once
pass
else:
# gathers losses for all sample in iterator
# for each internal model, calling `evaluate()`
for model_idx in range(self.model_nums):
self._switch_model_ctx(model_idx)
self.evaluate(target_data_split_name="train")
self.weights_data_sample = f_softmax(
(torch.log(self.weights_internal_models) -
ctx.all_losses_model_batch.T),
dim=1).T
self.weights_internal_models = self.weights_data_sample.mean(dim=1)
# restore the model_ctx
self._switch_model_ctx(0)
[docs] def _hook_on_fit_start_flop_count(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.monitor`` Count total_flops
================================== ===========================
"""
self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \
self.model_nums * ctx.num_train_data
[docs] def _hook_on_fit_end_flop_count(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.monitor`` Count total_flops
================================== ===========================
"""
self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \
self.model_nums * ctx.num_train_data
[docs] def _hook_on_fit_end_ensemble_eval(self, ctx):
"""
Ensemble evaluation
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.ys_prob_ensemble`` Ensemble ys_prob
``ctx.ys_true`` Concatenate results
``ctx.ys_prob`` Concatenate results
``ctx.eval_metrics`` Get evaluated results from \
``ctx.monitor``
================================== ===========================
"""
if ctx.get("ys_prob_ensemble", None) is None:
ctx.ys_prob_ensemble = CtxVar(0, LIFECYCLE.ROUTINE)
ctx.ys_prob_ensemble += np.concatenate(
ctx.ys_prob) * self.weights_internal_models[
ctx.cur_model_idx].item()
# do metrics calculation after the last internal model evaluation done
if ctx.cur_model_idx == self.model_nums - 1:
ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true),
LIFECYCLE.ROUTINE)
ctx.ys_prob = ctx.ys_prob_ensemble
ctx.eval_metrics = self.ctx.monitor.eval(ctx)