import copy
import torch
from federatedscope.core.aggregators import ClientsAvgAggregator
[docs]class AsynClientsAvgAggregator(ClientsAvgAggregator):
"""
The aggregator used in asynchronous training, which discounts the \
staled model updates
"""
def __init__(self, model=None, device='cpu', config=None):
super(AsynClientsAvgAggregator, self).__init__(model, device, config)
[docs] def aggregate(self, agg_info):
"""
To preform aggregation
Arguments:
agg_info (dict): the feedbacks from clients
Returns:
dict: the aggregated results
"""
models = agg_info["client_feedback"]
recover_fun = agg_info['recover_fun'] if (
'recover_fun' in agg_info and self.cfg.federate.use_ss) else None
staleness = [x[1]
for x in agg_info['staleness']] # (client_id, staleness)
avg_model = self._para_weighted_avg(models,
recover_fun=recover_fun,
staleness=staleness)
# When using asynchronous training, the return feedback is model delta
# rather than the model param
updated_model = copy.deepcopy(avg_model)
init_model = self.model.state_dict()
for key in avg_model:
updated_model[key] = init_model[key] + avg_model[key]
return updated_model
[docs] def discount_func(self, staleness):
"""
Served as an example, we discount the model update with staleness tau \
as: ``(1.0/((1.0+\tau)**factor))``, \
which has been used in previous studies such as FedAsync ( \
Asynchronous Federated Optimization) and FedBuff \
(Federated Learning with Buffered Asynchronous Aggregation).
"""
return (1.0 /
((1.0 + staleness)**self.cfg.asyn.staleness_discount_factor))
def _para_weighted_avg(self, models, recover_fun=None, staleness=None):
"""
Calculates the weighted average of models.
"""
training_set_size = 0
for i in range(len(models)):
sample_size, _ = models[i]
training_set_size += sample_size
sample_size, avg_model = models[0]
for key in avg_model:
for i in range(len(models)):
local_sample_size, local_model = models[i]
if self.cfg.federate.ignore_weight:
weight = 1.0 / len(models)
else:
weight = local_sample_size / training_set_size
assert staleness is not None
weight *= self.discount_func(staleness[i])
if isinstance(local_model[key], torch.Tensor):
local_model[key] = local_model[key].float()
else:
local_model[key] = torch.FloatTensor(local_model[key])
if i == 0:
avg_model[key] = local_model[key] * weight
else:
avg_model[key] += local_model[key] * weight
return avg_model