Source code for federatedscope.gfl.trainer.nodetrainer

import torch
from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler

from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.monitors import Monitor
from federatedscope.core.trainers.context import CtxVar
from federatedscope.register import register_trainer
from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.auxiliaries.ReIterator import ReIterator
import logging

logger = logging.getLogger(__name__)


[docs]class NodeFullBatchTrainer(GeneralTorchTrainer):
[docs] def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes """ init_dict = dict() if isinstance(data, dict): for mode in ["train", "val", "test"]: init_dict["{}_loader".format(mode)] = data.get(mode) init_dict["{}_data".format(mode)] = None # For node-level task dataloader contains one graph init_dict["num_{}_data".format(mode)] = 1 else: raise TypeError("Type of data should be dict.") return init_dict
[docs] def _hook_on_batch_forward(self, ctx): batch = ctx.data_batch.to(ctx.device) pred = ctx.model(batch)[batch['{}_mask'.format(ctx.cur_split)]] label = batch.y[batch['{}_mask'.format(ctx.cur_split)]] ctx.batch_size = torch.sum(ctx.data_batch['{}_mask'.format( ctx.cur_split)]).item() ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH) ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
[docs] def _hook_on_batch_forward_flop_count(self, ctx): if not isinstance(self.ctx.monitor, Monitor): logger.warning( f"The trainer {type(self)} does contain a valid monitor, " f"this may be caused by " f"initializing trainer subclasses without passing a valid " f"monitor instance." f"Plz check whether this is you want.") return if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ == 0: # calculate the flops_per_sample try: batch = ctx.data_batch.to(ctx.device) from torch_geometric.data import Data if isinstance(batch, Data): x, edge_index = batch.x, batch.edge_index from fvcore.nn import FlopCountAnalysis flops_one_batch = FlopCountAnalysis(ctx.model, (x, edge_index)).total() if self.model_nums > 1 and ctx.mirrored_models: flops_one_batch *= self.model_nums logger.warning( "the flops_per_batch is multiplied by " "internal model nums as self.mirrored_models=True." "if this is not the case you want, " "please customize the count hook") self.ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size) except: logger.warning( "current flop count implementation is for general " "NodeFullBatchTrainer case: " "1) the ctx.model takes only batch = ctx.data_batch as " "input." "Please check the forward format or implement your own " "flop_count function") self.ctx.monitor.flops_per_sample = -1 # warning at the # first failure # by default, we assume the data has the same input shape, # thus simply multiply the flops to avoid redundant forward self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \ ctx.batch_size
[docs]class NodeMiniBatchTrainer(GeneralTorchTrainer):
[docs] def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes """ init_dict = dict() if isinstance(data, dict): for mode in ["train", "val", "test"]: init_dict["{}_data".format(mode)] = None init_dict["{}_loader".format(mode)] = None init_dict["num_{}_data".format(mode)] = 0 if data.get(mode, None) is not None: if isinstance( data.get(mode), NeighborSampler) or isinstance( data.get(mode), GraphSAINTRandomWalkSampler): if mode == 'train': init_dict["{}_loader".format(mode)] = data.get( mode) init_dict["num_{}_data".format(mode)] = len( data.get(mode).dataset) else: # We need to pass Full Dataloader to model init_dict["{}_loader".format(mode)] = [ data.get(mode) ] init_dict["num_{}_data".format( mode)] = self.cfg.dataloader.batch_size else: raise TypeError("Type {} is not supported.".format( type(data.get(mode)))) else: raise TypeError("Type of data should be dict.") return init_dict
[docs] def _hook_on_epoch_start(self, ctx): if not isinstance(ctx.get("{}_loader".format(ctx.cur_split)), ReIterator): if isinstance(ctx.get("{}_loader".format(ctx.cur_split)), NeighborSampler): self.is_NeighborSampler = True ctx.data['data'].x = ctx.data['data'].x.to(ctx.device) ctx.data['data'].y = ctx.data['data'].y.to(ctx.device) else: self.is_NeighborSampler = False setattr(ctx, "{}_loader".format(ctx.cur_split), ReIterator(ctx.get("{}_loader".format(ctx.cur_split))))
[docs] def _hook_on_batch_forward(self, ctx): if ctx.cur_split == 'train': # For training if self.is_NeighborSampler: # For NeighborSamper batch_size, n_id, adjs = ctx.data_batch adjs = [adj.to(ctx.device) for adj in adjs] pred = ctx.model(ctx.data['data'].x[n_id], adjs=adjs) label = ctx.data['data'].y[n_id[:batch_size]] ctx.batch_size, _, _ = ctx.data_batch else: # For GraphSAINTRandomWalkSampler or PyGDataLoader batch = ctx.data_batch.to(ctx.device) pred = ctx.model( (batch.x, batch.edge_index))[batch['{}_mask'.format(ctx.cur_split)]] label = batch.y[batch['{}_mask'.format(ctx.cur_split)]] ctx.batch_size = torch.sum(ctx.data_batch['train_mask']).item() else: # For inference subgraph_loader = ctx.data_batch mask = ctx.data['data']['{}_mask'.format(ctx.cur_split)] pred = ctx.model.inference(ctx.data['data'].x, subgraph_loader, ctx.device)[mask] label = ctx.data['data'].y[mask] ctx.batch_size = torch.sum(ctx.data['data']['{}_mask'.format( ctx.cur_split)]).item() ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH) ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
def call_node_level_trainer(trainer_type): if trainer_type == 'nodefullbatch_trainer': trainer_builder = NodeFullBatchTrainer elif trainer_type == 'nodeminibatch_trainer': trainer_builder = NodeMiniBatchTrainer else: trainer_builder = None return trainer_builder register_trainer('nodefullbatch_trainer', call_node_level_trainer) register_trainer('nodeminibatch_trainer', call_node_level_trainer)