import logging
from federatedscope.core.configs import constants
logger = logging.getLogger(__name__)
[docs]def get_aggregator(method, model=None, device=None, online=False, config=None):
"""
This function builds an aggregator, which is a protocol for aggregate \
all clients' model(s).
Arguments:
method: key to determine which aggregator to use
model: model to be aggregated
device: where to aggregate models (``cpu`` or ``gpu``)
online: ``True`` or ``False`` to use online aggregator.
config: configurations for FL, see ``federatedscope.core.configs``
Returns:
An instance of aggregator (see ``core.aggregator`` for details)
Note:
The key-value pairs of ``method`` and aggregators:
================================== ===========================
Method Aggregator
================================== ===========================
``tensorflow`` ``cross_backends.FedAvgAggregator``
``local`` \
``core.aggregators.NoCommunicationAggregator``
``global`` \
``core.aggregators.NoCommunicationAggregator``
``fedavg`` \
``core.aggregators.OnlineClientsAvgAggregator`` or \
``core.aggregators.AsynClientsAvgAggregator`` or \
``ClientsAvgAggregator``
``pfedme`` \
``core.aggregators.ServerClientsInterpolateAggregator``
``ditto`` \
``core.aggregators.OnlineClientsAvgAggregator`` or \
``core.aggregators.AsynClientsAvgAggregator`` or \
``ClientsAvgAggregator``
``fedsageplus`` \
``core.aggregators.OnlineClientsAvgAggregator`` or \
``core.aggregators.AsynClientsAvgAggregator`` or \
``ClientsAvgAggregator``
``gcflplus`` \
``core.aggregators.OnlineClientsAvgAggregator`` or \
``core.aggregators.AsynClientsAvgAggregator`` or \
``ClientsAvgAggregator``
``fedopt`` \
``core.aggregators.FedOptAggregator``
================================== ===========================
"""
if config.backend == 'tensorflow':
from federatedscope.cross_backends import FedAvgAggregator
return FedAvgAggregator(model=model, device=device)
else:
from federatedscope.core.aggregators import ClientsAvgAggregator, \
OnlineClientsAvgAggregator, ServerClientsInterpolateAggregator, \
FedOptAggregator, NoCommunicationAggregator, \
AsynClientsAvgAggregator, KrumAggregator, \
MedianAggregator, TrimmedmeanAggregator, \
BulyanAggregator, NormboundingAggregator
STR2AGG = {
'fedavg': ClientsAvgAggregator,
'krum': KrumAggregator,
'median': MedianAggregator,
'bulyan': BulyanAggregator,
'trimmedmean': TrimmedmeanAggregator,
'normbounding': NormboundingAggregator
}
if method.lower() in constants.AGGREGATOR_TYPE:
aggregator_type = constants.AGGREGATOR_TYPE[method.lower()]
else:
aggregator_type = "clients_avg"
logger.warning(
'Aggregator for method {} is not implemented. Will use default one'
.format(method))
if config.data.type.lower() == 'hetero_nlp_tasks' and \
not config.federate.atc_vanilla:
from federatedscope.nlp.hetero_tasks.aggregator import ATCAggregator
return ATCAggregator(model=model, config=config, device=device)
if config.fedopt.use or aggregator_type == 'fedopt':
return FedOptAggregator(config=config, model=model, device=device)
elif aggregator_type == 'clients_avg':
if online:
return OnlineClientsAvgAggregator(
model=model,
device=device,
config=config,
src_device=device
if config.federate.share_local_model else 'cpu')
elif config.asyn.use:
return AsynClientsAvgAggregator(model=model,
device=device,
config=config)
else:
if config.aggregator.robust_rule not in STR2AGG:
logger.warning(
f'The specified {config.aggregator.robust_rule} aggregtion\
rule has not been supported, the vanilla fedavg algorithm \
will be used instead.')
return STR2AGG.get(config.aggregator.robust_rule,
ClientsAvgAggregator)(model=model,
device=device,
config=config)
elif aggregator_type == 'server_clients_interpolation':
return ServerClientsInterpolateAggregator(
model=model,
device=device,
config=config,
beta=config.personalization.beta)
elif aggregator_type == 'no_communication':
return NoCommunicationAggregator(model=model,
device=device,
config=config)
else:
raise NotImplementedError(
"Aggregator {} is not implemented.".format(aggregator_type))