Source code for federatedscope.core.trainers.tf_trainer

try:
    import tensorflow as tf
except ImportError:
    tf = None

import numpy as np
from federatedscope.core.trainers import Trainer
from federatedscope.core.trainers.enums import MODE
from federatedscope.core.auxiliaries.utils import batch_iter
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.trainers.enums import LIFECYCLE


[docs]class GeneralTFTrainer(Trainer): def train(self, target_data_split_name="train", hooks_set=None): hooks_set = self.hooks_in_train if hooks_set is None else hooks_set self.ctx.check_split(target_data_split_name) num_samples = self._run_routine(MODE.TRAIN, hooks_set, target_data_split_name) # TODO: The return values should be more flexible? Now: sample_num, # model_para, results={k:v} return num_samples, self.ctx.model.state_dict(), self.ctx.eval_metrics
[docs] def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes """ init_dict = dict() if isinstance(data, dict): for mode in ["train", "val", "test"]: init_dict["{}_data".format(mode)] = None init_dict["{}_loader".format(mode)] = None init_dict["num_{}_data".format(mode)] = 0 if data.get(mode, None) is not None: init_dict["{}_data".format(mode)] = data.get(mode) init_dict["num_{}_data".format(mode)] = len(data.get(mode)) else: raise TypeError("Type of data should be dict.") return init_dict
def register_default_hooks_train(self): self.register_hook_in_train(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_train(self._hook_on_epoch_start, "on_epoch_start") self.register_hook_in_train(self._hook_on_batch_start_init, "on_batch_start") self.register_hook_in_train(self._hook_on_batch_forward, "on_batch_forward") self.register_hook_in_train(self._hook_on_batch_forward_regularizer, "on_batch_forward") self.register_hook_in_train(self._hook_on_batch_backward, "on_batch_backward") self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end") self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end") def register_default_hooks_eval(self): # test/val self.register_hook_in_eval(self._hook_on_fit_start_init, "on_fit_start") self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start") self.register_hook_in_eval(self._hook_on_batch_start_init, "on_batch_start") self.register_hook_in_eval(self._hook_on_batch_forward, "on_batch_forward") self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end") self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")
[docs] def _hook_on_fit_start_init(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.model`` Move to `ctx.device` ``ctx.loss_batch_total`` Initialize to 0 ``ctx.loss_regular_total`` Initialize to 0 ``ctx.num_samples`` Initialize to 0 ``ctx.ys_true`` Initialize to ``[]`` ``ctx.ys_prob`` Initialize to ``[]`` ================================== =========================== """ # prepare model ctx.model.to(ctx.device) # prepare statistics ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE) ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE) ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE) ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE) ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE)
[docs] def _hook_on_epoch_start(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.{cur_split}_loader`` Initialize DataLoader ================================== =========================== """ # prepare dataloader setattr(ctx, "{}_loader".format(ctx.cur_split), batch_iter(ctx.get("{}_data".format(ctx.cur_split))))
[docs] def _hook_on_batch_start_init(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.data_batch`` Initialize batch data ================================== =========================== """ # prepare data batch try: ctx.data_batch = next(ctx.get("{}_loader".format(ctx.cur_split))) except StopIteration: raise StopIteration
[docs] def _hook_on_batch_forward(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.optimizer`` Initialize optimizer ``ctx.batch_size`` Calculate batch size ``ctx.loss_batch`` Calculate batch loss ``ctx.model`` Forward propagation ``ctx.y_true`` Get y_true from batch ``ctx.y_prob`` Forward propagation to get \ `y_prob` ================================== =========================== """ ctx.optimizer = ctx.model.optimizer ctx.batch_size = len(ctx.data_batch) with ctx.model.graph.as_default(): with ctx.model.sess.as_default(): feed_dict = { ctx.model.input_x: ctx.data_batch['x'], ctx.model.input_y: ctx.data_batch['y'] } _, batch_loss, y_true, y_prob = ctx.model.sess.run( [ ctx.model.train_op, ctx.model.losses, ctx.model.input_y, ctx.model.out ], feed_dict=feed_dict) ctx.loss_batch = batch_loss ctx.y_true = CtxVar(y_true, LIFECYCLE.BATCH) ctx.y_prob = CtxVar(y_prob, LIFECYCLE.BATCH)
def _hook_on_batch_forward_regularizer(self, ctx): pass def _hook_on_batch_backward(self, ctx): pass
[docs] def _hook_on_batch_end(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.num_samples`` Add ``ctx.batch_size`` ``ctx.loss_batch_total`` Add batch loss ``ctx.loss_regular_total`` Add batch regular loss ``ctx.ys_true`` Append ``ctx.y_true`` ``ctx.ys_prob`` Append ``ctx.ys_prob`` ================================== =========================== """ # TODO: the same with the torch_trainer # update statistics ctx.num_samples += ctx.batch_size ctx.loss_batch_total += ctx.loss_batch ctx.loss_regular_total += float(ctx.get("loss_regular", 0.)) # cache label for evaluate ctx.ys_true.append(ctx.y_true.detach().cpu().numpy()) ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy())
[docs] def _hook_on_fit_end(self, ctx): """ Evaluate metrics. Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.ys_true`` Convert to `numpy.array` ``ctx.ys_prob`` Convert to `numpy.array` ``ctx.monitor`` Evaluate the results ``ctx.eval_metrics`` Get evaluated results from \ ``ctx.monitor`` ================================== =========================== """ ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE) ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE) results = self.ctx.monitor.eval(ctx) setattr(ctx, 'eval_metrics', results)
[docs] def update(self, model_parameters, strict=False): self.ctx.model.load_state_dict(model_parameters, strict=strict)
def save_model(self, path, cur_round=-1): pass def load_model(self, path): pass def discharge_model(self): pass