Source code for federatedscope.core.auxiliaries.criterion_builder

import logging
import federatedscope.register as register

logger = logging.getLogger(__name__)

try:
    from torch import nn
    from federatedscope.nlp.loss import *
    from federatedscope.cl.loss import *
except ImportError:
    nn = None

try:
    from federatedscope.contrib.loss import *
except ImportError as error:
    logger.warning(
        f'{error} in `federatedscope.contrib.loss`, some modules are not '
        f'available.')


[docs]def get_criterion(criterion_type, device): """ This function builds an instance of loss functions from: \ "https://pytorch.org/docs/stable/nn.html#loss-functions", where the ``criterion_type`` is chosen from. Arguments: criterion_type: loss function type device: move to device (``cpu`` or ``gpu``) Returns: An instance of loss functions. """ for func in register.criterion_dict.values(): criterion = func(criterion_type, device) if criterion is not None: return criterion if isinstance(criterion_type, str): if hasattr(nn, criterion_type): return getattr(nn, criterion_type)() else: raise NotImplementedError( 'Criterion {} not implement'.format(criterion_type)) else: raise TypeError()