Source code for federatedscope.core.auxiliaries.scheduler_builder

import copy
import logging
import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
    import torch
except ImportError:
    torch = None

try:
    from federatedscope.contrib.scheduler import *
except ImportError as error:
    logger.warning(
        f'{error} in `federatedscope.contrib.scheduler`, some modules are not '
        f'available.')


[docs]def get_scheduler(optimizer, type, **kwargs): """ This function builds an instance of scheduler. Args: optimizer: optimizer to be scheduled type: type of scheduler **kwargs: kwargs dict Returns: An instantiated scheduler. Note: Please follow ``contrib.scheduler.example`` to implement your own \ scheduler. """ # in case of users have not called the cfg.freeze() tmp_kwargs = copy.deepcopy(kwargs) if '__help_info__' in tmp_kwargs: del tmp_kwargs['__help_info__'] if '__cfg_check_funcs__' in tmp_kwargs: del tmp_kwargs['__cfg_check_funcs__'] if 'is_ready_for_run' in tmp_kwargs: del tmp_kwargs['is_ready_for_run'] if 'warmup_ratio' in tmp_kwargs: del tmp_kwargs['warmup_ratio'] if 'warmup_steps' in tmp_kwargs: warmup_steps = tmp_kwargs['warmup_steps'] del tmp_kwargs['warmup_steps'] if 'total_steps' in tmp_kwargs: total_steps = tmp_kwargs['total_steps'] del tmp_kwargs['total_steps'] for func in register.scheduler_dict.values(): scheduler = func(optimizer, type, **tmp_kwargs) if scheduler is not None: return scheduler if type == 'warmup_step': from torch.optim.lr_scheduler import LambdaLR def lr_lambda(cur_step): if cur_step < warmup_steps: return float(cur_step) / float(max(1, warmup_steps)) return max( 0.0, float(total_steps - cur_step) / float(max(1, total_steps - warmup_steps))) return LambdaLR(optimizer, lr_lambda) elif type == 'warmup_noam': from torch.optim.lr_scheduler import LambdaLR def lr_lambda(cur_step): return min( max(1, cur_step)**(-0.5), max(1, cur_step) * warmup_steps**(-1.5)) return LambdaLR(optimizer, lr_lambda) if torch is None or type == '': return None if isinstance(type, str): if hasattr(torch.optim.lr_scheduler, type): return getattr(torch.optim.lr_scheduler, type)(optimizer, **tmp_kwargs) else: raise NotImplementedError( 'Scheduler {} not implement'.format(type)) else: raise TypeError()