Source code for federatedscope.core.aggregators.server_clients_interpolate_aggregator

from federatedscope.core.aggregators import ClientsAvgAggregator


[docs]class ServerClientsInterpolateAggregator(ClientsAvgAggregator): """ conduct aggregation by interpolating global model from server and \ local models from clients """ def __init__(self, model=None, device='cpu', config=None, beta=1.0): super(ServerClientsInterpolateAggregator, self).__init__(model, device, config) self.beta = beta # the weight for local models used in interpolation
[docs] def aggregate(self, agg_info): """ Returns the aggregated value """ models = agg_info["client_feedback"] global_model = self.model elem_each_client = next(iter(models)) assert len(elem_each_client) == 2, f"Require (sample_size, " \ f"model_para) tuple for each " \ f"client, i.e., len=2, but got " \ f"len={len(elem_each_client)}" avg_model_by_clients = self._para_weighted_avg(models) global_local_models = [((1 - self.beta), global_model.state_dict()), (self.beta, avg_model_by_clients)] avg_model_by_interpolate = self._para_weighted_avg(global_local_models) return avg_model_by_interpolate