Source code for federatedscope.core.trainers.torch_trainer

import os
import logging

import numpy as np
try:
    import torch
    from torch.utils.data import DataLoader, Dataset
except ImportError:
    torch = None
    DataLoader = None
    Dataset = None

from federatedscope.core.trainers.enums import MODE, LIFECYCLE
from federatedscope.core.trainers.trainer import Trainer
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler
from federatedscope.core.data import ClientData
from federatedscope.core.data.wrap_dataset import WrapDataset
from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
from federatedscope.core.auxiliaries.ReIterator import ReIterator
from federatedscope.core.auxiliaries.utils import param2tensor, \
    merge_param_dict
from federatedscope.core.monitors.monitor import Monitor

logger = logging.getLogger(__name__)


[docs]class GeneralTorchTrainer(Trainer):
[docs] def get_model_para(self): if self.cfg.federate.process_num > 1: return self._param_filter(self.ctx.model.state_dict()) else: return self._param_filter( self.ctx.model.state_dict() if self.cfg.federate. share_local_model else self.ctx.model.cpu().state_dict())
[docs] def setup_data(self, ctx): """ Initialization data by ``cfg``. """ if isinstance(ctx.data, ClientData): ctx.data.setup(ctx.cfg) else: logger.warning(f'The data type should be `ClientData` to ' f'enable new `config`, but got ' f'{type(ctx.data)} instead.')
[docs] def parse_data(self, data): """Populate "${split}_data", "${split}_loader" and "num_${ split}_data" for different data splits """ init_dict = dict() if isinstance(data, dict): for split in data.keys(): if split not in ['train', 'val', 'test']: continue init_dict["{}_data".format(split)] = None init_dict["{}_loader".format(split)] = None init_dict["num_{}_data".format(split)] = 0 if data.get(split, None) is not None: if isinstance(data.get(split), Dataset): init_dict["{}_data".format(split)] = data.get(split) init_dict["num_{}_data".format(split)] = len( data.get(split)) elif isinstance(data.get(split), DataLoader): init_dict["{}_loader".format(split)] = data.get(split) init_dict["num_{}_data".format(split)] = len( data.get(split).dataset) elif isinstance(data.get(split), dict): init_dict["{}_data".format(split)] = data.get(split) init_dict["num_{}_data".format(split)] = len( data.get(split)['y']) else: raise TypeError("Type {} is not supported.".format( type(data.get(split)))) else: raise TypeError("Type of data should be dict.") return init_dict
[docs] def update(self, model_parameters, strict=False): """ Called by the FL client to update the model parameters Arguments: model_parameters (dict): PyTorch Module object's state_dict. """ for key in model_parameters: model_parameters[key] = param2tensor(model_parameters[key]) # Due to lazy load, we merge two state dict merged_param = merge_param_dict(self.ctx.model.state_dict().copy(), self._param_filter(model_parameters)) self.ctx.model.load_state_dict(merged_param, strict=strict)
def evaluate(self, target_data_split_name="test"): with torch.no_grad(): super(GeneralTorchTrainer, self).evaluate(target_data_split_name) return self.ctx.eval_metrics def register_default_hooks_train(self): self.register_hook_in_train(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_train( self._hook_on_fit_start_calculate_model_size, "on_fit_start") self.register_hook_in_train(self._hook_on_epoch_start, "on_epoch_start") self.register_hook_in_train(self._hook_on_batch_start_init, "on_batch_start") self.register_hook_in_train(self._hook_on_batch_forward, "on_batch_forward") self.register_hook_in_train(self._hook_on_batch_forward_regularizer, "on_batch_forward") self.register_hook_in_train(self._hook_on_batch_forward_flop_count, "on_batch_forward") self.register_hook_in_train(self._hook_on_batch_backward, "on_batch_backward") self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end") self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end") def register_default_hooks_ft(self): self.register_hook_in_ft(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_ft(self._hook_on_fit_start_calculate_model_size, "on_fit_start") self.register_hook_in_ft(self._hook_on_epoch_start, "on_epoch_start") self.register_hook_in_ft(self._hook_on_batch_start_init, "on_batch_start") self.register_hook_in_ft(self._hook_on_batch_forward, "on_batch_forward") self.register_hook_in_ft(self._hook_on_batch_forward_regularizer, "on_batch_forward") self.register_hook_in_ft(self._hook_on_batch_forward_flop_count, "on_batch_forward") self.register_hook_in_ft(self._hook_on_batch_backward, "on_batch_backward") self.register_hook_in_ft(self._hook_on_batch_end, "on_batch_end") self.register_hook_in_ft(self._hook_on_fit_end, "on_fit_end") def register_default_hooks_eval(self): # test/val self.register_hook_in_eval(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start") self.register_hook_in_eval(self._hook_on_batch_start_init, "on_batch_start") self.register_hook_in_eval(self._hook_on_batch_forward, "on_batch_forward") self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end") self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")
[docs] def _hook_on_fit_start_init(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.model`` Move to ``ctx.device`` ``ctx.optimizer`` Initialize by ``ctx.cfg`` ``ctx.scheduler`` Initialize by ``ctx.cfg`` ``ctx.loss_batch_total`` Initialize to 0 ``ctx.loss_regular_total`` Initialize to 0 ``ctx.num_samples`` Initialize to 0 ``ctx.ys_true`` Initialize to ``[]`` ``ctx.ys_prob`` Initialize to ``[]`` ================================== =========================== """ # prepare model and optimizer ctx.model.to(ctx.device) if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]: # Initialize optimizer here to avoid the reuse of optimizers # across different routines ctx.optimizer = get_optimizer(ctx.model, **ctx.cfg[ctx.cur_mode].optimizer) ctx.scheduler = get_scheduler(ctx.optimizer, **ctx.cfg[ctx.cur_mode].scheduler) # TODO: the number of batch and epoch is decided by the current mode # and data split, so the number of batch and epoch should be # initialized at the beginning of the routine # prepare statistics ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE) ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE) ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE) ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE) ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE)
[docs] def _hook_on_fit_start_calculate_model_size(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.monitor`` Track model size ================================== =========================== """ if not isinstance(ctx.monitor, Monitor): logger.warning( f"The trainer {type(self)} does contain a valid monitor, " f"this may be caused by initializing trainer subclasses " f"without passing a valid monitor instance." f"Plz check whether this is you want.") return if ctx.monitor.total_model_size == 0: ctx.monitor.track_model_size(ctx.models)
[docs] def _hook_on_epoch_start(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.{ctx.cur_split}_loader`` Initialize DataLoader ================================== =========================== """ # prepare dataloader if ctx.get("{}_loader".format(ctx.cur_split)) is None: loader = get_dataloader( WrapDataset(ctx.get("{}_data".format(ctx.cur_split))), self.cfg, ctx.cur_split) setattr(ctx, "{}_loader".format(ctx.cur_split), ReIterator(loader)) elif not isinstance(ctx.get("{}_loader".format(ctx.cur_split)), ReIterator): setattr(ctx, "{}_loader".format(ctx.cur_split), ReIterator(ctx.get("{}_loader".format(ctx.cur_split)))) else: ctx.get("{}_loader".format(ctx.cur_split)).reset()
[docs] def _hook_on_batch_start_init(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.data_batch`` Initialize batch data ================================== =========================== """ # prepare data batch try: ctx.data_batch = CtxVar( next(ctx.get("{}_loader".format(ctx.cur_split))), LIFECYCLE.BATCH) except StopIteration: raise StopIteration
[docs] def _hook_on_batch_forward(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.y_true`` Move to `ctx.device` ``ctx.y_prob`` Forward propagation get y_prob ``ctx.loss_batch`` Calculate the loss ``ctx.batch_size`` Get the batch_size ================================== =========================== """ x, label = [_.to(ctx.device) for _ in ctx.data_batch] pred = ctx.model(x) if len(label.size()) == 0: label = label.unsqueeze(0) ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH) ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH)
[docs] def _hook_on_batch_forward_flop_count(self, ctx): """ The monitoring hook to calculate the flops during the fl course Note: For customized cases that the forward process is not only \ based on ctx.model, please override this function (inheritance \ case) or replace this hook (plug-in case) The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.monitor`` Track average flops ================================== =========================== """ if not isinstance(ctx.monitor, Monitor): logger.warning( f"The trainer {type(self)} does contain a valid monitor, " f"this may be caused by initializing trainer subclasses " f"without passing a valid monitor instance." f"Please check whether this is you want.") return if self.cfg.eval.count_flops and ctx.monitor.flops_per_sample == 0: # calculate the flops_per_sample try: x, y = [_.to(ctx.device) for _ in ctx.data_batch] from fvcore.nn import FlopCountAnalysis flops_one_batch = FlopCountAnalysis(ctx.model, x).total() if self.model_nums > 1 and ctx.mirrored_models: flops_one_batch *= self.model_nums logger.warning( "the flops_per_batch is multiplied " "by internal model nums as self.mirrored_models=True." "if this is not the case you want, " "please customize the count hook") ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size) except: # Raise warning at the first failure logger.warning( "current flop count implementation is for general " "trainer case: " "1) ctx.data_batch = [x, y]; and" "2) the ctx.model takes only x as input." "Please check the forward format or implement your own " "flop_count function") ctx.monitor.flops_per_sample = -1 # by default, we assume the data has the same input shape, # thus simply multiply the flops to avoid redundant forward ctx.monitor.total_flops += ctx.monitor.flops_per_sample * \ ctx.batch_size
[docs] def _hook_on_batch_forward_regularizer(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.loss_regular`` Calculate the regular loss ``ctx.loss_task`` Sum the ``ctx.loss_regular`` \ and ``ctx.loss`` ================================== =========================== """ ctx.loss_regular = CtxVar( self.cfg.regularizer.mu * ctx.regularizer(ctx), LIFECYCLE.BATCH) ctx.loss_task = CtxVar(ctx.loss_batch + ctx.loss_regular, LIFECYCLE.BATCH)
[docs] def _hook_on_batch_backward(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.optimizer`` Update by gradient ``ctx.loss_task`` Backward propagation ``ctx.scheduler`` Update by gradient ================================== =========================== """ ctx.optimizer.zero_grad() ctx.loss_task.backward() if ctx.grad_clip > 0: torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), ctx.grad_clip) ctx.optimizer.step() if ctx.scheduler is not None: ctx.scheduler.step()
[docs] def _hook_on_batch_end(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.num_samples`` Add ``ctx.batch_size`` ``ctx.loss_batch_total`` Add batch loss ``ctx.loss_regular_total`` Add batch regular loss ``ctx.ys_true`` Append ``ctx.y_true`` ``ctx.ys_prob`` Append ``ctx.ys_prob`` ================================== =========================== """ # update statistics ctx.num_samples += ctx.batch_size ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) # cache label for evaluate ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy())
[docs] def _hook_on_fit_end(self, ctx): """ Evaluate metrics. Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.ys_true`` Convert to ``numpy.array`` ``ctx.ys_prob`` Convert to ``numpy.array`` ``ctx.monitor`` Evaluate the results ``ctx.eval_metrics`` Get evaluated results from \ ``ctx.monitor`` ================================== =========================== """ ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE) results = ctx.monitor.eval(ctx) setattr(ctx, 'eval_metrics', results)
def save_model(self, path, cur_round=-1): assert self.ctx.model is not None ckpt = {'cur_round': cur_round, 'model': self.ctx.model.state_dict()} torch.save(ckpt, path) def load_model(self, path): assert self.ctx.model is not None if os.path.exists(path): ckpt = torch.load(path, map_location=self.ctx.device) self.ctx.model.load_state_dict(ckpt['model']) return ckpt['cur_round'] else: raise ValueError("The file {} does NOT exist".format(path))
[docs] def discharge_model(self): """ Discharge the model from GPU device """ # Avoid memory leak if not self.cfg.federate.share_local_model: if torch is None: pass else: self.ctx.model.to(torch.device("cpu"))