Source code for federatedscope.core.auxiliaries.optimizer_builder

import copy
import logging
import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
    import torch
except ImportError:
    torch = None

try:
    from federatedscope.contrib.optimizer import *
except ImportError as error:
    logger.warning(
        f'{error} in `federatedscope.contrib.optimizer`, some modules are not '
        f'available.')


[docs]def get_optimizer(model, type, lr, **kwargs): """ This function returns an instantiated optimizer to optimize the model. Args: model: model to be optimized type: type of optimizer, see \ https://pytorch.org/docs/stable/optim.html lr: learning rate **kwargs: kwargs dict Returns: An instantiated optimizer """ if torch is None: return None # in case of users have not called the cfg.freeze() tmp_kwargs = copy.deepcopy(kwargs) if '__help_info__' in tmp_kwargs: del tmp_kwargs['__help_info__'] if '__cfg_check_funcs__' in tmp_kwargs: del tmp_kwargs['__cfg_check_funcs__'] if 'is_ready_for_run' in tmp_kwargs: del tmp_kwargs['is_ready_for_run'] for func in register.optimizer_dict.values(): optimizer = func(model, type, lr, **tmp_kwargs) if optimizer is not None: return optimizer if isinstance(type, str): if hasattr(torch.optim, type): if isinstance(model, torch.nn.Module): return getattr(torch.optim, type)(model.parameters(), lr, **tmp_kwargs) else: return getattr(torch.optim, type)(model, lr, **tmp_kwargs) else: raise NotImplementedError( 'Optimizer {} not implement'.format(type)) else: raise TypeError()