Source code for federatedscope.core.workers.client

import copy
import logging
import sys
import pickle

from federatedscope.core.message import Message
from federatedscope.core.communication import StandaloneCommManager, \
    StandaloneDDPCommManager, gRPCCommManager
from federatedscope.core.monitors.early_stopper import EarlyStopper
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
from federatedscope.core.secret_sharing import AdditiveSecretSharing
from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \
    calculate_time_cost
from federatedscope.core.workers.base_client import BaseClient

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


[docs]class Client(BaseClient): """ The Client class, which describes the behaviors of client in an FL \ course. The behaviors are described by the handling functions (named as \ ``callback_funcs_for_xxx``) Arguments: ID: The unique ID of the client, which is assigned by the server when joining the FL course server_id: (Default) 0 state: The training round config: The configuration data: The data owned by the client model: The model maintained locally device: The device to run local training and evaluation Attributes: ID: ID of worker state: the training round index model: the model maintained locally cfg: the configuration of FL course, \ see ``federatedscope.core.configs`` mode: the run mode for FL, ``distributed`` or ``standalone`` monitor: monite FL course and record metrics, \ see ``federatedscope.core.monitors.monitor.Monitor`` trainer: instantiated trainer, see ``federatedscope.core.trainers`` best_results: best results ever seen history_results: all evaluation results early_stopper: determine when to early stop, \ see ``federatedscope.core.monitors.early_stopper.EarlyStopper`` ss_manager: secret sharing manager msg_buffer: dict buffer for storing message comm_manager: manager for communication, \ see ``federatedscope.core.communication`` """ def __init__(self, ID=-1, server_id=None, state=-1, config=None, data=None, model=None, device='cpu', strategy=None, is_unseen_client=False, *args, **kwargs): super(Client, self).__init__(ID, state, config, model, strategy) self.data = data # Register message handlers self._register_default_handlers() # Un-configured worker if config is None: return # the unseen_client indicates that whether this client contributes to # FL process by training on its local data and uploading the local # model update, which is useful for check the participation # generalization gap in # [ICLR'22, What Do We Mean by Generalization in Federated Learning?] self.is_unseen_client = is_unseen_client # Parse the attack_id since we support both 'int' (for single attack) # and 'list' (for multiple attacks) for config.attack.attack_id parsed_attack_ids = list() if isinstance(config.attack.attacker_id, int): parsed_attack_ids.append(config.attack.attacker_id) elif isinstance(config.attack.attacker_id, list): parsed_attack_ids = config.attack.attacker_id else: raise TypeError(f"The expected types of config.attack.attack_id " f"include 'int' and 'list', but we got " f"{type(config.attack.attacker_id)}") # Attack only support the stand alone model; # Check if is a attacker; a client is a attacker if the # config.attack.attack_method is provided self.is_attacker = ID in parsed_attack_ids and \ config.attack.attack_method != '' and \ config.federate.mode == 'standalone' # Build Trainer # trainer might need configurations other than those of trainer node self.trainer = get_trainer(model=model, data=data, device=device, config=self._cfg, is_attacker=self.is_attacker, monitor=self._monitor) self.device = device # For client-side evaluation self.best_results = dict() self.history_results = dict() # in local or global training mode, we do use the early stopper. # Otherwise, we set patience=0 to deactivate the local early-stopper patience = self._cfg.early_stop.patience if \ self._cfg.federate.method in [ "local", "global" ] else 0 self.early_stopper = EarlyStopper( patience, self._cfg.early_stop.delta, self._cfg.early_stop.improve_indicator_mode, self._monitor.the_larger_the_better) # Secret Sharing Manager and message buffer self.ss_manager = AdditiveSecretSharing( shared_party_num=int(self._cfg.federate.sample_client_num )) if self._cfg.federate.use_ss else None self.msg_buffer = {'train': dict(), 'eval': dict()} # Communication and communication ability if 'resource_info' in kwargs and kwargs['resource_info'] is not None: self.comp_speed = float( kwargs['resource_info']['computation']) / 1000. # (s/sample) self.comm_bandwidth = float( kwargs['resource_info']['communication']) # (kbit/s) else: self.comp_speed = None self.comm_bandwidth = None if self._cfg.backend == 'torch': self.model_size = sys.getsizeof(pickle.dumps( self.model)) / 1024.0 * 8. # kbits else: # TODO: calculate model size for TF Model self.model_size = 1.0 logger.warning(f'The calculation of model size in backend:' f'{self._cfg.backend} is not provided.') # Initialize communication manager self.server_id = server_id if self.mode == 'standalone': comm_queue = kwargs['shared_comm_queue'] if self._cfg.federate.process_num <= 1: self.comm_manager = StandaloneCommManager( comm_queue=comm_queue, monitor=self._monitor) else: self.comm_manager = StandaloneDDPCommManager( comm_queue=comm_queue, monitor=self._monitor) self.local_address = None elif self.mode == 'distributed': host = kwargs['host'] port = kwargs['port'] server_host = kwargs['server_host'] server_port = kwargs['server_port'] self.comm_manager = gRPCCommManager( host=host, port=port, client_num=self._cfg.federate.client_num, cfg=self._cfg.distribute) logger.info('Client: Listen to {}:{}...'.format(host, port)) self.comm_manager.add_neighbors(neighbor_id=server_id, address={ 'host': server_host, 'port': server_port }) self.local_address = { 'host': self.comm_manager.host, 'port': self.comm_manager.port } def _gen_timestamp(self, init_timestamp, instance_number): if init_timestamp is None: return None comp_cost, comm_cost = calculate_time_cost( instance_number=instance_number, comm_size=self.model_size, comp_speed=self.comp_speed, comm_bandwidth=self.comm_bandwidth) return init_timestamp + comp_cost + comm_cost def _calculate_model_delta(self, init_model, updated_model): if not isinstance(init_model, list): init_model = [init_model] updated_model = [updated_model] model_deltas = list() for model_index in range(len(init_model)): model_delta = copy.deepcopy(init_model[model_index]) for key in init_model[model_index].keys(): model_delta[key] = updated_model[model_index][ key] - init_model[model_index][key] model_deltas.append(model_delta) if len(model_deltas) > 1: return model_deltas else: return model_deltas[0]
[docs] def join_in(self): """ To send ``join_in`` message to the server for joining in the FL course. """ self.comm_manager.send( Message(msg_type='join_in', sender=self.ID, receiver=[self.server_id], timestamp=0, content=self.local_address))
[docs] def run(self): """ To listen to the message and handle them accordingly (used for \ distributed mode) """ while True: msg = self.comm_manager.receive() if self.state <= msg.state: self.msg_handlers[msg.msg_type](msg) if msg.msg_type == 'finish': break
[docs] def run_standalone(self): """ Run in standalone mode """ self.join_in() self.run()
[docs] def callback_funcs_for_model_para(self, message: Message): """ The handling function for receiving model parameters, \ which triggers the local training process. \ This handling function is widely used in various FL courses. Arguments: message: The received message """ if 'ss' in message.msg_type: # A fragment of the shared secret state, content, timestamp = message.state, message.content, \ message.timestamp self.msg_buffer['train'][state].append(content) if len(self.msg_buffer['train'] [state]) == self._cfg.federate.client_num: # Check whether the received fragments are enough model_list = self.msg_buffer['train'][state] sample_size, first_aggregate_model_para = model_list[0] single_model_case = True if isinstance(first_aggregate_model_para, list): assert isinstance(first_aggregate_model_para[0], dict), \ "aggregate_model_para should a list of multiple " \ "state_dict for multiple models" single_model_case = False else: assert isinstance(first_aggregate_model_para, dict), \ "aggregate_model_para should " \ "a state_dict for single model case" first_aggregate_model_para = [first_aggregate_model_para] model_list = [[model] for model in model_list] for sub_model_idx, aggregate_single_model_para in enumerate( first_aggregate_model_para): for key in aggregate_single_model_para: for i in range(1, len(model_list)): aggregate_single_model_para[key] += model_list[i][ sub_model_idx][key] self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, receiver=[self.server_id], state=self.state, timestamp=timestamp, content=(sample_size, first_aggregate_model_para[0] if single_model_case else first_aggregate_model_para))) else: round = message.state sender = message.sender timestamp = message.timestamp content = message.content # dequantization if self._cfg.quantization.method == 'uniform': from federatedscope.core.compression import \ symmetric_uniform_dequantization if isinstance(content, list): # multiple model content = [ symmetric_uniform_dequantization(x) for x in content ] else: content = symmetric_uniform_dequantization(content) # When clients share the local model, we must set strict=True to # ensure all the model params (which might be updated by other # clients in the previous local training process) are overwritten # and synchronized with the received model if self._cfg.federate.process_num > 1: for k, v in content.items(): content[k] = v.to(self.device) self.trainer.update(content, strict=self._cfg.federate.share_local_model) self.state = round skip_train_isolated_or_global_mode = \ self.early_stopper.early_stopped and \ self._cfg.federate.method in ["local", "global"] if self.is_unseen_client or skip_train_isolated_or_global_mode: # for these cases (1) unseen client (2) isolated_global_mode, # we do not local train and upload local model sample_size, model_para_all, results = \ 0, self.trainer.get_model_para(), {} if skip_train_isolated_or_global_mode: logger.info( f"[Local/Global mode] Client #{self.ID} has been " f"early stopped, we will skip the local training") self._monitor.local_converged() else: if self.early_stopper.early_stopped and \ self._monitor.local_convergence_round == 0: logger.info( f"[Normal FL Mode] Client #{self.ID} has been locally " f"early stopped. " f"The next FL update may result in negative effect") self._monitor.local_converged() sample_size, model_para_all, results = self.trainer.train() if self._cfg.federate.share_local_model and not \ self._cfg.federate.online_aggr: model_para_all = copy.deepcopy(model_para_all) train_log_res = self._monitor.format_eval_res( results, rnd=self.state, role='Client #{}'.format(self.ID), return_raw=True) logger.info(train_log_res) if self._cfg.wandb.use and self._cfg.wandb.client_train_info: self._monitor.save_formatted_results(train_log_res, save_file_name="") # Return the feedbacks to the server after local update if self._cfg.federate.use_ss: assert not self.is_unseen_client, \ "Un-support using secret sharing for unseen clients." \ "i.e., you set cfg.federate.use_ss=True and " \ "cfg.federate.unseen_clients_rate in (0, 1)" single_model_case = True if isinstance(model_para_all, list): assert isinstance(model_para_all[0], dict), \ "model_para should a list of " \ "multiple state_dict for multiple models" single_model_case = False else: assert isinstance(model_para_all, dict), \ "model_para should a state_dict for single model case" model_para_all = [model_para_all] model_para_list_all = [] for model_para in model_para_all: for key in model_para: model_para[key] = model_para[key] * sample_size model_para_list = self.ss_manager.secret_split(model_para) model_para_list_all.append(model_para_list) frame_idx = 0 for neighbor in self.comm_manager.neighbors: if neighbor != self.server_id: content_frame = model_para_list_all[0][frame_idx] if \ single_model_case else \ [model_para_list[frame_idx] for model_para_list in model_para_list_all] self.comm_manager.send( Message(msg_type='ss_model_para', sender=self.ID, receiver=[neighbor], state=self.state, timestamp=self._gen_timestamp( init_timestamp=timestamp, instance_number=sample_size), content=content_frame)) frame_idx += 1 content_frame = model_para_list_all[0][frame_idx] if \ single_model_case else \ [model_para_list[frame_idx] for model_para_list in model_para_list_all] self.msg_buffer['train'][self.state] = [(sample_size, content_frame)] else: if self._cfg.asyn.use or self._cfg.aggregator.robust_rule in \ ['krum', 'normbounding', 'median', 'trimmedmean', 'bulyan']: # Return the model delta when using asynchronous training # protocol, because the staled updated might be discounted # and cause that the sum of the aggregated weights might # not be equal to 1 shared_model_para = self._calculate_model_delta( init_model=content, updated_model=model_para_all) else: shared_model_para = model_para_all # quantization if self._cfg.quantization.method == 'uniform': from federatedscope.core.compression import \ symmetric_uniform_quantization nbits = self._cfg.quantization.nbits if isinstance(shared_model_para, list): shared_model_para = [ symmetric_uniform_quantization(x, nbits) for x in shared_model_para ] else: shared_model_para = symmetric_uniform_quantization( shared_model_para, nbits) self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, receiver=[sender], state=self.state, timestamp=self._gen_timestamp( init_timestamp=timestamp, instance_number=sample_size), content=(sample_size, shared_model_para)))
[docs] def callback_funcs_for_assign_id(self, message: Message): """ The handling function for receiving the client_ID assigned by the \ server (during the joining process), which is used in the \ distributed mode. Arguments: message: The received message """ content = message.content self.ID = int(content) logger.info('Client (address {}:{}) is assigned with #{:d}.'.format( self.comm_manager.host, self.comm_manager.port, self.ID))
[docs] def callback_funcs_for_join_in_info(self, message: Message): """ The handling function for receiving the request of join in \ information (such as ``batch_size``, ``num_of_samples``) during \ the joining process. Arguments: message: The received message """ requirements = message.content timestamp = message.timestamp join_in_info = dict() for requirement in requirements: if requirement.lower() == 'num_sample': if self._cfg.train.batch_or_epoch == 'batch': num_sample = self._cfg.train.local_update_steps * \ self._cfg.dataloader.batch_size else: num_sample = self._cfg.train.local_update_steps * \ len(self.trainer.data.train_data) join_in_info['num_sample'] = num_sample if self._cfg.trainer.type == 'nodefullbatch_trainer': join_in_info['num_sample'] = \ self.trainer.data.train_data.x.shape[0] elif requirement.lower() == 'client_resource': assert self.comm_bandwidth is not None and self.comp_speed \ is not None, "The requirement join_in_info " \ "'client_resource' does not exist." join_in_info['client_resource'] = self.model_size / \ self.comm_bandwidth + self.comp_speed else: raise ValueError( 'Fail to get the join in information with type {}'.format( requirement)) self.comm_manager.send( Message(msg_type='join_in_info', sender=self.ID, receiver=[self.server_id], state=self.state, timestamp=timestamp, content=join_in_info))
[docs] def callback_funcs_for_address(self, message: Message): """ The handling function for receiving other clients' IP addresses, \ which is used for constructing a complex topology Arguments: message: The received message """ content = message.content for neighbor_id, address in content.items(): if int(neighbor_id) != self.ID: self.comm_manager.add_neighbors(neighbor_id, address)
[docs] def callback_funcs_for_evaluate(self, message: Message): """ The handling function for receiving the request of evaluating Arguments: message: The received message """ sender, timestamp = message.sender, message.timestamp self.state = message.state if message.content is not None: self.trainer.update(message.content, strict=self._cfg.federate.share_local_model) if self.early_stopper.early_stopped and self._cfg.federate.method in [ "local", "global" ]: metrics = list(self.best_results.values())[0] else: metrics = {} if self._cfg.finetune.before_eval: self.trainer.finetune() for split in self._cfg.eval.split: # TODO: The time cost of evaluation is not considered here eval_metrics = self.trainer.evaluate( target_data_split_name=split) if self._cfg.federate.mode == 'distributed': logger.info( self._monitor.format_eval_res(eval_metrics, rnd=self.state, role='Client #{}'.format( self.ID), return_raw=True)) metrics.update(**eval_metrics) formatted_eval_res = self._monitor.format_eval_res( metrics, rnd=self.state, role='Client #{}'.format(self.ID), forms=['raw'], return_raw=True) self._monitor.update_best_result(self.best_results, formatted_eval_res['Results_raw'], results_type=f"client #{self.ID}") self.history_results = merge_dict_of_results( self.history_results, formatted_eval_res['Results_raw']) self.early_stopper.track_and_check(self.history_results[ self._cfg.eval.best_res_update_round_wise_key]) self.comm_manager.send( Message(msg_type='metrics', sender=self.ID, receiver=[sender], state=self.state, timestamp=timestamp, content=metrics))
[docs] def callback_funcs_for_finish(self, message: Message): """ The handling function for receiving the signal of finishing the FL \ course. Arguments: message: The received message """ logger.info( f"================= client {self.ID} received finish message " f"=================") if message.content is not None: self.trainer.update(message.content, strict=self._cfg.federate.share_local_model) self._monitor.finish_fl()
[docs] def callback_funcs_for_converged(self, message: Message): """ The handling function for receiving the signal that the FL course \ converged Arguments: message: The received message """ self._monitor.global_converged()
@classmethod def get_msg_handler_dict(cls): return cls().msg_handlers_str