Source code for federatedscope.core.aggregators.normbounding_aggregator

import logging
import copy
import torch
import numpy as np
from federatedscope.core.aggregators import ClientsAvgAggregator

logger = logging.getLogger(__name__)


[docs]class NormboundingAggregator(ClientsAvgAggregator): """ The server clips each update to reduce the negative impact \ of malicious updates. """ def __init__(self, model=None, device='cpu', config=None): super(NormboundingAggregator, self).__init__(model, device, config) self.norm_bound = config.aggregator.BFT_args.normbounding_norm_bound
[docs] def aggregate(self, agg_info): """ To preform aggregation with normbounding aggregation rule Arguments: agg_info (dict): the feedbacks from clients :returns: the aggregated results :rtype: dict """ models = agg_info["client_feedback"] avg_model = self._aggre_with_normbounding(models) updated_model = copy.deepcopy(avg_model) init_model = self.model.state_dict() for key in avg_model: updated_model[key] = init_model[key] + avg_model[key] return updated_model
def _aggre_with_normbounding(self, models): models_temp = [] for each_model in models: param = self._flatten_updates(each_model[1]) if torch.norm(param, p=2) > self.norm_bound: scaling_rate = self.norm_bound / torch.norm(param, p=2) scaled_param = scaling_rate * param models_temp.append( (each_model[0], self._reconstruct_updates(scaled_param))) else: models_temp.append(each_model) return self._para_weighted_avg(models_temp) def _flatten_updates(self, model): model_update = [] init_model = self.model.state_dict() for key in init_model: model_update.append(model[key].view(-1)) return torch.cat(model_update, dim=0) def _reconstruct_updates(self, flatten_updates): start_idx = 0 init_model = self.model.state_dict() reconstructed_model = copy.deepcopy(init_model) for key in init_model: reconstructed_model[key] = flatten_updates[ start_idx:start_idx + len(init_model[key].view(-1))].reshape( init_model[key].shape) start_idx = start_idx + len(init_model[key].view(-1)) return reconstructed_model