Source code for federatedscope.core.trainers.trainer_pFedMe

import copy

from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from federatedscope.core.optimizer import wrap_regularized_optimizer
from typing import Type


[docs]def wrap_pFedMeTrainer( base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: """ Build a `pFedMeTrainer` with a plug-in manner, by registering new functions into specific `BaseTrainer` The pFedMe implementation, "Personalized Federated Learning with Moreau Envelopes (NeurIPS 2020)" is based on the Algorithm 1 in their paper and official codes: https://github.com/CharlieDinh/pFedMe """ # ---------------- attribute-level plug-in ----------------------- init_pFedMe_ctx(base_trainer) # ---------------- action-level plug-in ----------------------- base_trainer.register_hook_in_train( new_hook=_hook_on_fit_start_set_local_para_tmp, trigger="on_fit_start", insert_pos=-1) base_trainer.register_hook_in_train( new_hook=_hook_on_epoch_end_update_local, trigger="on_epoch_end", insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_on_fit_end_update_local, trigger="on_fit_end", insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_on_batch_end_flop_count, trigger="on_batch_end", insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_on_epoch_end_flop_count, trigger="on_epoch_end", insert_pos=-1) # for "on_batch_start" trigger: replace the original hooks into new ones # of pFedMe # 1) cache the original hooks for "on_batch_start" base_trainer.ctx.original_hook_on_batch_start_train = \ base_trainer.hooks_in_train["on_batch_start"] # 2) replace the original hooks for "on_batch_start" base_trainer.replace_hook_in_train( new_hook=_hook_on_batch_start_init_pfedme, target_trigger="on_batch_start", target_hook_name=None) return base_trainer
def init_pFedMe_ctx(base_trainer): """ init necessary attributes used in pFedMe, some new attributes will be with prefix `pFedMe` optimizer to avoid namespace pollution Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.optimizer_for_global_model`` False ================================== =========================== """ ctx = base_trainer.ctx cfg = base_trainer.cfg # pFedMe finds approximate model with K steps using the same data batch # the complexity of each pFedMe client is K times the one of FedAvg ctx.pFedMe_K = cfg.personalization.K ctx.num_train_epoch *= ctx.pFedMe_K ctx.pFedMe_approx_fit_counter = 0 # the local_model_tmp is used to be the referenced parameter when # finding the approximate \theta in paper # will be copied from model every run_routine ctx.pFedMe_local_model_tmp = None def _hook_on_fit_start_set_local_para_tmp(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.optimizer`` Wrapped by \ ``wrap_regularized_optimizer`` and set compared parameter group ``ctx.pFedMe_outer_lr`` Initialize to \ ``ctx.cfg.train.optimizer.lr`` ``ctx.pFedMe_local_model_tmp`` Copy from ``ctx.model`` ================================== =========================== """ # the optimizer used in pFedMe is based on Moreau Envelopes regularization # besides, there are two distinct lr for the approximate model and base # model ctx.optimizer = wrap_regularized_optimizer( ctx.optimizer, ctx.cfg.personalization.regular_weight) for g in ctx.optimizer.param_groups: g['lr'] = ctx.cfg.personalization.lr ctx.pFedMe_outer_lr = ctx.cfg.train.optimizer.lr ctx.pFedMe_local_model_tmp = copy.deepcopy(ctx.model) # set the compared model data, then the optimizer will find approximate # model using trainer.cfg.personalization.lr compared_global_model_para = [{ "params": list(ctx.pFedMe_local_model_tmp.parameters()) }] ctx.optimizer.set_compared_para_group(compared_global_model_para) def _hook_on_batch_start_init_pfedme(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.data_batch_cache`` Copy from ``ctx.data_batch`` ``ctx.pFedMe_approx_fit_counter`` Count to refresh data every K step ================================== =========================== """ # refresh data every K step if ctx.pFedMe_approx_fit_counter == 0: if ctx.cur_mode == "train": for hook in ctx.original_hook_on_batch_start_train: hook(ctx) else: for hook in ctx.original_hook_on_batch_start_eval: hook(ctx) ctx.data_batch_cache = copy.deepcopy(ctx.data_batch) else: # reuse the data_cache since the original hook `_hook_on_batch_end` # will clean `data_batch` ctx.data_batch = copy.deepcopy(ctx.data_batch_cache) ctx.pFedMe_approx_fit_counter = (ctx.pFedMe_approx_fit_counter + 1) % ctx.pFedMe_K def _hook_on_batch_end_flop_count(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.monitor`` Monitor total flops ================================== =========================== """ # besides the normal forward flops, pFedMe introduces # 1) the regularization adds the cost of number of model parameters ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 def _hook_on_epoch_end_flop_count(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.monitor`` Monitor total flops ================================== =========================== """ # due to the local weight updating ctx.monitor.total_flops += ctx.monitor.total_model_size / 2 def _hook_on_epoch_end_update_local(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.model`` Update parameters by \ ``ctx.pFedMe_local_model_tmp`` ``ctx.optimizer`` Set compared parameter group ================================== =========================== """ # update local weight after finding approximate theta for client_param, local_para_tmp in zip( ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()): local_para_tmp.data = local_para_tmp.data - \ ctx.optimizer.regular_weight * \ ctx.pFedMe_outer_lr * (local_para_tmp.data - client_param.data) # set the compared model data, then the optimizer will find approximate # model using trainer.cfg.personalization.lr compared_global_model_para = [{ "params": list(ctx.pFedMe_local_model_tmp.parameters()) }] ctx.optimizer.set_compared_para_group(compared_global_model_para) def _hook_on_fit_end_update_local(ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.model`` Update parameters by ``ctx.pFedMe_local_model_tmp`` ``ctx.pFedMe_local_model_tmp`` Delete ================================== =========================== """ for param, local_para_tmp in zip(ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()): param.data = local_para_tmp.data del ctx.pFedMe_local_model_tmp