import copy
import logging
import federatedscope.register as register
logger = logging.getLogger(__name__)
try:
import torch
except ImportError:
torch = None
try:
from federatedscope.contrib.scheduler import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.scheduler`, some modules are not '
f'available.')
[docs]def get_scheduler(optimizer, type, **kwargs):
"""
This function builds an instance of scheduler.
Args:
optimizer: optimizer to be scheduled
type: type of scheduler
**kwargs: kwargs dict
Returns:
An instantiated scheduler.
Note:
Please follow ``contrib.scheduler.example`` to implement your own \
scheduler.
"""
# in case of users have not called the cfg.freeze()
tmp_kwargs = copy.deepcopy(kwargs)
if '__help_info__' in tmp_kwargs:
del tmp_kwargs['__help_info__']
if '__cfg_check_funcs__' in tmp_kwargs:
del tmp_kwargs['__cfg_check_funcs__']
if 'is_ready_for_run' in tmp_kwargs:
del tmp_kwargs['is_ready_for_run']
if 'warmup_ratio' in tmp_kwargs:
del tmp_kwargs['warmup_ratio']
if 'warmup_steps' in tmp_kwargs:
warmup_steps = tmp_kwargs['warmup_steps']
del tmp_kwargs['warmup_steps']
if 'total_steps' in tmp_kwargs:
total_steps = tmp_kwargs['total_steps']
del tmp_kwargs['total_steps']
for func in register.scheduler_dict.values():
scheduler = func(optimizer, type, **tmp_kwargs)
if scheduler is not None:
return scheduler
if type == 'warmup_step':
from torch.optim.lr_scheduler import LambdaLR
def lr_lambda(cur_step):
if cur_step < warmup_steps:
return float(cur_step) / float(max(1, warmup_steps))
return max(
0.0,
float(total_steps - cur_step) /
float(max(1, total_steps - warmup_steps)))
return LambdaLR(optimizer, lr_lambda)
elif type == 'warmup_noam':
from torch.optim.lr_scheduler import LambdaLR
def lr_lambda(cur_step):
return min(
max(1, cur_step)**(-0.5),
max(1, cur_step) * warmup_steps**(-1.5))
return LambdaLR(optimizer, lr_lambda)
if torch is None or type == '':
return None
if isinstance(type, str):
if hasattr(torch.optim.lr_scheduler, type):
return getattr(torch.optim.lr_scheduler, type)(optimizer,
**tmp_kwargs)
else:
raise NotImplementedError(
'Scheduler {} not implement'.format(type))
else:
raise TypeError()