Source code for federatedscope.core.trainers.trainer_Ditto

import copy
import logging

import torch

from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from federatedscope.core.optimizer import wrap_regularized_optimizer
from federatedscope.core.trainers.utils import calculate_batch_epoch_num
from typing import Type

logger = logging.getLogger(__name__)

DEBUG_DITTO = False


[docs]def wrap_DittoTrainer( base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: """ Build a `DittoTrainer` with a plug-in manner, by registering new functions into specific `BaseTrainer` The Ditto implementation, "Ditto: Fair and Robust Federated Learning Through Personalization. (ICML2021)" based on the Algorithm 2 in their paper and official codes: https://github.com/litian96/ditto """ # ---------------- attribute-level plug-in ----------------------- init_Ditto_ctx(base_trainer) # ---------------- action-level plug-in ----------------------- base_trainer.register_hook_in_train(new_hook=_hook_on_fit_start_clean, trigger='on_fit_start', insert_pos=-1) base_trainer.register_hook_in_train( new_hook=_hook_on_fit_start_set_regularized_para, trigger="on_fit_start", insert_pos=0) base_trainer.register_hook_in_train( new_hook=_hook_on_batch_start_switch_model, trigger="on_batch_start", insert_pos=0) base_trainer.register_hook_in_train( new_hook=_hook_on_batch_forward_cnt_num, trigger="on_batch_forward", insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count, trigger="on_batch_end", insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_calibrate, trigger='on_fit_end', insert_pos=-1) # evaluation is based on the local personalized model base_trainer.register_hook_in_eval( new_hook=_hook_on_fit_start_switch_local_model, trigger="on_fit_start", insert_pos=0) base_trainer.register_hook_in_eval( new_hook=_hook_on_fit_end_switch_global_model, trigger="on_fit_end", insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_free_cuda, trigger="on_fit_end", insert_pos=-1) base_trainer.register_hook_in_eval(new_hook=_hook_on_fit_end_free_cuda, trigger="on_fit_end", insert_pos=-1) return base_trainer
def init_Ditto_ctx(base_trainer): """ init necessary attributes used in Ditto, `global_model` acts as the shared global model in FedAvg; `local_model` acts as personalized model will be optimized with regularization based on weights of `global_model` """ ctx = base_trainer.ctx cfg = base_trainer.cfg ctx.global_model = copy.deepcopy(ctx.model) ctx.local_model = copy.deepcopy(ctx.model) # the personalized model ctx.models = [ctx.local_model, ctx.global_model] ctx.model = ctx.global_model ctx.use_local_model_current = False ctx.num_samples_local_model_train = 0 # track the batch_num, epoch_num, for local & global model respectively cfg_p_local_update_steps = cfg.personalization.local_update_steps ctx.num_train_batch_for_local_model, \ ctx.num_train_batch_last_epoch_for_local_model, \ ctx.num_train_epoch_for_local_model, \ ctx.num_total_train_batch = \ calculate_batch_epoch_num(cfg_p_local_update_steps, cfg.train.batch_or_epoch, ctx.num_train_data, cfg.dataloader.batch_size, cfg.dataloader.drop_last) # In the first # 1. `num_train_batch` and `num_train_batch_last_epoch` # (batch_or_epoch == 'batch' case) or # 2. `num_train_epoch`, # (batch_or_epoch == 'epoch' case) # we will manipulate local models, and manipulate global model in the # remaining steps if cfg.train.batch_or_epoch == 'batch': ctx.num_train_batch += ctx.num_train_batch_for_local_model ctx.num_train_batch_last_epoch += \ ctx.num_train_batch_last_epoch_for_local_model else: ctx.num_train_epoch += ctx.num_train_epoch_for_local_model def _hook_on_fit_start_set_regularized_para(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.global_model`` Move to ``ctx.device`` and set \ to ``train`` mode ``ctx.local_model`` Move to ``ctx.device`` and set \ to ``train`` mode ``ctx.optimizer_for_global_model`` Initialize by ``ctx.cfg`` and \ wrapped by ``wrap_regularized_optimizer`` ``ctx.optimizer_for_local_model`` Initialize by ``ctx.cfg`` and \ set compared parameter group ================================== =========================== """ # set the compared model data for local personalized model ctx.global_model.to(ctx.device) ctx.local_model.to(ctx.device) ctx.global_model.train() ctx.local_model.train() compared_global_model_para = [{ "params": list(ctx.global_model.parameters()) }] ctx.optimizer_for_global_model = get_optimizer(ctx.global_model, **ctx.cfg.train.optimizer) ctx.optimizer_for_local_model = get_optimizer(ctx.local_model, **ctx.cfg.train.optimizer) ctx.optimizer_for_local_model = wrap_regularized_optimizer( ctx.optimizer_for_local_model, ctx.cfg.personalization.regular_weight) ctx.optimizer_for_local_model.set_compared_para_group( compared_global_model_para) def _hook_on_fit_start_clean(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.optimizer`` Delete ``ctx.num_..._local_model_train`` Initialize to 0 ================================== =========================== """ # remove the unnecessary optimizer del ctx.optimizer ctx.num_samples_local_model_train = 0 def _hook_on_fit_end_calibrate(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.num_samples`` Minus \ ``ctx.num_samples_local_model_train`` ``ctx.eval_metrics`` Record ``train_total`` and \ ``train_total_local_model`` ================================== =========================== """ # make the num_samples_train only related to the global model. # (num_samples_train will be used in aggregation process) ctx.num_samples -= ctx.num_samples_local_model_train ctx.eval_metrics['train_total'] = ctx.num_samples ctx.eval_metrics['train_total_local_model'] = \ ctx.num_samples_local_model_train def _hook_on_batch_end_flop_count(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.monitor`` Monitor total flops ================================== =========================== """ # besides the normal forward flops, the regularization adds the cost of # number of model parameters ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 def _hook_on_batch_forward_cnt_num(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.num_..._local_model_train`` Add `ctx.batch_size` ================================== =========================== """ if ctx.use_local_model_current: ctx.num_samples_local_model_train += ctx.batch_size def _hook_on_batch_start_switch_model(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.use_local_model_current`` Set to ``True`` or ``False`` ``ctx.model`` Set to ``ctx.local_model`` or \ ``ctx.global_model`` ``ctx.optimizer`` Set to \ ``ctx.optimizer_for_local_model`` or ``ctx.optimizer_for_global_model`` ================================== =========================== """ if ctx.cfg.train.batch_or_epoch == 'batch': if ctx.cur_epoch_i == (ctx.num_train_epoch - 1): ctx.use_local_model_current = \ ctx.cur_batch_i < \ ctx.num_train_batch_last_epoch_for_local_model else: ctx.use_local_model_current = \ ctx.cur_batch_i < ctx.num_train_batch_for_local_model else: ctx.use_local_model_current = \ ctx.cur_epoch_i < ctx.num_train_epoch_for_local_model if DEBUG_DITTO: logger.info("====================================================") logger.info(f"cur_epoch_i: {ctx.cur_epoch_i}") logger.info(f"num_train_epoch: {ctx.num_train_epoch}") logger.info(f"cur_batch_i: {ctx.cur_batch_i}") logger.info(f"num_train_batch: {ctx.num_train_batch}") logger.info(f"num_train_batch_for_local_model: " f"{ctx.num_train_batch_for_local_model}") logger.info(f"num_train_epoch_for_local_model: " f"{ctx.num_train_epoch_for_local_model}") logger.info(f"use_local_model: {ctx.use_local_model_current}") if ctx.use_local_model_current: ctx.model = ctx.local_model ctx.optimizer = ctx.optimizer_for_local_model else: ctx.model = ctx.global_model ctx.optimizer = ctx.optimizer_for_global_model # Note that Ditto only updates the para of global_model received from other # FL participants, and in the remaining steps, ctx.model has been = # ctx.global_model, thus we do not need register the following hook # def hook_on_fit_end_link_global_model(ctx): # ctx.model = ctx.global_model def _hook_on_fit_start_switch_local_model(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.model`` Set to ``ctx.local_model`` and \ set to ``eval`` mode ================================== =========================== """ ctx.model = ctx.local_model ctx.model.eval() def _hook_on_fit_end_switch_global_model(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.model `` Set to ``ctx.global_model`` ================================== =========================== """ ctx.model = ctx.global_model def _hook_on_fit_end_free_cuda(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.global_model`` Move to ``cpu`` ``ctx.locol_model`` Move to ``cpu`` ================================== =========================== """ ctx.global_model.to(torch.device("cpu")) ctx.local_model.to(torch.device("cpu"))