from federatedscope.core.trainers.utils import get_random
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from typing import Type
from copy import deepcopy
import numpy as np
import torch
[docs]def wrap_nbafl_trainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
"""Implementation of NbAFL refer to `Federated Learning with
Differential Privacy: Algorithms and Performance Analysis` [et al., 2020]
(https://ieeexplore.ieee.org/abstract/document/9069945/)
Arguments:
mu: the factor of the regularizer
epsilon: the distinguishable bound
w_clip: the threshold to clip weights
"""
# ---------------- attribute-level plug-in -----------------------
init_nbafl_ctx(base_trainer)
# ---------------- action-level plug-in -----------------------
base_trainer.register_hook_in_train(new_hook=_hook_record_initialization,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_eval(new_hook=_hook_record_initialization,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_train(new_hook=_hook_del_initialization,
trigger='on_fit_end',
insert_pos=-1)
base_trainer.register_hook_in_eval(new_hook=_hook_del_initialization,
trigger='on_fit_end',
insert_pos=-1)
base_trainer.register_hook_in_train(new_hook=_hook_inject_noise_in_upload,
trigger='on_fit_end',
insert_pos=-1)
return base_trainer
def init_nbafl_ctx(base_trainer):
"""Set proximal regularizer, and the scale of gaussian noise
"""
ctx = base_trainer.ctx
cfg = base_trainer.cfg
# set proximal regularizer
cfg.defrost()
cfg.regularizer.type = 'proximal_regularizer'
cfg.regularizer.mu = cfg.nbafl.mu
cfg.freeze()
from federatedscope.core.auxiliaries.regularizer_builder import \
get_regularizer
ctx.regularizer = get_regularizer(cfg.regularizer.type)
# set noise scale during upload
if cfg.trainer.type == 'nodefullbatch_trainer':
num_train_data = sum(ctx.train_loader.dataset[0]['train_mask'])
else:
num_train_data = ctx.num_train_data
ctx.nbafl_scale_u = cfg.nbafl.w_clip * cfg.federate.total_round_num * \
cfg.nbafl.constant / num_train_data / \
cfg.nbafl.epsilon
# ---------------------------------------------------------------------- #
# Additional functions for NbAFL algorithm
# ---------------------------------------------------------------------- #
# Trainer
def _hook_record_initialization(ctx):
"""
Record the initialized weights within local updates
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.weight_init`` Copy from `ctx.model`
================================== ===========================
"""
ctx.weight_init = deepcopy(
[_.data.detach() for _ in ctx.model.parameters()])
def _hook_del_initialization(ctx):
"""
Clear the variable to avoid memory leakage
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.weight_init`` Set to `None`
================================== ===========================
"""
ctx.weight_init = None
def _hook_inject_noise_in_upload(ctx):
"""
Inject noise into weights before the client upload them to server
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.model`` Inject noise to parameters
================================== ===========================
"""
for p in ctx.model.parameters():
noise = get_random("Normal", p.shape, {
"loc": 0,
"scale": ctx.nbafl_scale_u
}, p.device)
p.data += noise
# Server
def inject_noise_in_broadcast(cfg, sample_client_num, model):
"""Inject noise into weights before the server broadcasts them
"""
# Clip weight
for p in model.parameters():
p.data = p.data / torch.max(
torch.ones(size=p.shape, device=p.data.device),
torch.abs(p.data) / cfg.nbafl.w_clip)
if len(sample_client_num) > 0:
# Inject noise
L = cfg.federate.sample_client_num if cfg.federate.sample_client_num\
> 0 else cfg.federate.client_num
if cfg.federate.total_round_num > np.sqrt(cfg.federate.client_num) * L:
scale_d = 2 * cfg.nbafl.w_clip * cfg.nbafl.constant * np.sqrt(
np.power(cfg.federate.total_round_num, 2) -
np.power(L, 2) * cfg.federate.client_num) / (
min(sample_client_num) * cfg.federate.client_num *
cfg.nbafl.epsilon)
for p in model.parameters():
p.data += get_random("Normal", p.shape, {
"loc": 0,
"scale": scale_d
}, p.device)
# def wrap_nbafl_server(server: Type[Server]) -> Type[Server]:
[docs]def wrap_nbafl_server(server):
"""Register noise injector for the server
"""
server.register_noise_injector(inject_noise_in_broadcast)