import os
import torch
from federatedscope.core.aggregators import Aggregator
from federatedscope.core.auxiliaries.utils import param2tensor
[docs]class ClientsAvgAggregator(Aggregator):
"""
Implementation of vanilla FedAvg refer to 'Communication-efficient \
learning of deep networks from decentralized data' [McMahan et al., 2017] \
http://proceedings.mlr.press/v54/mcmahan17a.html
"""
def __init__(self, model=None, device='cpu', config=None):
super(Aggregator, self).__init__()
self.model = model
self.device = device
self.cfg = config
[docs] def aggregate(self, agg_info):
"""
To preform aggregation
Arguments:
agg_info (dict): the feedbacks from clients
Returns:
dict: the aggregated results
"""
models = agg_info["client_feedback"]
recover_fun = agg_info['recover_fun'] if (
'recover_fun' in agg_info and self.cfg.federate.use_ss) else None
avg_model = self._para_weighted_avg(models, recover_fun=recover_fun)
return avg_model
[docs] def update(self, model_parameters):
"""
Arguments:
model_parameters (dict): PyTorch Module object's state_dict.
"""
self.model.load_state_dict(model_parameters, strict=False)
def save_model(self, path, cur_round=-1):
assert self.model is not None
ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()}
torch.save(ckpt, path)
def load_model(self, path):
assert self.model is not None
if os.path.exists(path):
ckpt = torch.load(path, map_location=self.device)
self.model.load_state_dict(ckpt['model'])
return ckpt['cur_round']
else:
raise ValueError("The file {} does NOT exist".format(path))
def _para_weighted_avg(self, models, recover_fun=None):
"""
Calculates the weighted average of models.
"""
training_set_size = 0
for i in range(len(models)):
sample_size, _ = models[i]
training_set_size += sample_size
sample_size, avg_model = models[0]
for key in avg_model:
for i in range(len(models)):
local_sample_size, local_model = models[i]
if self.cfg.federate.ignore_weight:
weight = 1.0 / len(models)
elif self.cfg.federate.use_ss:
# When using secret sharing, what the server receives
# are sample_size * model_para
weight = 1.0
else:
weight = local_sample_size / training_set_size
if not self.cfg.federate.use_ss:
local_model[key] = param2tensor(local_model[key])
if i == 0:
avg_model[key] = local_model[key] * weight
else:
avg_model[key] += local_model[key] * weight
if self.cfg.federate.use_ss and recover_fun:
avg_model[key] = recover_fun(avg_model[key])
# When using secret sharing, what the server receives are
# sample_size * model_para
avg_model[key] /= training_set_size
avg_model[key] = torch.FloatTensor(avg_model[key])
return avg_model
[docs]class OnlineClientsAvgAggregator(ClientsAvgAggregator):
"""
Implementation of online aggregation of FedAvg.
"""
def __init__(self,
model=None,
device='cpu',
src_device='cpu',
config=None):
super(OnlineClientsAvgAggregator, self).__init__(model, device, config)
self.src_device = src_device
[docs] def reset(self):
"""
Reset the state of the model to its initial state
"""
self.maintained = self.model.state_dict()
for key in self.maintained:
self.maintained[key].data = torch.zeros_like(
self.maintained[key], device=self.src_device)
self.cnt = 0
[docs] def inc(self, content):
"""
Increment the model weight by the given content.
"""
if isinstance(content, tuple):
sample_size, model_params = content
for key in self.maintained:
# if model_params[key].device != self.maintained[key].device:
# model_params[key].to(self.maintained[key].device)
self.maintained[key] = (self.cnt * self.maintained[key] +
sample_size * model_params[key]) / (
self.cnt + sample_size)
self.cnt += sample_size
else:
raise TypeError(
"{} is not a tuple (sample_size, model_para)".format(content))
[docs] def aggregate(self, agg_info):
"""
Returns the aggregated value
"""
return self.maintained