Source code for federatedscope.core.trainers.trainer_FedEM

from typing import Type

import numpy as np
import torch
from torch.nn.functional import softmax as f_softmax

from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from federatedscope.core.trainers.trainer_multi_model import \
    GeneralMultiModelTrainer


[docs]class FedEMTrainer(GeneralMultiModelTrainer): """ The FedEM implementation, "Federated Multi-Task Learning under a \ Mixture of Distributions (NeurIPS 2021)" \ based on the Algorithm 1 in their paper and official codes: https://github.com/omarfoq/FedEM """ def __init__(self, model_nums, models_interact_mode="sequential", model=None, data=None, device=None, config=None, base_trainer: Type[GeneralTorchTrainer] = None): super(FedEMTrainer, self).__init__(model_nums, models_interact_mode, model, data, device, config, base_trainer) device = self.ctx.device # --------------- attribute-level modifications ---------------------- # used to mixture the internal models self.weights_internal_models = (torch.ones(self.model_nums) / self.model_nums).to(device) self.weights_data_sample = ( torch.ones(self.model_nums, self.ctx.num_train_batch) / self.model_nums).to(device) self.ctx.all_losses_model_batch = torch.zeros( self.model_nums, self.ctx.num_train_batch).to(device) self.ctx.cur_batch_idx = -1 # `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in # func `_hook_on_fit_end_ensemble_eval` # -> self.ctx.test_y_prob_ensemble = 0 # -> self.ctx.train_y_prob_ensemble = 0 # -> self.ctx.val_y_prob_ensemble = 0 # ---------------- action-level modifications ----------------------- # see register_multiple_model_hooks(), # which is called in the __init__ of `GeneralMultiModelTrainer`
[docs] def register_multiple_model_hooks(self): """ customized multiple_model_hooks, which is called in the __init__ of `GeneralMultiModelTrainer` """ # First register hooks for model 0 # ---------------- train hooks ----------------------- self.register_hook_in_train( new_hook=self._hook_on_fit_start_mixture_weights_update, trigger="on_fit_start", insert_pos=0) # insert at the front self.register_hook_in_train( new_hook=self._hook_on_fit_start_flop_count, trigger="on_fit_start", insert_pos=1 # follow the mixture operation ) self.register_hook_in_train(new_hook=self._hook_on_fit_end_flop_count, trigger="on_fit_end", insert_pos=-1) self.register_hook_in_train( new_hook=self._hook_on_batch_forward_weighted_loss, trigger="on_batch_forward", insert_pos=-1) self.register_hook_in_train( new_hook=self._hook_on_batch_start_track_batch_idx, trigger="on_batch_start", insert_pos=0) # insert at the front # ---------------- eval hooks ----------------------- self.register_hook_in_eval( new_hook=self._hook_on_batch_end_gather_loss, trigger="on_batch_end", insert_pos=0 ) # insert at the front, (we need gather the loss before clean it) self.register_hook_in_eval( new_hook=self._hook_on_batch_start_track_batch_idx, trigger="on_batch_start", insert_pos=0) # insert at the front # replace the original evaluation into the ensemble one self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval, target_trigger="on_fit_end", target_hook_name="_hook_on_fit_end") # Then for other models, set the same hooks as model 0 # since we differentiate different models in the hook # implementations via ctx.cur_model_idx self.hooks_in_train_multiple_models.extend([ self.hooks_in_train_multiple_models[0] for _ in range(1, self.model_nums) ]) self.hooks_in_eval_multiple_models.extend([ self.hooks_in_eval_multiple_models[0] for _ in range(1, self.model_nums) ])
[docs] def _hook_on_batch_start_track_batch_idx(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.optimizer_for_global_model`` False ================================== =========================== """ # for both train & eval ctx.cur_batch_idx = (self.ctx.cur_batch_idx + 1) % self.ctx.num_train_batch
[docs] def _hook_on_batch_forward_weighted_loss(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.loss_batch`` Multiply by \ ``weights_internal_models`` ================================== =========================== """ # for only train ctx.loss_batch *= self.weights_internal_models[ctx.cur_model_idx]
[docs] def _hook_on_batch_end_gather_loss(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.all_losses_model_batch`` Gather loss ================================== =========================== """ # for only eval # before clean the loss_batch; we record it # for further weights_data_sample update ctx.all_losses_model_batch[ctx.cur_model_idx][ ctx.cur_batch_idx] = ctx.loss_batch.item()
[docs] def _hook_on_fit_start_mixture_weights_update(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.mode`` Evaluate ================================== =========================== """ # for only train if ctx.cur_model_idx != 0: # do the mixture_weights_update once pass else: # gathers losses for all sample in iterator # for each internal model, calling `evaluate()` for model_idx in range(self.model_nums): self._switch_model_ctx(model_idx) self.evaluate(target_data_split_name="train") self.weights_data_sample = f_softmax( (torch.log(self.weights_internal_models) - ctx.all_losses_model_batch.T), dim=1).T self.weights_internal_models = self.weights_data_sample.mean(dim=1) # restore the model_ctx self._switch_model_ctx(0)
[docs] def _hook_on_fit_start_flop_count(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.monitor`` Count total_flops ================================== =========================== """ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ self.model_nums * ctx.num_train_data
[docs] def _hook_on_fit_end_flop_count(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.monitor`` Count total_flops ================================== =========================== """ self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ self.model_nums * ctx.num_train_data
[docs] def _hook_on_fit_end_ensemble_eval(self, ctx): """ Ensemble evaluation Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.ys_prob_ensemble`` Ensemble ys_prob ``ctx.ys_true`` Concatenate results ``ctx.ys_prob`` Concatenate results ``ctx.eval_metrics`` Get evaluated results from \ ``ctx.monitor`` ================================== =========================== """ if ctx.get("ys_prob_ensemble", None) is None: ctx.ys_prob_ensemble = CtxVar(0, LIFECYCLE.ROUTINE) ctx.ys_prob_ensemble += np.concatenate( ctx.ys_prob) * self.weights_internal_models[ ctx.cur_model_idx].item() # do metrics calculation after the last internal model evaluation done if ctx.cur_model_idx == self.model_nums - 1: ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) ctx.ys_prob = ctx.ys_prob_ensemble ctx.eval_metrics = self.ctx.monitor.eval(ctx)