Source code for federatedscope.attack.trainer.GAN_trainer

import logging
from typing import Type

from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.attack.privacy_attacks.GAN_based_attack import GANCRA

logger = logging.getLogger(__name__)


[docs]def wrap_GANTrainer( base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: ''' Warp the trainer for gan_based class representative attack. Args: base_trainer: Type: core.trainers.GeneralTorchTrainer :returns: The wrapped trainer; Type: core.trainers.GeneralTorchTrainer ''' # ---------------- attribute-level plug-in ----------------------- base_trainer.ctx.target_label_ind = \ base_trainer.cfg.attack.target_label_ind base_trainer.ctx.gan_cra = GANCRA(base_trainer.cfg.attack.target_label_ind, base_trainer.ctx.model, dataset_name=base_trainer.cfg.data.type, device=base_trainer.ctx.device, sav_pth=base_trainer.cfg.outdir) # ---- action-level plug-in ------- base_trainer.register_hook_in_train(new_hook=hood_on_fit_start_generator, trigger='on_fit_start', insert_mode=-1) base_trainer.register_hook_in_train(new_hook=hook_on_gan_cra_train, trigger='on_batch_start', insert_mode=-1) base_trainer.register_hook_in_train( new_hook=hook_on_batch_injected_data_generation, trigger='on_batch_start', insert_mode=-1) base_trainer.register_hook_in_train( new_hook=hook_on_batch_forward_injected_data, trigger='on_batch_forward', insert_mode=-1) base_trainer.register_hook_in_train( new_hook=hook_on_data_injection_sav_data, trigger='on_fit_end', insert_mode=-1) return base_trainer
[docs]def hood_on_fit_start_generator(ctx): ''' count the FL training round before fitting Args: ctx (): Returns: ''' ctx.gan_cra.round_num += 1 logger.info('----- Round {}: GAN training ............'.format( ctx.gan_cra.round_num))
[docs]def hook_on_batch_forward_injected_data(ctx): ''' inject the generated data into training batch loss Args: ctx (): Returns: ''' x, label = [_.to(ctx.device) for _ in ctx.injected_data] pred = ctx.model(x) if len(label.size()) == 0: label = label.unsqueeze(0) ctx.loss_task += ctx.criterion(pred, label) ctx.y_true_injected = label ctx.y_prob_injected = pred
[docs]def hook_on_batch_injected_data_generation(ctx): '''generate the injected data ''' ctx.injected_data = ctx.gan_cra.generate_fake_data()
def hook_on_gan_cra_train(ctx): ctx.gan_cra.update_discriminator(ctx.model) ctx.gan_cra.generator_train() def hook_on_data_injection_sav_data(ctx): ctx.gan_cra.generate_and_save_images()