import logging
import copy
import torch
import numpy as np
from federatedscope.core.aggregators import ClientsAvgAggregator
logger = logging.getLogger(__name__)
[docs]class NormboundingAggregator(ClientsAvgAggregator):
"""
The server clips each update to reduce the negative impact \
of malicious updates.
"""
def __init__(self, model=None, device='cpu', config=None):
super(NormboundingAggregator, self).__init__(model, device, config)
self.norm_bound = config.aggregator.BFT_args.normbounding_norm_bound
[docs] def aggregate(self, agg_info):
"""
To preform aggregation with normbounding aggregation rule
Arguments:
agg_info (dict): the feedbacks from clients
:returns: the aggregated results
:rtype: dict
"""
models = agg_info["client_feedback"]
avg_model = self._aggre_with_normbounding(models)
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model
def _aggre_with_normbounding(self, models):
models_temp = []
for each_model in models:
param = self._flatten_updates(each_model[1])
if torch.norm(param, p=2) > self.norm_bound:
scaling_rate = self.norm_bound / torch.norm(param, p=2)
scaled_param = scaling_rate * param
models_temp.append(
(each_model[0], self._reconstruct_updates(scaled_param)))
else:
models_temp.append(each_model)
return self._para_weighted_avg(models_temp)
def _flatten_updates(self, model):
model_update = []
init_model = self.model.state_dict()
for key in init_model:
model_update.append(model[key].view(-1))
return torch.cat(model_update, dim=0)
def _reconstruct_updates(self, flatten_updates):
start_idx = 0
init_model = self.model.state_dict()
reconstructed_model = copy.deepcopy(init_model)
for key in init_model:
reconstructed_model[key] = flatten_updates[
start_idx:start_idx + len(init_model[key].view(-1))].reshape(
init_model[key].shape)
start_idx = start_idx + len(init_model[key].view(-1))
return reconstructed_model