Source code for federatedscope.core.aggregators.asyn_clients_avg_aggregator

import copy
import torch
from federatedscope.core.aggregators import ClientsAvgAggregator


[docs]class AsynClientsAvgAggregator(ClientsAvgAggregator): """ The aggregator used in asynchronous training, which discounts the \ staled model updates """ def __init__(self, model=None, device='cpu', config=None): super(AsynClientsAvgAggregator, self).__init__(model, device, config)
[docs] def aggregate(self, agg_info): """ To preform aggregation Arguments: agg_info (dict): the feedbacks from clients Returns: dict: the aggregated results """ models = agg_info["client_feedback"] recover_fun = agg_info['recover_fun'] if ( 'recover_fun' in agg_info and self.cfg.federate.use_ss) else None staleness = [x[1] for x in agg_info['staleness']] # (client_id, staleness) avg_model = self._para_weighted_avg(models, recover_fun=recover_fun, staleness=staleness) # When using asynchronous training, the return feedback is model delta # rather than the model param updated_model = copy.deepcopy(avg_model) init_model = self.model.state_dict() for key in avg_model: updated_model[key] = init_model[key] + avg_model[key] return updated_model
[docs] def discount_func(self, staleness): """ Served as an example, we discount the model update with staleness tau \ as: ``(1.0/((1.0+\tau)**factor))``, \ which has been used in previous studies such as FedAsync ( \ Asynchronous Federated Optimization) and FedBuff \ (Federated Learning with Buffered Asynchronous Aggregation). """ return (1.0 / ((1.0 + staleness)**self.cfg.asyn.staleness_discount_factor))
def _para_weighted_avg(self, models, recover_fun=None, staleness=None): """ Calculates the weighted average of models. """ training_set_size = 0 for i in range(len(models)): sample_size, _ = models[i] training_set_size += sample_size sample_size, avg_model = models[0] for key in avg_model: for i in range(len(models)): local_sample_size, local_model = models[i] if self.cfg.federate.ignore_weight: weight = 1.0 / len(models) else: weight = local_sample_size / training_set_size assert staleness is not None weight *= self.discount_func(staleness[i]) if isinstance(local_model[key], torch.Tensor): local_model[key] = local_model[key].float() else: local_model[key] = torch.FloatTensor(local_model[key]) if i == 0: avg_model[key] = local_model[key] * weight else: avg_model[key] += local_model[key] * weight return avg_model