Source code for federatedscope.core.aggregators.median_aggregator
import copy
import torch
import numpy as np
from federatedscope.core.aggregators import ClientsAvgAggregator
import logging
logger = logging.getLogger(__name__)
[docs]class MedianAggregator(ClientsAvgAggregator):
"""
Implementation of median refers to `Byzantine-robust distributed
learning: Towards optimal statistical rates`
[Yin et al., 2018]
(http://proceedings.mlr.press/v80/yin18a/yin18a.pdf)
It computes the coordinate-wise median of recieved updates from clients
The code is adapted from https://github.com/bladesteam/blades
"""
def __init__(self, model=None, device='cpu', config=None):
super(MedianAggregator, self).__init__(model, device, config)
self.byzantine_node_num = config.aggregator.byzantine_node_num
assert 2 * self.byzantine_node_num + 2 < config.federate.client_num, \
"it should be satisfied that 2*byzantine_node_num + 2 < client_num"
[docs] def aggregate(self, agg_info):
"""
To preform aggregation with Median aggregation rule
Arguments:
agg_info (dict): the feedbacks from clients
:returns: the aggregated results
:rtype: dict
"""
models = agg_info["client_feedback"]
avg_model = self._aggre_with_median(models)
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model
def _aggre_with_median(self, models):
init_model = self.model.state_dict()
global_update = copy.deepcopy(init_model)
for key in init_model:
temp = torch.stack([each_model[1][key] for each_model in models],
0)
temp_pos, _ = torch.median(temp, dim=0)
temp_neg, _ = torch.median(-temp, dim=0)
global_update[key] = (temp_pos - temp_neg) / 2
return global_update