import torch
from torch.utils.data import DataLoader
from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborSampler
from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.monitors import Monitor
from federatedscope.core.trainers.context import CtxVar
from federatedscope.register import register_trainer
from federatedscope.core.trainers import GeneralTorchTrainer
import logging
logger = logging.getLogger(__name__)
MODE2MASK = {
'train': 'train_edge_mask',
'val': 'valid_edge_mask',
'test': 'test_edge_mask'
}
[docs]class LinkFullBatchTrainer(GeneralTorchTrainer):
def register_default_hooks_eval(self):
super().register_default_hooks_eval()
self.register_hook_in_eval(
new_hook=self._hook_on_epoch_start_data2device,
trigger='on_fit_start',
insert_pos=-1)
def register_default_hooks_train(self):
super().register_default_hooks_train()
self.register_hook_in_train(
new_hook=self._hook_on_epoch_start_data2device,
trigger='on_fit_start',
insert_pos=-1)
[docs] def parse_data(self, data):
"""Populate "{}_data", "{}_loader" and "num_{}_data" for different
modes
"""
init_dict = dict()
if isinstance(data, dict):
for mode in ["train", "val", "test"]:
graph_data = data['data']
edges = graph_data.edge_index.T[graph_data[MODE2MASK[mode]]]
# Use an index loader
index_loader = DataLoader(
range(edges.size(0)),
self.cfg.dataloader.batch_size,
shuffle=self.cfg.dataloader.shuffle
if mode == 'train' else False,
drop_last=self.cfg.dataloader.drop_last
if mode == 'train' else False)
init_dict["{}_loader".format(mode)] = index_loader
init_dict["num_{}_data".format(mode)] = edges.size(0)
init_dict["{}_data".format(mode)] = None
else:
raise TypeError("Type of data should be dict.")
return init_dict
def _hook_on_epoch_start_data2device(self, ctx):
if isinstance(ctx.data, dict):
ctx.data = ctx.data['data']
ctx.data = ctx.data.to(ctx.device)
# For handling different dict key
if "input_edge_index" in ctx.data:
ctx.input_edge_index = ctx.data.input_edge_index
else:
ctx.input_edge_index = ctx.data.edge_index.T[
ctx.data.train_edge_mask].T
[docs] def _hook_on_batch_forward(self, ctx):
data = ctx.data
perm = ctx.data_batch
mask = ctx.data[MODE2MASK[ctx.cur_split]]
edges = data.edge_index.T[mask]
if ctx.cur_split in ['train', 'val']:
h = ctx.model((data.x, ctx.input_edge_index))
else:
h = ctx.model((data.x, data.edge_index))
pred = ctx.model.link_predictor(h, edges[perm].T)
label = data.edge_type[mask][perm] # edge_type is y
ctx.loss_batch = ctx.criterion(pred, label)
ctx.batch_size = len(label)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
[docs] def _hook_on_batch_forward_flop_count(self, ctx):
if not isinstance(self.ctx.monitor, Monitor):
logger.warning(
f"The trainer {type(self)} does contain a valid monitor, "
f"this may be caused by initializing trainer subclasses "
f"without passing a valid monitor instance."
f"Plz check whether this is you want.")
return
if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \
== 0:
# calculate the flops_per_sample
try:
data = ctx.data
from fvcore.nn import FlopCountAnalysis
if ctx.cur_split in ['train', 'val']:
flops_one_batch = FlopCountAnalysis(
ctx.model, (data.x, ctx.input_edge_index)).total()
else:
flops_one_batch = FlopCountAnalysis(
ctx.model, (data.x, data.edge_index)).total()
if self.model_nums > 1 and ctx.mirrored_models:
flops_one_batch *= self.model_nums
logger.warning(
"the flops_per_batch is multiplied by "
"internal model nums as self.mirrored_models=True."
"if this is not the case you want, "
"please customize the count hook.")
self.ctx.monitor.track_avg_flops(flops_one_batch,
ctx.batch_size)
except:
logger.warning(
"current flop count implementation is for general "
"NodeFullBatchTrainer case: "
"1) the ctx.model takes the "
"tuple (data.x, data.edge_index) or tuple (data.x, "
"ctx.input_edge_index) as input."
"Please check the forward format or implement your own "
"flop_count function")
# warning at the first failure
self.ctx.monitor.flops_per_sample = -1
[docs]class LinkMiniBatchTrainer(GeneralTorchTrainer):
"""
# Support GraphSAGE with GraphSAINTRandomWalkSampler in train ONLY!
"""
[docs] def parse_data(self, data):
"""Populate "{}_data", "{}_loader" and "num_{}_data" for different
modes
"""
init_dict = dict()
if isinstance(data, dict):
for mode in ["train", "val", "test"]:
init_dict["{}_data".format(mode)] = None
init_dict["{}_loader".format(mode)] = None
init_dict["num_{}_data".format(mode)] = 0
if data.get(mode, None) is not None:
if isinstance(
data.get(mode), NeighborSampler) or isinstance(
data.get(mode), GraphSAINTRandomWalkSampler):
if mode == 'train':
init_dict["{}_loader".format(mode)] = data.get(
mode)
init_dict["num_{}_data".format(mode)] = len(
data.get(mode).dataset)
else:
# We need to pass Full Dataloader to model
init_dict["{}_loader".format(mode)] = [
data.get(mode)
]
init_dict["num_{}_data".format(
mode)] = self.cfg.dataloader.batch_size
else:
raise TypeError("Type {} is not supported.".format(
type(data.get(mode))))
else:
raise TypeError("Type of data should be dict.")
return init_dict
[docs] def _hook_on_batch_forward(self, ctx):
if ctx.cur_split == 'train':
batch = ctx.data_batch.to(ctx.device)
mask = batch[MODE2MASK[ctx.cur_split]]
edges = batch.edge_index.T[mask].T
h = ctx.model((batch.x, edges))
pred = ctx.model.link_predictor(h, edges)
label = batch.edge_type[mask]
ctx.batch_size = torch.sum(
ctx.data_batch[MODE2MASK[ctx.cur_split]]).item()
else:
# For inference
mask = ctx.data['data'][MODE2MASK[ctx.cur_split]]
subgraph_loader = ctx.data_batch
h = ctx.model.gnn.inference(ctx.data['data'].x, subgraph_loader,
ctx.device).to(ctx.device)
edges = ctx.data['data'].edge_index.T[mask].to(ctx.device)
pred = []
for perm in DataLoader(range(edges.size(0)),
self.cfg.dataloader.batch_size):
edge = edges[perm].T
pred += [ctx.model.link_predictor(h, edge).squeeze()]
pred = torch.cat(pred, dim=0)
label = ctx.data['data'].edge_type[mask].to(ctx.device)
ctx.batch_size = torch.sum(
ctx.data['data'][MODE2MASK[ctx.cur_split]]).item()
ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
def call_link_level_trainer(trainer_type):
if trainer_type == 'linkfullbatch_trainer':
trainer_builder = LinkFullBatchTrainer
elif trainer_type == 'linkminibatch_trainer':
trainer_builder = LinkMiniBatchTrainer
else:
trainer_builder = None
return trainer_builder
register_trainer('linkfullbatch_trainer', call_link_level_trainer)
register_trainer('linkminibatch_trainer', call_link_level_trainer)