Source code for federatedscope.gfl.trainer.linktrainer

import torch

from torch.utils.data import DataLoader
from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler

from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.monitors import Monitor
from federatedscope.core.trainers.context import CtxVar
from federatedscope.register import register_trainer
from federatedscope.core.trainers import GeneralTorchTrainer

import logging

logger = logging.getLogger(__name__)

MODE2MASK = {
    'train': 'train_edge_mask',
    'val': 'valid_edge_mask',
    'test': 'test_edge_mask'
}


[docs]class LinkFullBatchTrainer(GeneralTorchTrainer): def register_default_hooks_eval(self): super().register_default_hooks_eval() self.register_hook_in_eval( new_hook=self._hook_on_epoch_start_data2device, trigger='on_fit_start', insert_pos=-1) def register_default_hooks_train(self): super().register_default_hooks_train() self.register_hook_in_train( new_hook=self._hook_on_epoch_start_data2device, trigger='on_fit_start', insert_pos=-1)
[docs] def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes """ init_dict = dict() if isinstance(data, dict): for mode in ["train", "val", "test"]: graph_data = data['data'] edges = graph_data.edge_index.T[graph_data[MODE2MASK[mode]]] # Use an index loader index_loader = DataLoader( range(edges.size(0)), self.cfg.dataloader.batch_size, shuffle=self.cfg.dataloader.shuffle if mode == 'train' else False, drop_last=self.cfg.dataloader.drop_last if mode == 'train' else False) init_dict["{}_loader".format(mode)] = index_loader init_dict["num_{}_data".format(mode)] = edges.size(0) init_dict["{}_data".format(mode)] = None else: raise TypeError("Type of data should be dict.") return init_dict
def _hook_on_epoch_start_data2device(self, ctx): if isinstance(ctx.data, dict): ctx.data = ctx.data['data'] ctx.data = ctx.data.to(ctx.device) # For handling different dict key if "input_edge_index" in ctx.data: ctx.input_edge_index = ctx.data.input_edge_index else: ctx.input_edge_index = ctx.data.edge_index.T[ ctx.data.train_edge_mask].T
[docs] def _hook_on_batch_forward(self, ctx): data = ctx.data perm = ctx.data_batch mask = ctx.data[MODE2MASK[ctx.cur_split]] edges = data.edge_index.T[mask] if ctx.cur_split in ['train', 'val']: h = ctx.model((data.x, ctx.input_edge_index)) else: h = ctx.model((data.x, data.edge_index)) pred = ctx.model.link_predictor(h, edges[perm].T) label = data.edge_type[mask][perm] # edge_type is y ctx.loss_batch = ctx.criterion(pred, label) ctx.batch_size = len(label) ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
[docs] def _hook_on_batch_forward_flop_count(self, ctx): if not isinstance(self.ctx.monitor, Monitor): logger.warning( f"The trainer {type(self)} does contain a valid monitor, " f"this may be caused by initializing trainer subclasses " f"without passing a valid monitor instance." f"Plz check whether this is you want.") return if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \ == 0: # calculate the flops_per_sample try: data = ctx.data from fvcore.nn import FlopCountAnalysis if ctx.cur_split in ['train', 'val']: flops_one_batch = FlopCountAnalysis( ctx.model, (data.x, ctx.input_edge_index)).total() else: flops_one_batch = FlopCountAnalysis( ctx.model, (data.x, data.edge_index)).total() if self.model_nums > 1 and ctx.mirrored_models: flops_one_batch *= self.model_nums logger.warning( "the flops_per_batch is multiplied by " "internal model nums as self.mirrored_models=True." "if this is not the case you want, " "please customize the count hook.") self.ctx.monitor.track_avg_flops(flops_one_batch, ctx.batch_size) except: logger.warning( "current flop count implementation is for general " "NodeFullBatchTrainer case: " "1) the ctx.model takes the " "tuple (data.x, data.edge_index) or tuple (data.x, " "ctx.input_edge_index) as input." "Please check the forward format or implement your own " "flop_count function") # warning at the first failure self.ctx.monitor.flops_per_sample = -1
[docs]class LinkMiniBatchTrainer(GeneralTorchTrainer): """ # Support GraphSAGE with GraphSAINTRandomWalkSampler in train ONLY! """
[docs] def parse_data(self, data): """Populate "{}_data", "{}_loader" and "num_{}_data" for different modes """ init_dict = dict() if isinstance(data, dict): for mode in ["train", "val", "test"]: init_dict["{}_data".format(mode)] = None init_dict["{}_loader".format(mode)] = None init_dict["num_{}_data".format(mode)] = 0 if data.get(mode, None) is not None: if isinstance( data.get(mode), NeighborSampler) or isinstance( data.get(mode), GraphSAINTRandomWalkSampler): if mode == 'train': init_dict["{}_loader".format(mode)] = data.get( mode) init_dict["num_{}_data".format(mode)] = len( data.get(mode).dataset) else: # We need to pass Full Dataloader to model init_dict["{}_loader".format(mode)] = [ data.get(mode) ] init_dict["num_{}_data".format( mode)] = self.cfg.dataloader.batch_size else: raise TypeError("Type {} is not supported.".format( type(data.get(mode)))) else: raise TypeError("Type of data should be dict.") return init_dict
[docs] def _hook_on_batch_forward(self, ctx): if ctx.cur_split == 'train': batch = ctx.data_batch.to(ctx.device) mask = batch[MODE2MASK[ctx.cur_split]] edges = batch.edge_index.T[mask].T h = ctx.model((batch.x, edges)) pred = ctx.model.link_predictor(h, edges) label = batch.edge_type[mask] ctx.batch_size = torch.sum( ctx.data_batch[MODE2MASK[ctx.cur_split]]).item() else: # For inference mask = ctx.data['data'][MODE2MASK[ctx.cur_split]] subgraph_loader = ctx.data_batch h = ctx.model.gnn.inference(ctx.data['data'].x, subgraph_loader, ctx.device).to(ctx.device) edges = ctx.data['data'].edge_index.T[mask].to(ctx.device) pred = [] for perm in DataLoader(range(edges.size(0)), self.cfg.dataloader.batch_size): edge = edges[perm].T pred += [ctx.model.link_predictor(h, edge).squeeze()] pred = torch.cat(pred, dim=0) label = ctx.data['data'].edge_type[mask].to(ctx.device) ctx.batch_size = torch.sum( ctx.data['data'][MODE2MASK[ctx.cur_split]]).item() ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH) ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
def call_link_level_trainer(trainer_type): if trainer_type == 'linkfullbatch_trainer': trainer_builder = LinkFullBatchTrainer elif trainer_type == 'linkminibatch_trainer': trainer_builder = LinkMiniBatchTrainer else: trainer_builder = None return trainer_builder register_trainer('linkfullbatch_trainer', call_link_level_trainer) register_trainer('linkminibatch_trainer', call_link_level_trainer)