Source code for federatedscope.mf.model.model

from torch.nn import Module

import numpy as np
import torch


[docs]class BasicMFNet(Module): """Basic model for MF task Arguments: num_user (int): the number of users num_item (int): the number of items num_hidden (int): the dimension of embedding vector """ def __init__(self, num_user, num_item, num_hidden): super(BasicMFNet, self).__init__() self.num_user, self.num_item = num_user, num_item self.embed_user = torch.nn.Embedding(num_user, num_hidden, sparse=True) self.embed_item = torch.nn.Embedding(num_item, num_hidden, sparse=True)
[docs] def forward(self, indices, ratings): device = self.embed_user.weight.device indices = torch.tensor(np.array(indices)).to(device) user_embedding = self.embed_user(indices[0]) item_embedding = self.embed_item(indices[1]) pred = (user_embedding * item_embedding).sum(dim=1) label = torch.tensor(np.array(ratings)).to(device) return pred, label
[docs] def load_state_dict(self, state_dict, strict: bool = True): state_dict[self.name_reserve + '.weight'] = getattr( getattr(self, self.name_reserve), 'weight') super().load_state_dict(state_dict, strict)
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): state_dict = super().state_dict(destination, prefix, keep_vars) # Mask embed_item del state_dict[self.name_reserve + '.weight'] return state_dict
[docs]class VMFNet(BasicMFNet): """MF model for vertical federated learning """ name_reserve = "embed_item"
[docs]class HMFNet(BasicMFNet): """MF model for horizontal federated learning """ name_reserve = "embed_user"