Source code for federatedscope.core.auxiliaries.trainer_builder

import logging
import importlib

import federatedscope.register as register
from federatedscope.core.trainers import Trainer

logger = logging.getLogger(__name__)

try:
    from federatedscope.contrib.trainer import *
except ImportError as error:
    logger.warning(
        f'{error} in `federatedscope.contrib.trainer`, some modules are not '
        f'available.')

TRAINER_CLASS_DICT = {
    "cvtrainer": "CVTrainer",
    "nlptrainer": "NLPTrainer",
    "graphminibatch_trainer": "GraphMiniBatchTrainer",
    "linkfullbatch_trainer": "LinkFullBatchTrainer",
    "linkminibatch_trainer": "LinkMiniBatchTrainer",
    "nodefullbatch_trainer": "NodeFullBatchTrainer",
    "nodeminibatch_trainer": "NodeMiniBatchTrainer",
    "flitplustrainer": "FLITPlusTrainer",
    "flittrainer": "FLITTrainer",
    "fedvattrainer": "FedVATTrainer",
    "fedfocaltrainer": "FedFocalTrainer",
    "mftrainer": "MFTrainer",
    "cltrainer": "CLTrainer",
    "lptrainer": "LPTrainer",
    "atc_trainer": "ATCTrainer",
}


[docs]def get_trainer(model=None, data=None, device=None, config=None, only_for_eval=False, is_attacker=False, monitor=None): """ This function builds an instance of trainer. Arguments: model: model used in FL course data: data used in FL course device: where to train model (``cpu`` or ``gpu``) config: configurations for FL, see ``federatedscope.core.configs`` only_for_eval: ``True`` or ``False``, if ``True``, ``train`` \ routine will be removed in this trainer is_attacker: ``True`` or ``False`` to determine whether this client \ is an attacker monitor: an instance of ``federatedscope.core.monitors.Monitor`` to \ observe the evaluation and system metrics Returns: An instance of trainer. Note: The key-value pairs of ``cfg.trainer.type`` and trainers: ================================== =========================== Trainer Type Source ================================== =========================== ``general`` \ ``core.trainers.GeneralTorchTrainer`` and \ ``core.trainers.GeneralTFTrainer`` ``cvtrainer`` ``cv.trainer.trainer.CVTrainer`` ``nlptrainer`` ``nlp.trainer.trainer.NLPTrainer`` ``graphminibatch_trainer`` \ ``gfl.trainer.graphtrainer.GraphMiniBatchTrainer`` ``linkfullbatch_trainer`` \ ``gfl.trainer.linktrainer.LinkFullBatchTrainer`` ``linkminibatch_trainer`` \ ``gfl.trainer.linktrainer.LinkMiniBatchTrainer`` ``nodefullbatch_trainer`` \ ``gfl.trainer.nodetrainer.NodeFullBatchTrainer`` ``nodeminibatch_trainer`` \ ``gfl.trainer.nodetrainer.NodeMiniBatchTrainer`` ``flitplustrainer`` \ ``gfl.flitplus.trainer.FLITPlusTrainer`` ``flittrainer`` \ ``gfl.flitplus.trainer.FLITTrainer`` ``fedvattrainer`` \ ``gfl.flitplus.trainer.FedVATTrainer`` ``fedfocaltrainer`` \ ``gfl.flitplus.trainer.FedFocalTrainer`` ``mftrainer`` \ ``federatedscope.mf.trainer.MFTrainer`` ``mytorchtrainer`` \ ``contrib.trainer.torch_example.MyTorchTrainer`` ================================== =========================== Wrapper functions are shown below: ================================== =========================== Wrapper Functions Source ================================== =========================== ``nbafl`` \ ``core.trainers.wrap_nbafl_trainer`` ``sgdmf`` ``mf.trainer.wrap_MFTrainer`` ``pfedme`` \ ``core.trainers.wrap_pFedMeTrainer`` ``ditto`` ``core.trainers.wrap_DittoTrainer`` ``fedem`` ``core.trainers.FedEMTrainer`` ``fedprox`` \ ``core.trainers.wrap_fedprox_trainer`` ``attack`` \ ``attack.trainer.wrap_benignTrainer`` and \ ``attack.auxiliary.attack_trainer_builder.wrap_attacker_trainer`` ================================== =========================== """ if config.trainer.type == 'general': if config.backend == 'torch': from federatedscope.core.trainers import GeneralTorchTrainer trainer = GeneralTorchTrainer(model=model, data=data, device=device, config=config, only_for_eval=only_for_eval, monitor=monitor) elif config.backend == 'tensorflow': from federatedscope.core.trainers import GeneralTFTrainer trainer = GeneralTFTrainer(model=model, data=data, device=device, config=config, only_for_eval=only_for_eval, monitor=monitor) else: raise ValueError elif config.trainer.type == 'none': return None elif config.trainer.type.lower() in TRAINER_CLASS_DICT: if config.trainer.type.lower() in ['cvtrainer']: dict_path = "federatedscope.cv.trainer.trainer" elif config.trainer.type.lower() in ['nlptrainer']: dict_path = "federatedscope.nlp.trainer.trainer" elif config.trainer.type.lower() in ['cltrainer', 'lptrainer']: dict_path = "federatedscope.cl.trainer.trainer" elif config.trainer.type.lower() in [ 'graphminibatch_trainer', ]: dict_path = "federatedscope.gfl.trainer.graphtrainer" elif config.trainer.type.lower() in [ 'linkfullbatch_trainer', 'linkminibatch_trainer' ]: dict_path = "federatedscope.gfl.trainer.linktrainer" elif config.trainer.type.lower() in [ 'nodefullbatch_trainer', 'nodeminibatch_trainer' ]: dict_path = "federatedscope.gfl.trainer.nodetrainer" elif config.trainer.type.lower() in [ 'flitplustrainer', 'flittrainer', 'fedvattrainer', 'fedfocaltrainer' ]: dict_path = "federatedscope.gfl.flitplus.trainer" elif config.trainer.type.lower() in ['mftrainer']: dict_path = "federatedscope.mf.trainer.trainer" elif config.trainer.type.lower() in ['atc_trainer']: dict_path = "federatedscope.nlp.hetero_tasks.trainer" else: raise ValueError trainer_cls = getattr(importlib.import_module(name=dict_path), TRAINER_CLASS_DICT[config.trainer.type.lower()]) trainer = trainer_cls(model=model, data=data, device=device, config=config, only_for_eval=only_for_eval, monitor=monitor) elif config.trainer.type.lower() in ['verticaltrainer']: from federatedscope.vertical_fl.tree_based_models.trainer.utils \ import get_vertical_trainer trainer = get_vertical_trainer(config=config, model=model, data=data, device=device, monitor=monitor) else: # try to find user registered trainer trainer = None for func in register.trainer_dict.values(): trainer_cls = func(config.trainer.type) if trainer_cls is not None: trainer = trainer_cls(model=model, data=data, device=device, config=config, only_for_eval=only_for_eval, monitor=monitor) if trainer is None: raise ValueError('Trainer {} is not provided'.format( config.trainer.type)) if not isinstance(trainer, Trainer): logger.warning(f'Hook-like plug-in functions cannot be enabled when ' f'using {trainer}. If you want use our wrapper ' f'functions for your trainer please consider ' f'inheriting from ' f'`federatedscope.core.trainers.Trainer` instead.') return trainer # differential privacy plug-in if config.nbafl.use: from federatedscope.core.trainers import wrap_nbafl_trainer trainer = wrap_nbafl_trainer(trainer) if config.sgdmf.use: from federatedscope.mf.trainer import wrap_MFTrainer trainer = wrap_MFTrainer(trainer) # personalization plug-in if config.federate.method.lower() == "pfedme": from federatedscope.core.trainers import wrap_pFedMeTrainer # wrap style: instance a (class A) -> instance a (class A) trainer = wrap_pFedMeTrainer(trainer) elif config.federate.method.lower() == "ditto": from federatedscope.core.trainers import wrap_DittoTrainer # wrap style: instance a (class A) -> instance a (class A) trainer = wrap_DittoTrainer(trainer) elif config.federate.method.lower() == "fedem": from federatedscope.core.trainers import FedEMTrainer # copy construct style: instance a (class A) -> instance b (class B) trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer, base_trainer=trainer) elif config.federate.method.lower() == "fedrep": from federatedscope.core.trainers import wrap_FedRepTrainer # wrap style: instance a (class A) -> instance a (class A) trainer = wrap_FedRepTrainer(trainer) # attacker plug-in if 'backdoor' in config.attack.attack_method: from federatedscope.attack.trainer import wrap_benignTrainer trainer = wrap_benignTrainer(trainer) if is_attacker: if 'backdoor' in config.attack.attack_method: logger.info('--------This client is a backdoor attacker --------') else: logger.info('-------- This client is an privacy attacker --------') from federatedscope.attack.auxiliary.attack_trainer_builder \ import wrap_attacker_trainer trainer = wrap_attacker_trainer(trainer, config) elif 'backdoor' in config.attack.attack_method: logger.info( '----- This client is a benign client for backdoor attacks -----') # fed algorithm plug-in if config.fedprox.use: from federatedscope.core.trainers import wrap_fedprox_trainer trainer = wrap_fedprox_trainer(trainer) # different fine-tuning if config.finetune.before_eval and config.finetune.simple_tuning: from federatedscope.core.trainers import wrap_Simple_tuning_Trainer trainer = wrap_Simple_tuning_Trainer(trainer) return trainer