import logging
import federatedscope.register as register
logger = logging.getLogger(__name__)
try:
from torch import nn
from federatedscope.nlp.loss import *
from federatedscope.cl.loss import *
except ImportError:
nn = None
try:
from federatedscope.contrib.loss import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.loss`, some modules are not '
f'available.')
[docs]def get_criterion(criterion_type, device):
"""
This function builds an instance of loss functions from: \
"https://pytorch.org/docs/stable/nn.html#loss-functions",
where the ``criterion_type`` is chosen from.
Arguments:
criterion_type: loss function type
device: move to device (``cpu`` or ``gpu``)
Returns:
An instance of loss functions.
"""
for func in register.criterion_dict.values():
criterion = func(criterion_type, device)
if criterion is not None:
return criterion
if isinstance(criterion_type, str):
if hasattr(nn, criterion_type):
return getattr(nn, criterion_type)()
else:
raise NotImplementedError(
'Criterion {} not implement'.format(criterion_type))
else:
raise TypeError()