Source code for federatedscope.core.fed_runner

import abc
import logging

from collections import deque
import heapq

import numpy as np

from federatedscope.core.workers import Server, Client
from federatedscope.core.gpu_manager import GPUManager
from federatedscope.core.auxiliaries.model_builder import get_model
from federatedscope.core.auxiliaries.utils import get_resource_info
from federatedscope.core.auxiliaries.feat_engr_builder import \
    get_feat_engr_wrapper

logger = logging.getLogger(__name__)


[docs]class BaseRunner(object): """ This class is a base class to construct an FL course, which includes \ ``_set_up()`` and ``run()``. Args: data: The data used in the FL courses, which are formatted as \ ``{'ID':data}`` for standalone mode. More details can be found in \ federatedscope.core.auxiliaries.data_builder . server_class: The server class is used for instantiating a ( \ customized) server. client_class: The client class is used for instantiating a ( \ customized) client. config: The configurations of the FL course. client_configs: The clients' configurations. Attributes: data: The data used in the FL courses, which are formatted as \ ``{'ID':data}`` for standalone mode. More details can be found in \ federatedscope.core.auxiliaries.data_builder . server: The instantiated server. client: The instantiate client(s). cfg : The configurations of the FL course. client_cfgs: The clients' configurations. mode: The run mode for FL, ``distributed`` or ``standalone`` gpu_manager: manager of GPU resource resource_info: information of resource """ def __init__(self, data, server_class=Server, client_class=Client, config=None, client_configs=None): self.data = data self.server_class = server_class self.client_class = client_class assert config is not None, \ "When using Runner, you should specify the `config` para" if not config.is_ready_for_run: config.ready_for_run() self.cfg = config self.client_cfgs = client_configs self.serial_num_for_msg = 0 self.mode = self.cfg.federate.mode.lower() self.gpu_manager = GPUManager(gpu_available=self.cfg.use_gpu, specified_device=self.cfg.device) self.unseen_clients_id = [] self.feat_engr_wrapper_client, self.feat_engr_wrapper_server = \ get_feat_engr_wrapper(config) if self.cfg.federate.unseen_clients_rate > 0: self.unseen_clients_id = np.random.choice( np.arange(1, self.cfg.federate.client_num + 1), size=max( 1, int(self.cfg.federate.unseen_clients_rate * self.cfg.federate.client_num)), replace=False).tolist() # get resource information self.resource_info = get_resource_info( config.federate.resource_info_file) # Check the completeness of msg_handler. self.check() # Set up for Runner self._set_up()
[docs] @abc.abstractmethod def _set_up(self): """ Set up and instantiate the client/server. """ raise NotImplementedError
[docs] @abc.abstractmethod def _get_server_args(self, resource_info, client_resource_info): """ Get the args for instantiating the server. Args: resource_info: information of resource client_resource_info: information of client's resource Returns: (server_data, model, kw): None or data which server holds; model \ to be aggregated; kwargs dict to instantiate the server. """ raise NotImplementedError
[docs] @abc.abstractmethod def _get_client_args(self, client_id, resource_info): """ Get the args for instantiating the server. Args: client_id: ID of client resource_info: information of resource Returns: (client_data, kw): data which client holds; kwargs dict to \ instantiate the client. """ raise NotImplementedError
[docs] @abc.abstractmethod def run(self): """ Launch the FL course Returns: dict: best results during the FL course """ raise NotImplementedError
[docs] def _setup_server(self, resource_info=None, client_resource_info=None): """ Set up and instantiate the server. Args: resource_info: information of resource client_resource_info: information of client's resource Returns: Instantiate server. """ assert self.server_class is not None, \ "`server_class` cannot be None." self.server_id = 0 server_data, model, kw = self._get_server_args(resource_info, client_resource_info) self._server_device = self.gpu_manager.auto_choice() server = self.server_class( ID=self.server_id, config=self.cfg, data=server_data, model=model, client_num=self.cfg.federate.client_num, total_round_num=self.cfg.federate.total_round_num, device=self._server_device, unseen_clients_id=self.unseen_clients_id, **kw) if self.cfg.nbafl.use: from federatedscope.core.trainers.trainer_nbafl import \ wrap_nbafl_server wrap_nbafl_server(server) if self.cfg.vertical.use: from federatedscope.vertical_fl.utils import wrap_vertical_server server = wrap_vertical_server(server, self.cfg) if self.cfg.fedswa.use: from federatedscope.core.workers.wrapper import wrap_swa_server server = wrap_swa_server(server) logger.info('Server has been set up ... ') return self.feat_engr_wrapper_server(server)
[docs] def _setup_client(self, client_id=-1, client_model=None, resource_info=None): """ Set up and instantiate the client. Args: client_id: ID of client client_model: model of client resource_info: information of resource Returns: Instantiate client. """ assert self.client_class is not None, \ "`client_class` cannot be None" self.server_id = 0 client_data, kw = self._get_client_args(client_id, resource_info) client_specific_config = self.cfg.clone() if self.client_cfgs: client_specific_config.defrost() client_specific_config.merge_from_other_cfg( self.client_cfgs.get('client_{}'.format(client_id))) client_specific_config.freeze() client_device = self._server_device if \ self.cfg.federate.share_local_model else \ self.gpu_manager.auto_choice() client = self.client_class(ID=client_id, server_id=self.server_id, config=client_specific_config, data=client_data, model=client_model or get_model(client_specific_config.model, client_data, backend=self.cfg.backend), device=client_device, is_unseen_client=client_id in self.unseen_clients_id, **kw) if self.cfg.vertical.use: from federatedscope.vertical_fl.utils import wrap_vertical_client client = wrap_vertical_client(client, config=self.cfg) if client_id == -1: logger.info('Client (address {}:{}) has been set up ... '.format( self.client_address['host'], self.client_address['port'])) else: logger.info(f'Client {client_id} has been set up ... ') return self.feat_engr_wrapper_client(client)
[docs] def check(self): """ Check the completeness of Server and Client. """ if not self.cfg.check_completeness: return try: import os import networkx as nx import matplotlib.pyplot as plt # Build check graph G = nx.DiGraph() flags = {0: 'Client', 1: 'Server'} msg_handler_dicts = [ self.client_class.get_msg_handler_dict(), self.server_class.get_msg_handler_dict() ] for flag, msg_handler_dict in zip(flags.keys(), msg_handler_dicts): role, oppo = flags[flag], flags[(flag + 1) % 2] for msg_in, (handler, msgs_out) in \ msg_handler_dict.items(): for msg_out in msgs_out: msg_in_key = f'{oppo}_{msg_in}' handler_key = f'{role}_{handler}' msg_out_key = f'{role}_{msg_out}' G.add_node(msg_in_key, subset=1) G.add_node(handler_key, subset=0 if flag else 2) G.add_node(msg_out_key, subset=1) G.add_edge(msg_in_key, handler_key) G.add_edge(handler_key, msg_out_key) pos = nx.multipartite_layout(G) plt.figure(figsize=(20, 15)) nx.draw(G, pos, with_labels=True, node_color='white', node_size=800, width=1.0, arrowsize=25, arrowstyle='->') fig_path = os.path.join(self.cfg.outdir, 'msg_handler.png') plt.savefig(fig_path) if nx.has_path(G, 'Client_join_in', 'Server_finish'): if nx.is_weakly_connected(G): logger.info(f'Completeness check passes! Save check ' f'results in {fig_path}.') else: logger.warning(f'Completeness check raises warning for ' f'some handlers not in FL process! Save ' f'check results in {fig_path}.') else: logger.error(f'Completeness check fails for there is no' f'path from `join_in` to `finish`! Save ' f'check results in {fig_path}.') except Exception as error: logger.warning(f'Completeness check failed for {error}!') return
[docs]class StandaloneRunner(BaseRunner):
[docs] def _set_up(self): """ To set up server and client for standalone mode. """ self.is_run_online = True if self.cfg.federate.online_aggr else False self.shared_comm_queue = deque() if self.cfg.backend == 'torch': import torch torch.set_num_threads(1) assert self.cfg.federate.client_num != 0, \ "In standalone mode, self.cfg.federate.client_num should be " \ "non-zero. " \ "This is usually cased by using synthetic data and users not " \ "specify a non-zero value for client_num" if self.cfg.federate.method == "global": self.cfg.defrost() self.cfg.federate.client_num = 1 self.cfg.federate.sample_client_num = 1 self.cfg.freeze() # sample resource information if self.resource_info is not None: if len(self.resource_info) < self.cfg.federate.client_num + 1: replace = True logger.warning( f"Because the provided the number of resource information " f"{len(self.resource_info)} is less than the number of " f"participants {self.cfg.federate.client_num + 1}, one " f"candidate might be selected multiple times.") else: replace = False sampled_index = np.random.choice( list(self.resource_info.keys()), size=self.cfg.federate.client_num + 1, replace=replace) server_resource_info = self.resource_info[sampled_index[0]] client_resource_info = [ self.resource_info[x] for x in sampled_index[1:] ] else: server_resource_info = None client_resource_info = None self.server = self._setup_server( resource_info=server_resource_info, client_resource_info=client_resource_info) self.client = dict() # assume the client-wise data are consistent in their input&output # shape self._shared_client_model = get_model( self.cfg.model, self.data[1], backend=self.cfg.backend ) if self.cfg.federate.share_local_model else None for client_id in range(1, self.cfg.federate.client_num + 1): self.client[client_id] = self._setup_client( client_id=client_id, client_model=self._shared_client_model, resource_info=client_resource_info[client_id - 1] if client_resource_info is not None else None) # in standalone mode, by default, we print the trainer info only # once for better logs readability trainer_representative = self.client[1].trainer if trainer_representative is not None and hasattr( trainer_representative, 'print_trainer_meta_info'): trainer_representative.print_trainer_meta_info()
[docs] def _get_server_args(self, resource_info=None, client_resource_info=None): if self.server_id in self.data: server_data = self.data[self.server_id] model = get_model(self.cfg.model, server_data, backend=self.cfg.backend) else: server_data = None data_representative = self.data[1] model = get_model( self.cfg.model, data_representative, backend=self.cfg.backend ) # get the model according to client's data if the server # does not own data kw = { 'shared_comm_queue': self.shared_comm_queue, 'resource_info': resource_info, 'client_resource_info': client_resource_info } return server_data, model, kw
[docs] def _get_client_args(self, client_id=-1, resource_info=None): client_data = self.data[client_id] kw = { 'shared_comm_queue': self.shared_comm_queue, 'resource_info': resource_info } return client_data, kw
[docs] def run(self): for each_client in self.client: # Launch each client self.client[each_client].join_in() if self.is_run_online: self._run_simulation_online() else: self._run_simulation() # TODO: avoid using private attr self.server._monitor.finish_fed_runner(fl_mode=self.mode) return self.server.best_results
[docs] def _handle_msg(self, msg, rcv=-1): """ To simulate the message handling process (used only for the \ standalone mode) """ if rcv != -1: # simulate broadcast one-by-one self.client[rcv].msg_handlers[msg.msg_type](msg) return _, receiver = msg.sender, msg.receiver download_bytes, upload_bytes = msg.count_bytes() if not isinstance(receiver, list): receiver = [receiver] for each_receiver in receiver: if each_receiver == 0: self.server.msg_handlers[msg.msg_type](msg) self.server._monitor.track_download_bytes(download_bytes) else: self.client[each_receiver].msg_handlers[msg.msg_type](msg) self.client[each_receiver]._monitor.track_download_bytes( download_bytes)
[docs] def _run_simulation_online(self): """ Run for online aggregation. Any broadcast operation would be executed client-by-clien to avoid \ the existence of #clients messages at the same time. Currently, \ only consider centralized topology \ """ def is_broadcast(msg): return len(msg.receiver) >= 1 and msg.sender == 0 cached_bc_msgs = [] cur_idx = 0 while True: if len(self.shared_comm_queue) > 0: msg = self.shared_comm_queue.popleft() if is_broadcast(msg): cached_bc_msgs.append(msg) # assume there is at least one client msg = cached_bc_msgs[0] self._handle_msg(msg, rcv=msg.receiver[cur_idx]) cur_idx += 1 if cur_idx >= len(msg.receiver): del cached_bc_msgs[0] cur_idx = 0 else: self._handle_msg(msg) elif len(cached_bc_msgs) > 0: msg = cached_bc_msgs[0] self._handle_msg(msg, rcv=msg.receiver[cur_idx]) cur_idx += 1 if cur_idx >= len(msg.receiver): del cached_bc_msgs[0] cur_idx = 0 else: # finished break
[docs] def _run_simulation(self): """ Run for standalone simulation (W/O online aggr) """ server_msg_cache = list() while True: if len(self.shared_comm_queue) > 0: msg = self.shared_comm_queue.popleft() if not self.cfg.vertical.use and msg.receiver == [ self.server_id ]: # For the server, move the received message to a # cache for reordering the messages according to # the timestamps msg.serial_num = self.serial_num_for_msg self.serial_num_for_msg += 1 heapq.heappush(server_msg_cache, msg) else: self._handle_msg(msg) elif len(server_msg_cache) > 0: msg = heapq.heappop(server_msg_cache) if self.cfg.asyn.use and self.cfg.asyn.aggregator \ == 'time_up': # When the timestamp of the received message beyond # the deadline for the currency round, trigger the # time up event first and push the message back to # the cache if self.server.trigger_for_time_up(msg.timestamp): heapq.heappush(server_msg_cache, msg) else: self._handle_msg(msg) else: self._handle_msg(msg) else: if self.cfg.asyn.use and self.cfg.asyn.aggregator \ == 'time_up': self.server.trigger_for_time_up() if len(self.shared_comm_queue) == 0 and \ len(server_msg_cache) == 0: break else: # terminate when shared_comm_queue and # server_msg_cache are all empty break
[docs]class DistributedRunner(BaseRunner):
[docs] def _set_up(self): """ To set up server or client for distributed mode. """ # sample resource information if self.resource_info is not None: sampled_index = np.random.choice(list(self.resource_info.keys())) sampled_resource = self.resource_info[sampled_index] else: sampled_resource = None self.server_address = { 'host': self.cfg.distribute.server_host, 'port': self.cfg.distribute.server_port } if self.cfg.distribute.role == 'server': self.server = self._setup_server(resource_info=sampled_resource) elif self.cfg.distribute.role == 'client': # When we set up the client in the distributed mode, we assume # the server has been set up and number with #0 self.client_address = { 'host': self.cfg.distribute.client_host, 'port': self.cfg.distribute.client_port } self.client = self._setup_client(resource_info=sampled_resource)
[docs] def _get_server_args(self, resource_info, client_resource_info): server_data = self.data model = get_model(self.cfg.model, server_data, backend=self.cfg.backend) kw = self.server_address kw.update({'resource_info': resource_info}) return server_data, model, kw
[docs] def _get_client_args(self, client_id, resource_info): client_data = self.data kw = self.client_address kw['server_host'] = self.server_address['host'] kw['server_port'] = self.server_address['port'] kw['resource_info'] = resource_info return client_data, kw
[docs] def run(self): if self.cfg.distribute.role == 'server': self.server.run() return self.server.best_results elif self.cfg.distribute.role == 'client': self.client.join_in() self.client.run()
# TODO: remove FedRunner (keep now for forward compatibility)
[docs]class FedRunner(object): """ This class is used to construct an FL course, which includes `_set_up` and `run`. Arguments: data: The data used in the FL courses, which are formatted as \ ``{'ID':data}`` for standalone mode. More details can be found in \ federatedscope.core.auxiliaries.data_builder . server_class: The server class is used for instantiating a ( \ customized) server. client_class: The client class is used for instantiating a ( \ customized) client. config: The configurations of the FL course. client_configs: The clients' configurations. Warnings: ``FedRunner`` will be removed in the future, consider \ using ``StandaloneRunner`` or ``DistributedRunner`` instead! """ def __init__(self, data, server_class=Server, client_class=Client, config=None, client_configs=None): logger.warning('`federate.core.fed_runner.FedRunner` will be ' 'removed in the future, please use' '`federate.core.fed_runner.get_runner` to get ' 'Runner.') self.data = data self.server_class = server_class self.client_class = client_class assert config is not None, \ "When using FedRunner, you should specify the `config` para" if not config.is_ready_for_run: config.ready_for_run() self.cfg = config self.client_cfgs = client_configs self.mode = self.cfg.federate.mode.lower() self.gpu_manager = GPUManager(gpu_available=self.cfg.use_gpu, specified_device=self.cfg.device) self.unseen_clients_id = [] if self.cfg.federate.unseen_clients_rate > 0: self.unseen_clients_id = np.random.choice( np.arange(1, self.cfg.federate.client_num + 1), size=max( 1, int(self.cfg.federate.unseen_clients_rate * self.cfg.federate.client_num)), replace=False).tolist() # get resource information self.resource_info = get_resource_info( config.federate.resource_info_file) # Check the completeness of msg_handler. self.check() def setup(self): if self.mode == 'standalone': self.shared_comm_queue = deque() self._setup_for_standalone() # in standalone mode, by default, we print the trainer info only # once for better logs readability trainer_representative = self.client[1].trainer if trainer_representative is not None: trainer_representative.print_trainer_meta_info() elif self.mode == 'distributed': self._setup_for_distributed()
[docs] def _setup_for_standalone(self): """ To set up server and client for standalone mode. """ if self.cfg.backend == 'torch': import torch torch.set_num_threads(1) assert self.cfg.federate.client_num != 0, \ "In standalone mode, self.cfg.federate.client_num should be " \ "non-zero. " \ "This is usually cased by using synthetic data and users not " \ "specify a non-zero value for client_num" if self.cfg.federate.method == "global": self.cfg.defrost() self.cfg.federate.client_num = 1 self.cfg.federate.sample_client_num = 1 self.cfg.freeze() # sample resource information if self.resource_info is not None: if len(self.resource_info) < self.cfg.federate.client_num + 1: replace = True logger.warning( f"Because the provided the number of resource information " f"{len(self.resource_info)} is less than the number of " f"participants {self.cfg.federate.client_num+1}, one " f"candidate might be selected multiple times.") else: replace = False sampled_index = np.random.choice( list(self.resource_info.keys()), size=self.cfg.federate.client_num + 1, replace=replace) server_resource_info = self.resource_info[sampled_index[0]] client_resource_info = [ self.resource_info[x] for x in sampled_index[1:] ] else: server_resource_info = None client_resource_info = None self.server = self._setup_server( resource_info=server_resource_info, client_resource_info=client_resource_info) self.client = dict() # assume the client-wise data are consistent in their input&output # shape self._shared_client_model = get_model( self.cfg.model, self.data[1], backend=self.cfg.backend ) if self.cfg.federate.share_local_model else None for client_id in range(1, self.cfg.federate.client_num + 1): self.client[client_id] = self._setup_client( client_id=client_id, client_model=self._shared_client_model, resource_info=client_resource_info[client_id - 1] if client_resource_info is not None else None)
[docs] def _setup_for_distributed(self): """ To set up server or client for distributed mode. """ # sample resource information if self.resource_info is not None: sampled_index = np.random.choice(list(self.resource_info.keys())) sampled_resource = self.resource_info[sampled_index] else: sampled_resource = None self.server_address = { 'host': self.cfg.distribute.server_host, 'port': self.cfg.distribute.server_port } if self.cfg.distribute.role == 'server': self.server = self._setup_server(resource_info=sampled_resource) elif self.cfg.distribute.role == 'client': # When we set up the client in the distributed mode, we assume # the server has been set up and number with #0 self.client_address = { 'host': self.cfg.distribute.client_host, 'port': self.cfg.distribute.client_port } self.client = self._setup_client(resource_info=sampled_resource)
[docs] def run(self): """ To run an FL course, which is called after server/client has been set up. For the standalone mode, a shared message queue will be set up to simulate ``receiving message``. """ self.setup() if self.mode == 'standalone': # trigger the FL course for each_client in self.client: self.client[each_client].join_in() if self.cfg.federate.online_aggr: # any broadcast operation would be executed client-by-client # to avoid the existence of #clients messages at the same time. # currently, only consider centralized topology self._run_simulation_online() else: self._run_simulation() self.server._monitor.finish_fed_runner(fl_mode=self.mode) return self.server.best_results elif self.mode == 'distributed': if self.cfg.distribute.role == 'server': self.server.run() return self.server.best_results elif self.cfg.distribute.role == 'client': self.client.join_in() self.client.run()
def _run_simulation_online(self): def is_broadcast(msg): return len(msg.receiver) >= 1 and msg.sender == 0 cached_bc_msgs = [] cur_idx = 0 while True: if len(self.shared_comm_queue) > 0: msg = self.shared_comm_queue.popleft() if is_broadcast(msg): cached_bc_msgs.append(msg) # assume there is at least one client msg = cached_bc_msgs[0] self._handle_msg(msg, rcv=msg.receiver[cur_idx]) cur_idx += 1 if cur_idx >= len(msg.receiver): del cached_bc_msgs[0] cur_idx = 0 else: self._handle_msg(msg) elif len(cached_bc_msgs) > 0: msg = cached_bc_msgs[0] self._handle_msg(msg, rcv=msg.receiver[cur_idx]) cur_idx += 1 if cur_idx >= len(msg.receiver): del cached_bc_msgs[0] cur_idx = 0 else: # finished break def _run_simulation(self): server_msg_cache = list() while True: if len(self.shared_comm_queue) > 0: msg = self.shared_comm_queue.popleft() if msg.receiver == [self.server_id]: # For the server, move the received message to a # cache for reordering the messages according to # the timestamps heapq.heappush(server_msg_cache, msg) else: self._handle_msg(msg) elif len(server_msg_cache) > 0: msg = heapq.heappop(server_msg_cache) if self.cfg.asyn.use and self.cfg.asyn.aggregator \ == 'time_up': # When the timestamp of the received message beyond # the deadline for the currency round, trigger the # time up event first and push the message back to # the cache if self.server.trigger_for_time_up(msg.timestamp): heapq.heappush(server_msg_cache, msg) else: self._handle_msg(msg) else: self._handle_msg(msg) else: if self.cfg.asyn.use and self.cfg.asyn.aggregator \ == 'time_up': self.server.trigger_for_time_up() if len(self.shared_comm_queue) == 0 and \ len(server_msg_cache) == 0: break else: # terminate when shared_comm_queue and # server_msg_cache are all empty break
[docs] def _setup_server(self, resource_info=None, client_resource_info=None): """ Set up the server """ self.server_id = 0 if self.mode == 'standalone': if self.server_id in self.data: server_data = self.data[self.server_id] model = get_model(self.cfg.model, server_data, backend=self.cfg.backend) else: server_data = None data_representative = self.data[1] model = get_model( self.cfg.model, data_representative, backend=self.cfg.backend ) # get the model according to client's data if the server # does not own data kw = { 'shared_comm_queue': self.shared_comm_queue, 'resource_info': resource_info, 'client_resource_info': client_resource_info } elif self.mode == 'distributed': server_data = self.data model = get_model(self.cfg.model, server_data, backend=self.cfg.backend) kw = self.server_address kw.update({'resource_info': resource_info}) else: raise ValueError('Mode {} is not provided'.format( self.cfg.mode.type)) if self.server_class: self._server_device = self.gpu_manager.auto_choice() server = self.server_class( ID=self.server_id, config=self.cfg, data=server_data, model=model, client_num=self.cfg.federate.client_num, total_round_num=self.cfg.federate.total_round_num, device=self._server_device, unseen_clients_id=self.unseen_clients_id, **kw) if self.cfg.nbafl.use: from federatedscope.core.trainers.trainer_nbafl import \ wrap_nbafl_server wrap_nbafl_server(server) else: raise ValueError logger.info('Server has been set up ... ') return server
[docs] def _setup_client(self, client_id=-1, client_model=None, resource_info=None): """ Set up the client """ self.server_id = 0 if self.mode == 'standalone': client_data = self.data[client_id] kw = { 'shared_comm_queue': self.shared_comm_queue, 'resource_info': resource_info } elif self.mode == 'distributed': client_data = self.data kw = self.client_address kw['server_host'] = self.server_address['host'] kw['server_port'] = self.server_address['port'] kw['resource_info'] = resource_info else: raise ValueError('Mode {} is not provided'.format( self.cfg.mode.type)) if self.client_class: client_specific_config = self.cfg.clone() if self.client_cfgs and \ self.client_cfgs.get('client_{}'.format(client_id)): client_specific_config.defrost() client_specific_config.merge_from_other_cfg( self.client_cfgs.get('client_{}'.format(client_id))) client_specific_config.freeze() client_device = self._server_device if \ self.cfg.federate.share_local_model else \ self.gpu_manager.auto_choice() client = self.client_class( ID=client_id, server_id=self.server_id, config=client_specific_config, data=client_data, model=client_model or get_model(client_specific_config.model, client_data, backend=self.cfg.backend), device=client_device, is_unseen_client=client_id in self.unseen_clients_id, **kw) else: raise ValueError if client_id == -1: logger.info('Client (address {}:{}) has been set up ... '.format( self.client_address['host'], self.client_address['port'])) else: logger.info(f'Client {client_id} has been set up ... ') return client
[docs] def _handle_msg(self, msg, rcv=-1): """ To simulate the message handling process (used only for the standalone mode) """ if rcv != -1: # simulate broadcast one-by-one self.client[rcv].msg_handlers[msg.msg_type](msg) return _, receiver = msg.sender, msg.receiver download_bytes, upload_bytes = msg.count_bytes() if not isinstance(receiver, list): receiver = [receiver] for each_receiver in receiver: if each_receiver == 0: self.server.msg_handlers[msg.msg_type](msg) self.server._monitor.track_download_bytes(download_bytes) else: self.client[each_receiver].msg_handlers[msg.msg_type](msg) self.client[each_receiver]._monitor.track_download_bytes( download_bytes)
[docs] def check(self): """ Check the completeness of Server and Client. """ if not self.cfg.check_completeness: return try: import os import networkx as nx import matplotlib.pyplot as plt # Build check graph G = nx.DiGraph() flags = {0: 'Client', 1: 'Server'} msg_handler_dicts = [ self.client_class.get_msg_handler_dict(), self.server_class.get_msg_handler_dict() ] for flag, msg_handler_dict in zip(flags.keys(), msg_handler_dicts): role, oppo = flags[flag], flags[(flag + 1) % 2] for msg_in, (handler, msgs_out) in \ msg_handler_dict.items(): for msg_out in msgs_out: msg_in_key = f'{oppo}_{msg_in}' handler_key = f'{role}_{handler}' msg_out_key = f'{role}_{msg_out}' G.add_node(msg_in_key, subset=1) G.add_node(handler_key, subset=0 if flag else 2) G.add_node(msg_out_key, subset=1) G.add_edge(msg_in_key, handler_key) G.add_edge(handler_key, msg_out_key) pos = nx.multipartite_layout(G) plt.figure(figsize=(20, 15)) nx.draw(G, pos, with_labels=True, node_color='white', node_size=800, width=1.0, arrowsize=25, arrowstyle='->') fig_path = os.path.join(self.cfg.outdir, 'msg_handler.png') plt.savefig(fig_path) if nx.has_path(G, 'Client_join_in', 'Server_finish'): if nx.is_weakly_connected(G): logger.info(f'Completeness check passes! Save check ' f'results in {fig_path}.') else: logger.warning(f'Completeness check raises warning for ' f'some handlers not in FL process! Save ' f'check results in {fig_path}.') else: logger.error(f'Completeness check fails for there is no' f'path from `join_in` to `finish`! Save ' f'check results in {fig_path}.') except Exception as error: logger.warning(f'Completeness check failed for {error}!') return