Source code for federatedscope.core.aggregators.aggregator

import os
import torch
from abc import ABC, abstractmethod

[docs]class Aggregator(ABC): """ Abstract class of Aggregator. """ def __init__(self): pass
[docs] @abstractmethod def aggregate(self, agg_info): """ Aggregation function. Args: agg_info: information to be aggregated. """ pass
[docs]class NoCommunicationAggregator(Aggregator): """Clients do not communicate. Each client work locally """ def __init__(self, model=None, device='cpu', config=None): super(Aggregator, self).__init__() self.model = model self.device = device self.cfg = config
[docs] def update(self, model_parameters): ''' Arguments: model_parameters (dict): PyTorch Module object's state_dict. ''' self.model.load_state_dict(model_parameters, strict=False)
def save_model(self, path, cur_round=-1): assert self.model is not None ckpt = {'cur_round': cur_round, 'model': self.model.state_dict()}, path) def load_model(self, path): assert self.model is not None if os.path.exists(path): ckpt = torch.load(path, map_location=self.device) self.model.load_state_dict(ckpt['model']) return ckpt['cur_round'] else: raise ValueError("The file {} does NOT exist".format(path))
[docs] def aggregate(self, agg_info): """ Aggregation function. Args: agg_info: information to be aggregated. """ # do nothing return {}