Source code for federatedscope.core.trainers.trainer_fedprox

from typing import Type

from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from federatedscope.core.trainers.trainer_nbafl import \
    _hook_record_initialization, _hook_del_initialization


[docs]def wrap_fedprox_trainer( base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: """Implementation of fedprox refer to `Federated Optimization in Heterogeneous Networks` [Tian Li, et al., 2020] (https://proceedings.mlsys.org/paper/2020/ \ file/38af86134b65d0f10fe33d30dd76442e-Paper.pdf) """ # ---------------- attribute-level plug-in ----------------------- init_fedprox_ctx(base_trainer) # ---------------- action-level plug-in ----------------------- base_trainer.register_hook_in_train(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) base_trainer.register_hook_in_eval(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) base_trainer.register_hook_in_eval(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) return base_trainer
def init_fedprox_ctx(base_trainer): """Set proximal regularizer and the factor of regularizer """ ctx = base_trainer.ctx cfg = base_trainer.cfg cfg.defrost() cfg.regularizer.type = 'proximal_regularizer' cfg.regularizer.mu = cfg.fedprox.mu cfg.freeze() from federatedscope.core.auxiliaries.regularizer_builder import \ get_regularizer ctx.regularizer = get_regularizer(cfg.regularizer.type)