import logging
import copy
import os
import sys
import numpy as np
import pickle
import time
from federatedscope.core.monitors.early_stopper import EarlyStopper
from federatedscope.core.message import Message
from federatedscope.core.communication import StandaloneCommManager, \
StandaloneDDPCommManager, gRPCCommManager
from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator
from federatedscope.core.auxiliaries.sampler_builder import get_sampler
from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \
Timeout, merge_param_dict
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
from federatedscope.core.secret_sharing import AdditiveSecretSharing
from federatedscope.core.workers.base_server import BaseServer
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
[docs]class Server(BaseServer):
"""
The Server class, which describes the behaviors of server in an FL \
course. The behaviors are described by the handled functions (named as \
``callback_funcs_for_xxx``).
Arguments:
ID: The unique ID of the server, which is set to 0 by default
state: The training round
config: the configuration
data: The data owned by the server (for global evaluation)
model: The model used for aggregation
client_num: The (expected) client num to start the FL course
total_round_num: The total number of the training round
device: The device to run local training and evaluation
Attributes:
ID: ID of worker
state: the training round index
model: the model maintained locally
cfg: the configuration of FL course, \
see ``federatedscope.core.configs``
mode: the run mode for FL, ``distributed`` or ``standalone``
monitor: monite FL course and record metrics, \
see ``federatedscope.core.monitors.monitor.Monitor``
trainer: instantiated trainer, see ``federatedscope.core.trainers``
best_results: best results ever seen
history_results: all evaluation results
early_stopper: determine when to early stop, \
see ``federatedscope.core.monitors.early_stopper.EarlyStopper``
aggregators: a protocol for aggregate all clients' model(s), see \
``federatedscope.core.aggregators``
sample_client_num: number of client aggregated in each round
msg_buffer: dict buffer for storing message
staled_msg_buffer: list buffer for storing staled message
comm_manager: manager for communication, \
see ``federatedscope.core.communication``
"""
def __init__(self,
ID=-1,
state=0,
config=None,
data=None,
model=None,
client_num=5,
total_round_num=10,
device='cpu',
strategy=None,
unseen_clients_id=None,
**kwargs):
super(Server, self).__init__(ID, state, config, model, strategy)
# Register message handlers
self._register_default_handlers()
# Un-configured worker
if config is None:
return
self.data = data
self.device = device
self.best_results = dict()
self.history_results = dict()
self.early_stopper = EarlyStopper(
self._cfg.early_stop.patience, self._cfg.early_stop.delta,
self._cfg.early_stop.improve_indicator_mode,
self._monitor.the_larger_the_better)
if self._cfg.federate.share_local_model \
and not self._cfg.federate.process_num > 1:
# put the model to the specified device
model.to(device)
# Build aggregator
self.aggregator = get_aggregator(self._cfg.federate.method,
model=model,
device=device,
online=self._cfg.federate.online_aggr,
config=self._cfg)
if self._cfg.federate.restore_from != '':
if not os.path.exists(self._cfg.federate.restore_from):
logger.warning(f'Invalid `restore_from`:'
f' {self._cfg.federate.restore_from}.')
else:
_ = self.aggregator.load_model(self._cfg.federate.restore_from)
logger.info("Restored the model from {}-th round's ckpt")
if int(config.model.model_num_per_trainer) != \
config.model.model_num_per_trainer or \
config.model.model_num_per_trainer < 1:
raise ValueError(
f"model_num_per_trainer should be integer and >= 1, "
f"got {config.model.model_num_per_trainer}.")
self.model_num = config.model.model_num_per_trainer
self.models = [self.model]
self.aggregators = [self.aggregator]
if self.model_num > 1:
self.models.extend(
[copy.deepcopy(self.model) for _ in range(self.model_num - 1)])
self.aggregators.extend([
copy.deepcopy(self.aggregator)
for _ in range(self.model_num - 1)
])
# function for recovering shared secret
self.recover_fun = AdditiveSecretSharing(
shared_party_num=int(self._cfg.federate.sample_client_num)
).fixedpoint2float if self._cfg.federate.use_ss else None
if self._cfg.federate.make_global_eval:
# set up a trainer for conducting evaluation in server
assert self.models is not None
assert self.data is not None
self.trainer = get_trainer(
model=self.models[0],
data=self.data,
device=self.device,
config=self._cfg,
only_for_eval=True,
monitor=self._monitor
) # the trainer is only used for global evaluation
self.trainers = [self.trainer]
if self.model_num > 1:
# By default, the evaluation is conducted by calling
# trainer[i].eval over all internal models
self.trainers.extend([
copy.deepcopy(self.trainer)
for _ in range(self.model_num - 1)
])
# Initialize the number of joined-in clients
self._client_num = client_num
self._total_round_num = total_round_num
self.sample_client_num = int(self._cfg.federate.sample_client_num)
self.join_in_client_num = 0
self.join_in_info = dict()
# the unseen clients indicate the ones that do not contribute to FL
# process by training on their local data and uploading their local
# model update. The splitting is useful to check participation
# generalization gap in
# [ICLR'22, What Do We Mean by Generalization in Federated Learning?]
self.unseen_clients_id = [] if unseen_clients_id is None \
else unseen_clients_id
# Server state
self.is_finish = False
# Sampler
if self._cfg.federate.sampler in ['uniform']:
self.sampler = get_sampler(
sample_strategy=self._cfg.federate.sampler,
client_num=self.client_num,
client_info=None)
else:
# Some type of sampler would be instantiated in trigger_for_start,
# since they need more information
self.sampler = None
# Current Timestamp
self.cur_timestamp = 0
self.deadline_for_cur_round = 1
# Staleness toleration
self.staleness_toleration = self._cfg.asyn.staleness_toleration if \
self._cfg.asyn.use else 0
self.dropout_num = 0
# Device information
self.resource_info = kwargs['resource_info'] \
if 'resource_info' in kwargs else None
self.client_resource_info = kwargs['client_resource_info'] \
if 'client_resource_info' in kwargs else None
# Initialize communication manager and message buffer
self.msg_buffer = {'train': dict(), 'eval': dict()}
self.staled_msg_buffer = list()
if self.mode == 'standalone':
comm_queue = kwargs.get('shared_comm_queue', None)
if self._cfg.federate.process_num > 1:
id2comm = kwargs.get('id2comm', None)
self.comm_manager = StandaloneDDPCommManager(
comm_queue=comm_queue,
monitor=self._monitor,
id2comm=id2comm)
else:
self.comm_manager = StandaloneCommManager(
comm_queue=comm_queue, monitor=self._monitor)
elif self.mode == 'distributed':
host = kwargs['host']
port = kwargs['port']
self.comm_manager = gRPCCommManager(host=host,
port=port,
client_num=client_num,
cfg=self._cfg.distribute)
logger.info('Server: Listen to {}:{}...'.format(host, port))
# inject noise before broadcast
self._noise_injector = None
@property
def client_num(self):
return self._client_num
@client_num.setter
def client_num(self, value):
self._client_num = value
@property
def total_round_num(self):
return self._total_round_num
@total_round_num.setter
def total_round_num(self, value):
self._total_round_num = value
def register_noise_injector(self, func):
self._noise_injector = func
[docs] def run(self):
"""
To start the FL course, listen and handle messages (for distributed \
mode).
"""
# Begin: Broadcast model parameters and start to FL train
while self.join_in_client_num < self.client_num:
msg = self.comm_manager.receive()
self.msg_handlers[msg.msg_type](msg)
# Running: listen for message (updates from clients),
# aggregate and broadcast feedbacks (aggregated model parameters)
min_received_num = self._cfg.asyn.min_received_num \
if self._cfg.asyn.use else self._cfg.federate.sample_client_num
num_failure = 0
time_budget = self._cfg.asyn.time_budget if self._cfg.asyn.use else -1
with Timeout(time_budget) as time_counter:
while self.state <= self.total_round_num:
try:
msg = self.comm_manager.receive()
move_on_flag = self.msg_handlers[msg.msg_type](msg)
if move_on_flag:
time_counter.reset()
except TimeoutError:
logger.info('Time out at the training round #{}'.format(
self.state))
move_on_flag_eval = self.check_and_move_on(
min_received_num=min_received_num,
check_eval_result=True)
move_on_flag = self.check_and_move_on(
min_received_num=min_received_num)
if not move_on_flag and not move_on_flag_eval:
num_failure += 1
# Terminate the training if the number of failure
# exceeds the maximum number (default value: 10)
if time_counter.exceed_max_failure(num_failure):
logger.info(f'----------- Training fails at round '
f'#{self.state}-------------')
break
# Time out, broadcast the model para and re-start
# the training round
logger.info(
f'----------- Re-starting the training round ('
f'Round #{self.state}) for {num_failure} time '
f'-------------')
# TODO: Clean the msg_buffer
if self.state in self.msg_buffer['train']:
self.msg_buffer['train'][self.state].clear()
self.broadcast_model_para(
msg_type='model_para',
sample_client_num=self.sample_client_num)
else:
num_failure = 0
time_counter.reset()
self.terminate(msg_type='finish')
[docs] def check_and_move_on(self,
check_eval_result=False,
min_received_num=None):
"""
To check the message_buffer. When enough messages are receiving, \
some events (such as perform aggregation, evaluation, and move to \
the next training round) would be triggered.
Arguments:
check_eval_result (bool): If True, check the message buffer for \
evaluation; and check the message buffer for training \
otherwise.
min_received_num: number of minimal received message, used for \
async mode
"""
if min_received_num is None:
if self._cfg.asyn.use:
min_received_num = self._cfg.asyn.min_received_num
else:
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num
if check_eval_result and self._cfg.federate.mode.lower(
) == "standalone":
# in evaluation stage and standalone simulation mode, we assume
# strong synchronization that receives responses from all clients
min_received_num = len(self.comm_manager.get_neighbors().keys())
move_on_flag = True # To record whether moving to a new training
# round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):
if not check_eval_result:
# Receiving enough feedback in the training process
aggregated_num = self._perform_federated_aggregation()
self.state += 1
if self.state % self._cfg.eval.freq == 0 and self.state != \
self.total_round_num:
# Evaluate
logger.info(f'Server: Starting evaluation at the end '
f'of round {self.state - 1}.')
self.eval()
if self.state < self.total_round_num:
# Move to next round of training
logger.info(
f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
self.msg_buffer['train'][self.state] = dict()
self.staled_msg_buffer.clear()
# Start a new training round
self._start_new_training_round(aggregated_num)
else:
# Final Evaluate
logger.info('Server: Training is finished! Starting '
'evaluation.')
self.eval()
else:
# Receiving enough feedback in the evaluation process
self._merge_and_format_eval_results()
if self.state >= self.total_round_num:
self.is_finish = True
else:
move_on_flag = False
return move_on_flag
[docs] def check_and_save(self):
"""
To save the results and save model after each evaluation, and check \
whether to early stop.
"""
# early stopping
if "Results_weighted_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_weighted_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_weighted_avg'][
self._cfg.eval.best_res_update_round_wise_key])
elif "Results_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_avg'][
self._cfg.eval.best_res_update_round_wise_key])
else:
should_stop = False
if should_stop:
self._monitor.global_converged()
self.comm_manager.send(
Message(
msg_type="converged",
sender=self.ID,
receiver=list(self.comm_manager.neighbors.keys()),
timestamp=self.cur_timestamp,
state=self.state,
))
self.state = self.total_round_num + 1
if should_stop or self.state == self.total_round_num:
logger.info('Server: Final evaluation is finished! Starting '
'merging results.')
# last round or early stopped
self.save_best_results()
if not self._cfg.federate.make_global_eval:
self.save_client_eval_results()
self.terminate(msg_type='finish')
# Clean the clients evaluation msg buffer
if not self._cfg.federate.make_global_eval:
round = max(self.msg_buffer['eval'].keys())
self.msg_buffer['eval'][round].clear()
if self.state == self.total_round_num:
# break out the loop for distributed mode
self.state += 1
[docs] def _start_new_training_round(self, aggregated_num=0):
"""
The behaviors for starting a new training round
"""
if self._cfg.asyn.use: # for asynchronous training
if self._cfg.asyn.aggregator == "time_up":
# Update the deadline according to the time budget
self.deadline_for_cur_round = \
self.cur_timestamp + self._cfg.asyn.time_budget
if self._cfg.asyn.broadcast_manner == \
'after_aggregating':
if self._cfg.asyn.overselection:
sample_client_num = self.sample_client_num
else:
sample_client_num = aggregated_num + \
self.dropout_num
self.broadcast_model_para(msg_type='model_para',
sample_client_num=sample_client_num)
self.dropout_num = 0
else: # for synchronous training
self.broadcast_model_para(msg_type='model_para',
sample_client_num=self.sample_client_num)
[docs] def save_best_results(self):
"""
To Save the best evaluation results.
"""
if self._cfg.federate.save_to != '':
self.aggregator.save_model(self._cfg.federate.save_to, self.state)
formatted_best_res = self._monitor.format_eval_res(
results=self.best_results,
rnd="Final",
role='Server #',
forms=["raw"],
return_raw=True)
logger.info(formatted_best_res)
self._monitor.save_formatted_results(formatted_best_res)
[docs] def save_client_eval_results(self):
"""
save the evaluation results of each client when the fl course \
early stopped or terminated
"""
rnd = max(self.msg_buffer['eval'].keys())
eval_msg_buffer = self.msg_buffer['eval'][rnd]
with open(os.path.join(self._cfg.outdir, "eval_results.log"),
"a") as outfile:
for client_id, client_eval_results in eval_msg_buffer.items():
formatted_res = self._monitor.format_eval_res(
client_eval_results,
rnd=self.state,
role='Client #{}'.format(client_id),
return_raw=True)
logger.info(formatted_res)
outfile.write(str(formatted_res) + "\n")
[docs] def merge_eval_results_from_all_clients(self):
"""
Merge evaluation results from all clients, update best, \
log the merged results and save them into eval_results.log
Returns:
the formatted merged results
"""
round = max(self.msg_buffer['eval'].keys())
eval_msg_buffer = self.msg_buffer['eval'][round]
eval_res_participated_clients = []
eval_res_unseen_clients = []
for client_id in eval_msg_buffer:
if eval_msg_buffer[client_id] is None:
continue
if client_id in self.unseen_clients_id:
eval_res_unseen_clients.append(eval_msg_buffer[client_id])
else:
eval_res_participated_clients.append(
eval_msg_buffer[client_id])
formatted_logs_all_set = dict()
for merge_type, eval_res_set in [("participated",
eval_res_participated_clients),
("unseen", eval_res_unseen_clients)]:
if eval_res_set != []:
metrics_all_clients = dict()
for client_eval_results in eval_res_set:
for key in client_eval_results.keys():
if key not in metrics_all_clients:
metrics_all_clients[key] = list()
metrics_all_clients[key].append(
float(client_eval_results[key]))
formatted_logs = self._monitor.format_eval_res(
metrics_all_clients,
rnd=round,
role='Server #',
forms=self._cfg.eval.report)
if merge_type == "unseen":
for key, val in copy.deepcopy(formatted_logs).items():
if isinstance(val, dict):
# to avoid the overrides of results using the
# same name, we use new keys with postfix `unseen`:
# 'Results_weighted_avg' ->
# 'Results_weighted_avg_unseen'
formatted_logs[key + "_unseen"] = val
del formatted_logs[key]
logger.info(formatted_logs)
formatted_logs_all_set.update(formatted_logs)
self._monitor.update_best_result(
self.best_results,
metrics_all_clients,
results_type="unseen_client_best_individual"
if merge_type == "unseen" else "client_best_individual")
self._monitor.save_formatted_results(formatted_logs)
for form in self._cfg.eval.report:
if form != "raw":
metric_name = form + "_unseen" if merge_type == \
"unseen" else form
self._monitor.update_best_result(
self.best_results,
formatted_logs[f"Results_{metric_name}"],
results_type=f"unseen_client_summarized_{form}"
if merge_type == "unseen" else
f"client_summarized_{form}")
return formatted_logs_all_set
[docs] def broadcast_model_para(self,
msg_type='model_para',
sample_client_num=-1,
filter_unseen_clients=True):
"""
To broadcast the message to all clients or sampled clients
Arguments:
msg_type: 'model_para' or other user defined msg_type
sample_client_num: the number of sampled clients in the broadcast \
behavior. And ``sample_client_num = -1`` denotes to \
broadcast to all the clients.
filter_unseen_clients: whether filter out the unseen clients that \
do not contribute to FL process by training on their local \
data and uploading their local model update. The splitting is \
useful to check participation generalization gap in [ICLR'22, \
What Do We Mean by Generalization in Federated Learning?] \
You may want to set it to be False when in evaluation stage
"""
if filter_unseen_clients:
# to filter out the unseen clients when sampling
self.sampler.change_state(self.unseen_clients_id, 'unseen')
if sample_client_num > 0:
receiver = self.sampler.sample(size=sample_client_num)
else:
# broadcast to all clients
receiver = list(self.comm_manager.neighbors.keys())
if msg_type == 'model_para':
self.sampler.change_state(receiver, 'working')
if self._noise_injector is not None and msg_type == 'model_para':
# Inject noise only when broadcast parameters
for model_idx_i in range(len(self.models)):
num_sample_clients = [
v["num_sample"] for v in self.join_in_info.values()
]
self._noise_injector(self._cfg, num_sample_clients,
self.models[model_idx_i])
skip_broadcast = self._cfg.federate.method in ["local", "global"]
if self.model_num > 1:
model_para = [{} if skip_broadcast else model.state_dict()
for model in self.models]
else:
model_para = {} if skip_broadcast else self.models[0].state_dict()
# quantization
if msg_type == 'model_para' and not skip_broadcast and \
self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_quantization
nbits = self._cfg.quantization.nbits
if self.model_num > 1:
model_para = [
symmetric_uniform_quantization(x, nbits)
for x in model_para
]
else:
model_para = symmetric_uniform_quantization(model_para, nbits)
# We define the evaluation happens at the end of an epoch
rnd = self.state - 1 if msg_type == 'evaluate' else self.state
self.comm_manager.send(
Message(msg_type=msg_type,
sender=self.ID,
receiver=receiver,
state=min(rnd, self.total_round_num),
timestamp=self.cur_timestamp,
content=model_para))
if self._cfg.federate.online_aggr:
for idx in range(self.model_num):
self.aggregators[idx].reset()
if filter_unseen_clients:
# restore the state of the unseen clients within sampler
self.sampler.change_state(self.unseen_clients_id, 'seen')
[docs] def broadcast_client_address(self):
"""
To broadcast the communication addresses of clients (used for \
additive secret sharing)
"""
self.comm_manager.send(
Message(msg_type='address',
sender=self.ID,
receiver=list(self.comm_manager.neighbors.keys()),
state=self.state,
timestamp=self.cur_timestamp,
content=self.comm_manager.get_neighbors()))
[docs] def check_buffer(self,
cur_round,
min_received_num,
check_eval_result=False):
"""
To check the message buffer
Arguments:
cur_round (int): The current round number
min_received_num (int): The minimal number of the receiving \
messages
check_eval_result (bool): To check training results for \
evaluation results
Returns
bool: Whether enough messages have been received or not
"""
if check_eval_result:
if 'eval' not in self.msg_buffer.keys() or len(
self.msg_buffer['eval'].keys()) == 0:
return False
buffer = self.msg_buffer['eval']
cur_round = max(buffer.keys())
cur_buffer = buffer[cur_round]
return len(cur_buffer) >= min_received_num
else:
if cur_round not in self.msg_buffer['train']:
cur_buffer = dict()
else:
cur_buffer = self.msg_buffer['train'][cur_round]
if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up':
if self.cur_timestamp >= self.deadline_for_cur_round and len(
cur_buffer) + len(self.staled_msg_buffer) == 0:
# When the time budget is run out but the server has not
# received any feedback
logger.warning(
f'The server has not received any feedback when the '
f'time budget has run out, therefore the server would '
f'wait for more {self._cfg.asyn.time_budget} seconds. '
f'Maybe you should carefully reset '
f'`cfg.asyn.time_budget` to a reasonable value.')
self.deadline_for_cur_round += self._cfg.asyn.time_budget
if self._cfg.asyn.broadcast_manner == \
'after_aggregating' and self.dropout_num != 0:
self.broadcast_model_para(
msg_type='model_para',
sample_client_num=self.dropout_num)
self.dropout_num = 0
return self.cur_timestamp >= self.deadline_for_cur_round
else:
return len(cur_buffer)+len(self.staled_msg_buffer) >= \
min_received_num
[docs] def check_client_join_in(self):
"""
To check whether all the clients have joined in the FL course.
"""
if len(self._cfg.federate.join_in_info) != 0:
return len(self.join_in_info) == self.client_num
else:
return self.join_in_client_num == self.client_num
[docs] def trigger_for_start(self):
"""
To start the FL course when the expected number of clients have joined
"""
if self.check_client_join_in():
if self._cfg.federate.use_ss or self._cfg.vertical.use:
self.broadcast_client_address()
# get sampler
if 'client_resource' in self._cfg.federate.join_in_info:
client_resource = [
self.join_in_info[client_index]['client_resource']
for client_index in np.arange(1, self.client_num + 1)
]
else:
if self._cfg.backend == 'torch':
model_size = sys.getsizeof(pickle.dumps(
self.models[0])) / 1024.0 * 8.
else:
# TODO: calculate model size for TF Model
model_size = 1.0
logger.warning(f'The calculation of model size in backend:'
f'{self._cfg.backend} is not provided.')
client_resource = [
model_size / float(x['communication']) +
float(x['computation']) / 1000.
for x in self.client_resource_info
] if self.client_resource_info is not None else None
if self.sampler is None:
self.sampler = get_sampler(
sample_strategy=self._cfg.federate.sampler,
client_num=self.client_num,
client_info=client_resource)
# change the deadline if the asyn.aggregator is `time up`
if self._cfg.asyn.use and self._cfg.asyn.aggregator == 'time_up':
self.deadline_for_cur_round = self.cur_timestamp + \
self._cfg.asyn.time_budget
# start feature engineering
self.trigger_for_feat_engr(
self.broadcast_model_para, {
'msg_type': 'model_para',
'sample_client_num': self.sample_client_num
})
logger.info(
'----------- Starting training (Round #{:d}) -------------'.
format(self.state))
[docs] def trigger_for_feat_engr(self,
trigger_train_func,
kwargs_for_trigger_train_func={}):
"""
Interface for feature engineering, the default operation is none
"""
trigger_train_func(**kwargs_for_trigger_train_func)
[docs] def trigger_for_time_up(self, check_timestamp=None):
"""
The handler for time up: modify the currency timestamp \
and check the trigger condition
"""
if self.is_finish:
return False
if check_timestamp is not None and \
check_timestamp < self.deadline_for_cur_round:
return False
self.cur_timestamp = self.deadline_for_cur_round
self.check_and_move_on()
return True
[docs] def terminate(self, msg_type='finish'):
"""
To terminate the FL course
"""
self.is_finish = True
if self.model_num > 1:
model_para = [model.state_dict() for model in self.models]
else:
model_para = self.models[0].state_dict()
self._monitor.finish_fl()
self.comm_manager.send(
Message(msg_type=msg_type,
sender=self.ID,
receiver=list(self.comm_manager.neighbors.keys()),
state=self.state,
timestamp=self.cur_timestamp,
content=model_para))
[docs] def eval(self):
"""
To conduct evaluation. When ``cfg.federate.make_global_eval=True``, \
a global evaluation is conducted by the server.
"""
if self._cfg.federate.make_global_eval:
# By default, the evaluation is conducted one-by-one for all
# internal models;
# for other cases such as ensemble, override the eval function
for i in range(self.model_num):
trainer = self.trainers[i]
# Preform evaluation in server
metrics = {}
for split in self._cfg.eval.split:
eval_metrics = trainer.evaluate(
target_data_split_name=split)
metrics.update(**eval_metrics)
formatted_eval_res = self._monitor.format_eval_res(
metrics,
rnd=self.state,
role='Server #',
forms=self._cfg.eval.report,
return_raw=self._cfg.federate.make_global_eval)
self._monitor.update_best_result(
self.best_results,
formatted_eval_res['Results_raw'],
results_type="server_global_eval")
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res)
self._monitor.save_formatted_results(formatted_eval_res)
logger.info(formatted_eval_res)
self.check_and_save()
else:
# Preform evaluation in clients
self.broadcast_model_para(msg_type='evaluate',
filter_unseen_clients=False)
[docs] def callback_funcs_model_para(self, message: Message):
"""
The handling function for receiving model parameters, which triggers \
``check_and_move_on`` (perform aggregation when enough feedback has \
been received). This handling function is widely used in various FL \
courses.
Arguments:
message: The received message.
"""
if self.is_finish:
return 'finish'
round = message.state
sender = message.sender
timestamp = message.timestamp
content = message.content
self.sampler.change_state(sender, 'idle')
# dequantization
if self._cfg.quantization.method == 'uniform':
from federatedscope.core.compression import \
symmetric_uniform_dequantization
if isinstance(content[1], list): # multiple model
sample_size = content[0]
quant_model = [
symmetric_uniform_dequantization(x) for x in content[1]
]
else:
sample_size = content[0]
quant_model = symmetric_uniform_dequantization(content[1])
content = (sample_size, quant_model)
# update the currency timestamp according to the received message
assert timestamp >= self.cur_timestamp # for test
self.cur_timestamp = timestamp
if round == self.state:
if round not in self.msg_buffer['train']:
self.msg_buffer['train'][round] = dict()
# Save the messages in this round
self.msg_buffer['train'][round][sender] = content
elif round >= self.state - self.staleness_toleration:
# Save the staled messages
self.staled_msg_buffer.append((round, sender, content))
else:
# Drop the out-of-date messages
logger.info(f'Drop a out-of-date message from round #{round}')
self.dropout_num += 1
if self._cfg.federate.online_aggr:
self.aggregator.inc(content)
move_on_flag = self.check_and_move_on()
if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \
'after_receiving':
self.broadcast_model_para(msg_type='model_para',
sample_client_num=1)
return move_on_flag
[docs] def callback_funcs_for_join_in(self, message: Message):
"""
The handling function for receiving the join in information. The \
server might request for some information (such as \
``num_of_samples``) if necessary, assign IDs for the servers. \
If all the clients have joined in, the training process will be \
triggered.
Arguments:
message: The received message
"""
if 'info' in message.msg_type:
sender, info = message.sender, message.content
for key in self._cfg.federate.join_in_info:
assert key in info
self.join_in_info[sender] = info
logger.info('Server: Client #{:d} has joined in !'.format(sender))
else:
self.join_in_client_num += 1
sender, address = message.sender, message.content
if int(sender) == -1: # assign number to client
sender = self.join_in_client_num
self.comm_manager.add_neighbors(neighbor_id=sender,
address=address)
self.comm_manager.send(
Message(msg_type='assign_client_id',
sender=self.ID,
receiver=[sender],
state=self.state,
timestamp=self.cur_timestamp,
content=str(sender)))
else:
self.comm_manager.add_neighbors(neighbor_id=sender,
address=address)
if len(self._cfg.federate.join_in_info) != 0:
self.comm_manager.send(
Message(msg_type='ask_for_join_in_info',
sender=self.ID,
receiver=[sender],
state=self.state,
timestamp=self.cur_timestamp,
content=self._cfg.federate.join_in_info.copy()))
self.trigger_for_start()
[docs] def callback_funcs_for_metrics(self, message: Message):
"""
The handling function for receiving the evaluation results, \
which triggers ``check_and_move_on`` (perform aggregation when \
enough feedback has been received).
Arguments:
message: The received message
"""
rnd = message.state
sender = message.sender
content = message.content
if rnd not in self.msg_buffer['eval'].keys():
self.msg_buffer['eval'][rnd] = dict()
self.msg_buffer['eval'][rnd][sender] = content
return self.check_and_move_on(check_eval_result=True)
@classmethod
def get_msg_handler_dict(cls):
return cls().msg_handlers_str