try:
import tensorflow as tf
except ImportError:
tf = None
import numpy as np
from federatedscope.core.trainers import Trainer
from federatedscope.core.trainers.enums import MODE
from federatedscope.core.auxiliaries.utils import batch_iter
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.trainers.enums import LIFECYCLE
[docs]class GeneralTFTrainer(Trainer):
def train(self, target_data_split_name="train", hooks_set=None):
hooks_set = self.hooks_in_train if hooks_set is None else hooks_set
self.ctx.check_split(target_data_split_name)
num_samples = self._run_routine(MODE.TRAIN, hooks_set,
target_data_split_name)
# TODO: The return values should be more flexible? Now: sample_num,
# model_para, results={k:v}
return num_samples, self.ctx.model.state_dict(), self.ctx.eval_metrics
[docs] def parse_data(self, data):
"""Populate "{}_data", "{}_loader" and "num_{}_data" for different
modes
"""
init_dict = dict()
if isinstance(data, dict):
for mode in ["train", "val", "test"]:
init_dict["{}_data".format(mode)] = None
init_dict["{}_loader".format(mode)] = None
init_dict["num_{}_data".format(mode)] = 0
if data.get(mode, None) is not None:
init_dict["{}_data".format(mode)] = data.get(mode)
init_dict["num_{}_data".format(mode)] = len(data.get(mode))
else:
raise TypeError("Type of data should be dict.")
return init_dict
def register_default_hooks_train(self):
self.register_hook_in_train(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_train(self._hook_on_epoch_start,
"on_epoch_start")
self.register_hook_in_train(self._hook_on_batch_start_init,
"on_batch_start")
self.register_hook_in_train(self._hook_on_batch_forward,
"on_batch_forward")
self.register_hook_in_train(self._hook_on_batch_forward_regularizer,
"on_batch_forward")
self.register_hook_in_train(self._hook_on_batch_backward,
"on_batch_backward")
self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")
def register_default_hooks_eval(self):
# test/val
self.register_hook_in_eval(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start")
self.register_hook_in_eval(self._hook_on_batch_start_init,
"on_batch_start")
self.register_hook_in_eval(self._hook_on_batch_forward,
"on_batch_forward")
self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")
[docs] def _hook_on_fit_start_init(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.model`` Move to `ctx.device`
``ctx.loss_batch_total`` Initialize to 0
``ctx.loss_regular_total`` Initialize to 0
``ctx.num_samples`` Initialize to 0
``ctx.ys_true`` Initialize to ``[]``
``ctx.ys_prob`` Initialize to ``[]``
================================== ===========================
"""
# prepare model
ctx.model.to(ctx.device)
# prepare statistics
ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE)
ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE)
ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE)
ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE)
ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE)
[docs] def _hook_on_epoch_start(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.{cur_split}_loader`` Initialize DataLoader
================================== ===========================
"""
# prepare dataloader
setattr(ctx, "{}_loader".format(ctx.cur_split),
batch_iter(ctx.get("{}_data".format(ctx.cur_split))))
[docs] def _hook_on_batch_start_init(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.data_batch`` Initialize batch data
================================== ===========================
"""
# prepare data batch
try:
ctx.data_batch = next(ctx.get("{}_loader".format(ctx.cur_split)))
except StopIteration:
raise StopIteration
[docs] def _hook_on_batch_forward(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.optimizer`` Initialize optimizer
``ctx.batch_size`` Calculate batch size
``ctx.loss_batch`` Calculate batch loss
``ctx.model`` Forward propagation
``ctx.y_true`` Get y_true from batch
``ctx.y_prob`` Forward propagation to get \
`y_prob`
================================== ===========================
"""
ctx.optimizer = ctx.model.optimizer
ctx.batch_size = len(ctx.data_batch)
with ctx.model.graph.as_default():
with ctx.model.sess.as_default():
feed_dict = {
ctx.model.input_x: ctx.data_batch['x'],
ctx.model.input_y: ctx.data_batch['y']
}
_, batch_loss, y_true, y_prob = ctx.model.sess.run(
[
ctx.model.train_op, ctx.model.losses,
ctx.model.input_y, ctx.model.out
],
feed_dict=feed_dict)
ctx.loss_batch = batch_loss
ctx.y_true = CtxVar(y_true, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(y_prob, LIFECYCLE.BATCH)
def _hook_on_batch_forward_regularizer(self, ctx):
pass
def _hook_on_batch_backward(self, ctx):
pass
[docs] def _hook_on_batch_end(self, ctx):
"""
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.num_samples`` Add ``ctx.batch_size``
``ctx.loss_batch_total`` Add batch loss
``ctx.loss_regular_total`` Add batch regular loss
``ctx.ys_true`` Append ``ctx.y_true``
``ctx.ys_prob`` Append ``ctx.ys_prob``
================================== ===========================
"""
# TODO: the same with the torch_trainer
# update statistics
ctx.num_samples += ctx.batch_size
ctx.loss_batch_total += ctx.loss_batch
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))
# cache label for evaluate
ctx.ys_true.append(ctx.y_true.detach().cpu().numpy())
ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy())
[docs] def _hook_on_fit_end(self, ctx):
"""
Evaluate metrics.
Note:
The modified attributes and according operations are shown below:
================================== ===========================
Attribute Operation
================================== ===========================
``ctx.ys_true`` Convert to `numpy.array`
``ctx.ys_prob`` Convert to `numpy.array`
``ctx.monitor`` Evaluate the results
``ctx.eval_metrics`` Get evaluated results from \
``ctx.monitor``
================================== ===========================
"""
ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE)
ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE)
results = self.ctx.monitor.eval(ctx)
setattr(ctx, 'eval_metrics', results)
[docs] def update(self, model_parameters, strict=False):
self.ctx.model.load_state_dict(model_parameters, strict=strict)
def save_model(self, path, cur_round=-1):
pass
def load_model(self, path):
pass
def discharge_model(self):
pass