Source code for federatedscope.core.trainers.trainer_nbafl

from federatedscope.core.trainers.utils import get_random
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from typing import Type
from copy import deepcopy

import numpy as np
import torch


[docs]def wrap_nbafl_trainer( base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: """Implementation of NbAFL refer to `Federated Learning with Differential Privacy: Algorithms and Performance Analysis` [et al., 2020] (https://ieeexplore.ieee.org/abstract/document/9069945/) Arguments: mu: the factor of the regularizer epsilon: the distinguishable bound w_clip: the threshold to clip weights """ # ---------------- attribute-level plug-in ----------------------- init_nbafl_ctx(base_trainer) # ---------------- action-level plug-in ----------------------- base_trainer.register_hook_in_train(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) base_trainer.register_hook_in_eval(new_hook=_hook_record_initialization, trigger='on_fit_start', insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) base_trainer.register_hook_in_eval(new_hook=_hook_del_initialization, trigger='on_fit_end', insert_pos=-1) base_trainer.register_hook_in_train(new_hook=_hook_inject_noise_in_upload, trigger='on_fit_end', insert_pos=-1) return base_trainer
def init_nbafl_ctx(base_trainer): """Set proximal regularizer, and the scale of gaussian noise """ ctx = base_trainer.ctx cfg = base_trainer.cfg # set proximal regularizer cfg.defrost() cfg.regularizer.type = 'proximal_regularizer' cfg.regularizer.mu = cfg.nbafl.mu cfg.freeze() from federatedscope.core.auxiliaries.regularizer_builder import \ get_regularizer ctx.regularizer = get_regularizer(cfg.regularizer.type) # set noise scale during upload if cfg.trainer.type == 'nodefullbatch_trainer': num_train_data = sum(ctx.train_loader.dataset[0]['train_mask']) else: num_train_data = ctx.num_train_data ctx.nbafl_scale_u = cfg.nbafl.w_clip * cfg.federate.total_round_num * \ cfg.nbafl.constant / num_train_data / \ cfg.nbafl.epsilon # ---------------------------------------------------------------------- # # Additional functions for NbAFL algorithm # ---------------------------------------------------------------------- # # Trainer def _hook_record_initialization(ctx): """ Record the initialized weights within local updates Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.weight_init`` Copy from `ctx.model` ================================== =========================== """ ctx.weight_init = deepcopy( [_.data.detach() for _ in ctx.model.parameters()]) def _hook_del_initialization(ctx): """ Clear the variable to avoid memory leakage Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.weight_init`` Set to `None` ================================== =========================== """ ctx.weight_init = None def _hook_inject_noise_in_upload(ctx): """ Inject noise into weights before the client upload them to server Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.model`` Inject noise to parameters ================================== =========================== """ for p in ctx.model.parameters(): noise = get_random("Normal", p.shape, { "loc": 0, "scale": ctx.nbafl_scale_u }, p.device) p.data += noise # Server def inject_noise_in_broadcast(cfg, sample_client_num, model): """Inject noise into weights before the server broadcasts them """ # Clip weight for p in model.parameters(): p.data = p.data / torch.max( torch.ones(size=p.shape, device=p.data.device), torch.abs(p.data) / cfg.nbafl.w_clip) if len(sample_client_num) > 0: # Inject noise L = cfg.federate.sample_client_num if cfg.federate.sample_client_num\ > 0 else cfg.federate.client_num if cfg.federate.total_round_num > np.sqrt(cfg.federate.client_num) * L: scale_d = 2 * cfg.nbafl.w_clip * cfg.nbafl.constant * np.sqrt( np.power(cfg.federate.total_round_num, 2) - np.power(L, 2) * cfg.federate.client_num) / ( min(sample_client_num) * cfg.federate.client_num * cfg.nbafl.epsilon) for p in model.parameters(): p.data += get_random("Normal", p.shape, { "loc": 0, "scale": scale_d }, p.device) # def wrap_nbafl_server(server: Type[Server]) -> Type[Server]:
[docs]def wrap_nbafl_server(server): """Register noise injector for the server """ server.register_noise_injector(inject_noise_in_broadcast)