import logging
import federatedscope.register as register
from federatedscope.nlp.hetero_tasks.metric import *
logger = logging.getLogger(__name__)
try:
from federatedscope.contrib.metrics import *
except ImportError as error:
logger.warning(
f'{error} in `federatedscope.contrib.metrics`, some modules are not '
f'available.')
[docs]def get_metric(types):
"""
This function returns a dict, where the key is metric name, and value is \
the function of how to calculate the metric and a bool to indicate the \
metric is larger the better.
Args:
types: list of metric names
Returns:
A metric calculator dict, such as \
``{'loss': (eval_loss, False), 'acc': (eval_acc, True), ...}``
Note:
The key-value pairs of built-in metric and related funcs and \
``the_larger_the_better`` sign is shown below:
================= ============================================= =====
Metric name Source \
The larger the better
================= ============================================= =====
``loss`` ``monitors.metric_calculator.eval_loss`` False
``avg_loss`` ``monitors.metric_calculator.eval_avg_loss`` False
``total`` ``monitors.metric_calculator.eval_total`` False
``correct`` ``monitors.metric_calculator.eval_correct`` True
``acc`` ``monitors.metric_calculator.eval_acc`` True
``ap`` ``monitors.metric_calculator.eval_ap`` True
``f1`` ``monitors.metric_calculator.eval_f1_score`` True
``roc_auc`` ``monitors.metric_calculator.eval_roc_auc`` True
``rmse`` ``monitors.metric_calculator.eval_rmse`` False
``mse`` ``monitors.metric_calculator.eval_mse`` False
``loss_regular`` ``monitors.metric_calculator.eval_regular`` False
``imp_ratio`` ``monitors.metric_calculator.eval_imp_ratio`` True
``std`` ``None`` False
``hits@{n}`` ``monitors.metric_calculator.eval_hits`` True
================= ============================================= =====
"""
metrics = dict()
for func in register.metric_dict.values():
res = func(types)
if res is not None:
name, metric, the_larger_the_better = res
metrics[name] = metric, the_larger_the_better
return metrics