Source code for federatedscope.attack.auxiliary.utils

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import logging
import os
import numpy as np
import federatedscope.register as register

logger = logging.getLogger(__name__)


def label_to_onehot(target, num_classes=100):
    return torch.nn.functional.one_hot(target, num_classes)


def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(-target * F.log_softmax(pred, dim=-1), 1))


[docs]def iDLG_trick(original_gradient, num_class, is_one_hot_label=False): ''' Using iDLG trick to recover the label. Paper: "iDLG: Improved Deep Leakage from Gradients", link: https://arxiv.org/abs/2001.02610 Args: original_gradient: the gradient of the FL model; type: list num_class: the total number of class in the data is_one_hot_label: whether the dataset's label is in the form of one hot. Type: bool Returns: The recovered label by iDLG trick. ''' last_weight_min = torch.argmin(torch.sum(original_gradient[-2], dim=-1), dim=-1).detach() if is_one_hot_label: label = label_to_onehot( last_weight_min.reshape((1, )).requires_grad_(False), num_class) else: label = last_weight_min return label
def cos_sim(input_gradient, gt_gradient): total = 1 - torch.nn.functional.cosine_similarity( input_gradient.flatten(), gt_gradient.flatten(), 0, 1e-10) # total = 0 # input_norm= 0 # gt_norm = 0 # # total -= (input_gradient * gt_gradient).sum() # input_norm += input_gradient.pow(2).sum() # gt_norm += gt_gradient.pow(2).sum() # total += 1 + total / input_norm.sqrt() / gt_norm.sqrt() return total def total_variation(x): """Anisotropic TV.""" dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) total = x.size()[0] for ind in range(1, len(x.size())): total *= x.size()[ind] return (dx + dy) / (total) def approximate_func(x, device, C1=20, C2=0.5): ''' Approximate the function f(x) = 0 if x<0.5 otherwise 1 Args: x: input data; device: C1: C2: Returns: 1/(1+e^{-1*C1 (x-C2)}) ''' C1 = torch.tensor(C1).to(torch.device(device)) C2 = torch.tensor(C2).to(torch.device(device)) return 1 / (1 + torch.exp(-1 * C1 * (x - C2))) def get_classifier(classifier: str, model=None): if model is not None: return model if classifier == 'lr': from sklearn.linear_model import LogisticRegression model = LogisticRegression(random_state=0) return model elif classifier.lower() == 'randomforest': from sklearn.ensemble import RandomForestClassifier model = RandomForestClassifier(random_state=0) return model elif classifier.lower() == 'svm': from sklearn.svm import SVC from sklearn.preprocessing import StandardScaler from sklearn.pipeline import make_pipeline model = make_pipeline(StandardScaler(), SVC(gamma='auto')) return model else: ValueError()
[docs]def get_data_info(dataset_name): ''' Get the dataset information, including the feature dimension, number of total classes, whether the label is represented in one-hot version Args: dataset_name:dataset name; str :returns: data_feature_dim, num_class, is_one_hot_label ''' if dataset_name.lower() == 'femnist': return [1, 28, 28], 36, False else: ValueError( 'Please provide the data info of {}: data_feature_dim, num_class'. format(dataset_name))
def get_data_sav_fn(dataset_name): if dataset_name.lower() == 'femnist': return sav_femnist_image else: logger.info(f"Reconstructed data saving function is not provided for " f"dataset: {dataset_name}") return None def sav_femnist_image(data, sav_pth, name): _ = plt.figure(figsize=(4, 4)) # print(data.shape) if len(data.shape) == 2: data = torch.unsqueeze(data, 0) data = torch.unsqueeze(data, 0) ind = min(data.shape[0], 16) # print(data.shape) # plt.imshow(data * 127.5 + 127.5, cmap='gray') for i in range(ind): plt.subplot(4, 4, i + 1) plt.imshow(data[i, 0, :, :] * 127.5 + 127.5, cmap='gray') # plt.imshow(generated_data[i, 0, :, :] , cmap='gray') # plt.imshow() plt.axis('off') plt.savefig(os.path.join(sav_pth, name)) plt.close() def get_info_diff_loss(info_diff_type): if info_diff_type.lower() == 'l2': info_diff_loss = torch.nn.MSELoss(reduction='sum') elif info_diff_type.lower() == 'l1': info_diff_loss = torch.nn.SmoothL1Loss(reduction='sum', beta=1e-5) elif info_diff_type.lower() == 'sim': info_diff_loss = cos_sim else: ValueError( 'info_diff_type: {} is not supported'.format(info_diff_type)) return info_diff_loss
[docs]def get_reconstructor(atk_method, **kwargs): ''' Args: atk_method: the attack method name, and currently supporting "DLG: deep leakage from gradient", and "IG: Inverting gradient" ; Type: str **kwargs: other arguments Returns: ''' if atk_method.lower() == 'dlg': from federatedscope.attack.privacy_attacks.reconstruction_opt import\ DLG logger.info( '--------- Getting reconstructor: DLG --------------------') return DLG(max_ite=kwargs['max_ite'], lr=kwargs['lr'], federate_loss_fn=kwargs['federate_loss_fn'], device=kwargs['device'], federate_lr=kwargs['federate_lr'], optim=kwargs['optim'], info_diff_type=kwargs['info_diff_type'], federate_method=kwargs['federate_method']) elif atk_method.lower() == 'ig': from federatedscope.attack.privacy_attacks.reconstruction_opt import\ InvertGradient logger.info( '------- Getting reconstructor: InvertGradient ------------------') return InvertGradient(max_ite=kwargs['max_ite'], lr=kwargs['lr'], federate_loss_fn=kwargs['federate_loss_fn'], device=kwargs['device'], federate_lr=kwargs['federate_lr'], optim=kwargs['optim'], info_diff_type=kwargs['info_diff_type'], federate_method=kwargs['federate_method'], alpha_TV=kwargs['alpha_TV']) else: ValueError( "attack method: {} lacks reconstructor implementation".format( atk_method))
[docs]def get_generator(dataset_name): ''' Get the dataset's corresponding generator. Args: dataset_name: The dataset name; Type: str :returns: The generator; Type: object ''' if dataset_name == 'femnist': from federatedscope.attack.models.gan_based_model import \ GeneratorFemnist return GeneratorFemnist else: ValueError( "The generator to generate data like {} is not defined!".format( dataset_name))
def get_data_property(ctx): # A SHOWCASE for Femnist dataset: Property := whether contains a circle. x, label = [_.to(ctx.device) for _ in ctx.data_batch] prop = torch.zeros(label.size) positive_labels = [0, 6, 8] for ind in range(label.size()[0]): if label[ind] in positive_labels: prop[ind] = 1 prop.to(ctx.device) return prop
[docs]def get_passive_PIA_auxiliary_dataset(dataset_name): ''' Args: dataset_name (str): dataset name :returns: the auxiliary dataset for property inference attack. Type: dict { 'x': array, 'y': array, 'prop': array } ''' for func in register.auxiliary_data_loader_PIA_dict.values(): criterion = func(dataset_name) if criterion is not None: return criterion if dataset_name == 'toy': def _generate_data(instance_num=1000, feature_num=5, save_data=False): """ Generate data in Runner format Args: instance_num: feature_num: save_data: Returns: { 'x': ..., 'y': ..., 'prop': ... } """ weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num) bias = np.random.normal(loc=0.0, scale=1.0) prop_weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num) prop_bias = np.random.normal(loc=0.0, scale=1.0) x = np.random.normal(loc=0.0, scale=0.5, size=(instance_num, feature_num)) y = np.sum(x * weights, axis=-1) + bias y = np.expand_dims(y, -1) prop = np.sum(x * prop_weights, axis=-1) + prop_bias prop = 1.0 * ((1 / (1 + np.exp(-1 * prop))) > 0.5) prop = np.expand_dims(prop, -1) data_train = {'x': x, 'y': y, 'prop': prop} return data_train return _generate_data() else: ValueError( 'The data cannot be loaded. Please specify the data load function.' )
def plot_mia_loss_compare(loss_in_pth, loss_out_pth, in_round=20): loss_in = np.loadtxt(loss_in_pth, delimiter=',') loss_out = np.loadtxt(loss_out_pth, delimiter=',') import matplotlib.pyplot as plt loss_in_all = [] loss_out_all = [] for i in range(len(loss_in)): if i == in_round: pass else: loss_in_all.append(loss_in[i]) loss_out_all.append(loss_out[i]) plt.plot(loss_out_all, label='not-in', alpha=0.9, color='red', linewidth=2) plt.plot(loss_in_all, linestyle=':', label='in', alpha=0.9, linewidth=2, color='blue') plt.legend() plt.xlabel('Round', fontsize=16) plt.ylabel('$L_x$', fontsize=16) plt.show()