Source code for federatedscope.attack.trainer.gaussian_attack_trainer

import logging
from typing import Type

import torch

from federatedscope.core.trainers import GeneralTorchTrainer

logger = logging.getLogger(__name__)


[docs]def wrap_GaussianAttackTrainer( base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: ''' wrap the gaussian attack trainer Args: base_trainer: Type: core.trainers.GeneralTorchTrainer :returns: The wrapped trainer; Type: core.trainers.GeneralTorchTrainer ''' base_trainer.replace_hook_in_train( new_hook=hook_on_batch_backward_generate_gaussian_noise_gradient, target_trigger='on_batch_backward', target_hook_name='_hook_on_batch_backward') return base_trainer
def hook_on_batch_backward_generate_gaussian_noise_gradient(ctx): ctx.optimizer.zero_grad() ctx.loss_task.backward() grad_values = list() for name, param in ctx.model.named_parameters(): if 'bn' not in name: grad_values.append(param.grad.detach().cpu().view(-1)) grad_values = torch.cat(grad_values) mean_for_gaussian_noise = torch.mean(grad_values) + 0.1 std_for_gaussian_noise = torch.std(grad_values) for name, param in ctx.model.named_parameters(): if 'bn' not in name: generated_grad = torch.normal(mean=mean_for_gaussian_noise, std=std_for_gaussian_noise, size=param.grad.shape) param.grad = generated_grad.to(param.grad.device) ctx.optimizer.step()