import logging
from typing import Type
import numpy as np
from federatedscope.core.trainers import GeneralTorchTrainer
logger = logging.getLogger(__name__)
[docs]def wrap_benignTrainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
'''
Warp the benign trainer for backdoor attack:
We just add the normalization operation.
Args:
base_trainer: Type: core.trainers.GeneralTorchTrainer
:returns:
The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
'''
base_trainer.register_hook_in_eval(new_hook=hook_on_fit_end_test_poison,
trigger='on_fit_end',
insert_pos=0)
return base_trainer
def hook_on_fit_end_test_poison(ctx):
"""
Evaluate metrics of poisoning attacks.
"""
ctx['poison_' + ctx.cur_split + '_loader'] = ctx.data['poison_' +
ctx.cur_split]
ctx['poison_' + ctx.cur_split + '_data'] = ctx.data['poison_' +
ctx.cur_split].dataset
ctx['num_poison_' + ctx.cur_split + '_data'] = len(
ctx.data['poison_' + ctx.cur_split].dataset)
setattr(ctx, "poison_{}_y_true".format(ctx.cur_split), [])
setattr(ctx, "poison_{}_y_prob".format(ctx.cur_split), [])
setattr(ctx, "poison_num_samples_{}".format(ctx.cur_split), 0)
for batch_idx, (samples, targets) in enumerate(
ctx['poison_' + ctx.cur_split + '_loader']):
samples, targets = samples.to(ctx.device), targets.to(ctx.device)
pred = ctx.model(samples)
if len(targets.size()) == 0:
targets = targets.unsqueeze(0)
ctx.poison_y_true = targets
ctx.poison_y_prob = pred
ctx.poison_batch_size = len(targets)
ctx.get("poison_{}_y_true".format(ctx.cur_split)).append(
ctx.poison_y_true.detach().cpu().numpy())
ctx.get("poison_{}_y_prob".format(ctx.cur_split)).append(
ctx.poison_y_prob.detach().cpu().numpy())
setattr(
ctx, "poison_num_samples_{}".format(ctx.cur_split),
ctx.get("poison_num_samples_{}".format(ctx.cur_split)) +
ctx.poison_batch_size)
setattr(ctx, "poison_{}_y_true".format(ctx.cur_split),
np.concatenate(ctx.get("poison_{}_y_true".format(ctx.cur_split))))
setattr(ctx, "poison_{}_y_prob".format(ctx.cur_split),
np.concatenate(ctx.get("poison_{}_y_prob".format(ctx.cur_split))))
logger.info('the {} poisoning samples: {:d}'.format(
ctx.cur_split, ctx.get("poison_num_samples_{}".format(ctx.cur_split))))
poison_true = ctx['poison_' + ctx.cur_split + '_y_true']
poison_prob = ctx['poison_' + ctx.cur_split + '_y_prob']
poison_pred = np.argmax(poison_prob, axis=1)
correct = poison_true == poison_pred
poisoning_acc = float(np.sum(correct)) / len(correct)
logger.info('the {} poisoning accuracy: {:f}'.format(
ctx.cur_split, poisoning_acc))