Recommendation

Matrix Factorization

FederatedScope has built in the matrix factorization (MF) task for recommendation, which provides flexible supports for MF models, datasets and federated settings. In this tutorial, we will introduce

  • The matrix factorization task in FederatedScope
  • How to implement matrix factorization task with FederatedScope
  • The privacy preserving techniques used in FederatedScope

Background

Matrix factorization (MF) [1-3] is a fundamental building block in recommendation system. For a matrix, a row corresponds to a user, while a column corresponds to an item. The target of matrix factorization is to approximate unobserved ratings by constructing user embedding $U\in{\mathbb{R}^{n\times{d}}}$ and item embedding $V\in{\mathbb{R}^{m\times{d}}}$.

 mf_task.png

Supposing $X\in{\mathbb{R}^{n\times{m}}}$ is the target rating matrix, the task aims at minimizing the loss function:

\[\frac{1}{| \mathcal{D} |} \sum_{(i,j) \in \mathcal{D}} \mathcal{L}_{i,j} (X,U,V) = \frac{1}{| \mathcal{D} |} \sum_{(i,j) \in \mathcal{D}}( X_{i,j} - <u_i, v_j>)^2\]

where $u_i \in{\mathbb{R}^{n \times{1}}}$ and $v_j \in{\mathbb{R}^{m \times{1}}}$ are the user and item vectors of $U$ and $V$.

MF in Federated Learning

In federated learning, the dataset is distributed in different clients. The vanilla federated matrix factorization algorithm runs as follows

  • Step1: Server initializes shared parameters
  • Step2: Server broadcasts shared parameters to all participators
  • Step3: Each participator updates their parameters locally
  • Step4: Participators upload their shared parameters to the server
  • Step5: Server aggregates the received parameters and repeat Step2 until the training is finished

With different data partitions, matrix factorization has three FL settings: Vertical FL(VFL), Horizontal FL(HFL) and Local FL(LFL).

Vertical FL

In VFL, the set of users is the same across different databases, and each participators only has partial items. In this setting, the user embedding is shared across all participators and each client maintains its own item embedding.
VFL setting [3]

Horizontal FL

In HFL, the set of items is the same across different participators, and they only share the item embedding with the coordination server.
截屏2022-03-21 下午2.03.06.png

Local FL

LFL is a special case of HFL, where each user owns her/his own ratings. It’s a common scenario on mobile devices.
截屏2022-03-23 下午2.59.34.png

Support of MF

To support federated MF, FederatedScope builds in MF models, datasets and trainer in different federated learning settings.

MF models

MF model has two trainable parameters: user embedding and item embedding. Based on the given federated setting, they share different embedding with the other participators. FederatedScope achieves VMFNetand HMFNetto support the settings of VFL and HFL.

class VMFNet(BasicMFNet):
    name_reserve = "embed_item"


class HMFNet(BasicMFNet):
    name_reserve = "embed_user"

The attribute name_reservespecifics the name of local embedding vector, and the parent classBasicMFNetdefines the common actions, including

  • load/fetch parameters, and
  • forward propagation.

Note the rating matrix is usually very sparse. To impove the efficiency, FederatedScope creates the predicted matrix and the target rating matrix as sparse tensors.

class BasicMFNet(Module):
    ...
    def forward(self, indices, ratings):
        pred = torch.matmul(self.embed_user, self.embed_item.T)
        label = torch.sparse_coo_tensor(indices,
                                        ratings,
                                        size=pred.shape,
                                        device=pred.device,
                                        dtype=torch.float32).to_dense()
        mask = torch.sparse_coo_tensor(indices,
                                       np.ones(len(ratings)),
                                       size=pred.shape,
                                       device=pred.device,
                                       dtype=torch.float32).to_dense()

        return mask * pred, label, float(np.prod(pred.size())) / len(ratings)
    ...

MF Datasets

MovieLens is series of movie recommendation datasets collected from the website MovieLens.
To satisify the requirement of different FL settings, FederatedScope splits the dataset into VFLMoviesLensand HFLMovieLensas follows. For example, if your want to use the dataset MovieLens1M in VFL settings, just set cfg.data.type='VFLMovieLens1M'.

class VFLMovieLens1M(MovieLens1M, VMFDataset):
    """MovieLens1M dataset in VFL setting
    
    """
    pass


class HFLMovieLens1M(MovieLens1M, HMFDataset):
    """MovieLens1M dataset in HFL setting

    """
    pass


class VFLMovieLens10M(MovieLens10M, VMFDataset):
    """MovieLens10M dataset in VFL setting

    """
    pass


class HFLMovieLens10M(MovieLens10M, HMFDataset):
    """MovieLens10M dataset in HFL setting

    """
    pass

The parent classes of the above datasets define the data information and the FL setting respectively.

Data information

The first parent class MovieLens1M and MovieLens10Mprovide the details (e.g. url, md5, filename).

class MovieLens1M(MovieLensData):
    """MoviesLens 1M Dataset
    (https://grouplens.org/datasets/movielens)

    Format:
        UserID::MovieID::Rating::Timestamp

    Arguments:
        root (str): Root directory of dataset where directory
            ``MoviesLen1M`` exists or will be saved to if download is set to True.
        config (callable): Parameters related to matrix factorization.
        train_size (float, optional): The proportion of training data.
        test_size (float, optional): The proportion of test data.
        download  (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
    base_folder = 'MovieLens1M'
    url = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"
    filename = "ml-1m"
    zip_md5 = "c4d9eecfca2ab87c1945afe126590906"
    raw_file = "ratings.dat"
    raw_file_md5 = "a89aa3591bc97d6d4e0c89459ff39362"


class MovieLens10M(MovieLensData):
    """MoviesLens 10M Dataset
    (https://grouplens.org/datasets/movielens)

    Format:
        UserID::MovieID::Rating::Timestamp

    Arguments:
        root (str): Root directory of dataset where directory
            ``MoviesLen1M`` exists or will be saved to if download is set to True.
        config (callable): Parameters related to matrix factorization.
        train_size (float, optional): The proportion of training data.
        test_size (float, optional): The proportion of test data.
        download  (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
    base_folder = 'MovieLens10M'
    url = "https://files.grouplens.org/datasets/movielens/ml-10m.zip"
    filename = "ml-10M100K"

    zip_md5 = "ce571fd55effeba0271552578f2648bd"
    raw_file = "ratings.dat"
    raw_file_md5 = "3f317698625386f66177629fa5c6b2dc"
FL Setting

VMFDatasetand HMFDatasetspecific the spliting of MF datasets (VFL or HFL).

class VMFDataset:
    """Dataset of matrix factorization task in vertical federated learning.

    """
    def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
                                test_portion: float):
        id_item = np.arange(self.n_item)
        shuffle(id_item)
        items_per_client = np.array_split(id_item, num_client)
        data = dict()
        for clientId, items in enumerate(items_per_client):
            client_ratings = ratings[:, items]
            train_ratings, test_ratings = self._split_train_test_ratings(
                client_ratings, test_portion)
            data[clientId + 1] = {"train": train_ratings, "test": test_ratings}
        self.data = data


class HMFDataset:
    """Dataset of matrix factorization task in horizontal federated learning.

    """
    def _split_n_clients_rating(self, ratings: csc_matrix, num_client: int,
                                test_portion: float):
        id_user = np.arange(self.n_user)
        shuffle(id_user)
        users_per_client = np.array_split(id_user, num_client)
        data = dict()
        for cliendId, users in enumerate(users_per_client):
            client_ratings = ratings[users, :]
            train_ratings, test_ratings = self._split_train_test_ratings(
                client_ratings, test_portion)
            data[cliendId + 1] = {"train": train_ratings, "test": test_ratings}
        self.data = data

MF Trainer

Considering the target rating matrix is large and sparse, FederatedScope achievesMFTrainerto support MF tasks during federated training.

class MFTrainer(GeneralTrainer):
    """
    model (torch.nn.module): MF model.
    data (dict): input data
    device (str): device.
    """

    def _hook_on_fit_end(self, ctx):
        results = {
            "{}_avg_loss".format(ctx.cur_mode): ctx.get("loss_batch_total_{}".format(ctx.cur_mode)) /
            ctx.get("num_samples_{}".format(ctx.cur_mode)),
            "{}_total".format(ctx.cur_mode): ctx.get("num_samples_{}".format(ctx.cur_mode))
        }
        setattr(ctx, 'eval_metrics', results)

    def _hook_on_batch_end(self, ctx):
        # update statistics
        setattr(
            ctx, "loss_batch_total_{}".format(ctx.cur_mode),
            ctx.get("loss_batch_total_{}".format(ctx.cur_mode)) +
            ctx.loss_batch.item() * ctx.batch_size)

        if ctx.get("loss_regular", None) is None or ctx.loss_regular == 0:
            loss_regular = 0.
        else:
            loss_regular = ctx.loss_regular.item()
        setattr(
            ctx, "loss_regular_total_{}".format(ctx.cur_mode),
            ctx.get("loss_regular_total_{}".format(ctx.cur_mode)) +
            loss_regular)
        setattr(
            ctx, "num_samples_{}".format(ctx.cur_mode),
            ctx.get("num_samples_{}".format(ctx.cur_mode)) + ctx.batch_size)

        # clean temp ctx
        ctx.data_batch = None
        ctx.batch_size = None
        ctx.loss_task = None
        ctx.loss_batch = None
        ctx.loss_regular = None
        ctx.y_true = None
        ctx.y_prob = None

    def _hook_on_batch_forward(self, ctx):
        indices, ratings = ctx.data_batch
        pred, label, ratio = ctx.model(indices, ratings)
        ctx.loss_batch = ctx.criterion(pred, label) * ratio

        ctx.batch_size = len(ratings)

Start an Example

Taking the combination of dataset MovieLen1Mand VFL setting as an example, the running command is as follows.

python main.py --cfg federatedscope/mf/baseline/fedavg_vfl_fedavg_standalone_on_movielens1m.yaml

More running scripts can be found in federatedscope/scripts. Partial experimental results are shown as follows.

Federated setting Dataset Number of clients Loss
VFL MovieLens1M 5 1.16
HFL MovieLens1M 5 1.13

Privacy Protection

To protect the user privacy, FederatedScope implements two differential privacy algorithms, VFL-SGDMF and HFL-SGDMF in [vldb22] as plug-ins.

VFL-SGDMF

VFL-SGDMF is a DP based algorithm for privacy preserving in VFL setting. It satisfies $(\epsilon, \delta)$privacy by injecting noise into the embedding matrix. More details please refer to [3]. The related parameters are shown as follows.

# ------------------------------------------------------------------------ #
# VFL-SGDMF(dp) related options
# ------------------------------------------------------------------------ #
cfg.sgdmf = CN()

cfg.sgdmf.use = False    # if use sgdmf
cfg.sgdmf.R = 5.         # The upper bound of rating
cfg.sgdmf.epsilon = 4.   # \epsilon in dp
cfg.sgdmf.delta = 0.5    # \delta in dp
cfg.sgdmf.constant = 1.  # constant
cfg.sgdmf.theta = -1     # -1 means per-rating privacy, otherwise per-user privacy

VFL-SGDMF is implemented as plug-in in federatedscope/mf/trainer/trainer_sgdmf.py. Similar with the other plug-in algorithms, it initializes and registers hook functions in the functionwrap_MFTrainer.

def wrap_MFTrainer(base_trainer: Type[MFTrainer]) -> Type[MFTrainer]:
    """Build `SGDMFTrainer` with a plug-in manner, by registering new functions into specific `MFTrainer`

    """

    # ---------------- attribute-level plug-in -----------------------
    init_sgdmf_ctx(base_trainer)

    # ---------------- action-level plug-in -----------------------
    base_trainer.replace_hook_in_train(new_hook=hook_on_batch_backward,
                                       target_trigger="on_batch_backward",
                                       target_hook_name="_hook_on_batch_backward")

    return base_trainer

The embedding clipping and noise injection is finished in the new hook function hook_on_batch_backward.

def hook_on_batch_backward(ctx):
    """Private local updates in SGDMF

    """
    ctx.optimizer.zero_grad()
    ctx.loss_task.backward()
    
    # Inject noise
    ctx.model.embed_user.grad.data += get_random(
        "Normal",
        sample_shape=ctx.model.embed_user.shape,
        params={
            "loc": 0,
            "scale": ctx.scale
        },
        device=ctx.model.embed_user.device)
    ctx.model.embed_item.grad.data += get_random(
        "Normal",
        sample_shape=ctx.model.embed_item.shape,
        params={
            "loc": 0,
            "scale": ctx.scale
        },
        device=ctx.model.embed_item.device)
    ctx.optimizer.step()

    # Embedding clipping
    with torch.no_grad():
        embedding_clip(ctx.model.embed_user, ctx.sgdmf_R)
        embedding_clip(ctx.model.embed_item, ctx.sgdmf_R)
Start an Example

Similarly, taking MovieLens1M as an example, the running script is shown as follows.

python federatedscope/main.py --cfg federatedscope/mf/baseline/vfl-sgdmf_fedavg_standalone_on_movielens1m.yaml
Evaluation

Take the dataset MovieLens1M as an example, the detailed settings are listed in federatedscope/mf/baseline/vfl_fedavg_standalone_on_movielens1m.yaml and federatedscope/mf/baseline/vfl-sgdmf_fedavg_standalone_on_movielens1m.yaml. VFL-SGDMF is evaluated as follows.

Algo $\epsilon$ $\delta$ Loss
VFL - - 1.16
VFL-SGDMF 4 0.75 1.47
VFL-SGDMF 4 0.25 1.54
VFL-SGDMF 2 0.75 1.55
VFL-SGDMF 2 0.25 1.56
VFL-SGDMF 0.5 0.75 1.68
VFL-SGDMF 0.5 0.25 1.84

HFL-SGDMF

On the other side, HFL-SGDMF protects privacy in HFL setting in the same way, and share the same parameters with VFL-SGDMF.

# ------------------------------------------------------------------------ #
# VFL-SGDMF(dp) related options
# ------------------------------------------------------------------------ #
cfg.sgdmf = CN()

cfg.sgdmf.use = False    # if use sgdmf
cfg.sgdmf.R = 5.         # The upper bound of rating
cfg.sgdmf.epsilon = 4.   # \epsilon in dp
cfg.sgdmf.delta = 0.5    # \delta in dp
cfg.sgdmf.constant = 1.  # constant
cfg.sgdmf.theta = -1     # -1 means per-rating privacy, otherwise per-user privacy
Start and Example

Run an example of HFL-SGDMF by the following command.

python federatedscope/main.py --cfg federatedscope/mf/baseline/hfl-sgdmf_fedavg_standalone_on_movielens1m.yaml
Evaluation

The evaluation results of HFL-SGDMF on the dataset MovieLens1M are shown as follows.

Algo $\epsilon$ $\delta$ Loss
HFL - - 1.13
HFL-SGDMF 4 0.75 1.56
HFL-SGDMF 4 0.25 1.62
HFL-SGDMF 2 0.75 1.60
HFL-SGDMF 2 0.25 1.64
HFL-SGDMF 0.5 0.75 1.66
HFL-SGDMF 0.5 0.25 1.73

References

[1] Ma H, Yang H, Lyu M R, et al. “SoRec: social recommendation using probabilistic matrix factorization”. Proceedings of the ACM Conference on Information and Knowledge Management, 2008.

[2] Jamali M, Ester M. “A matrix factorization technique with trust propagation for recommendation in social networks”. Proceedings of the ACM Conference on Recommender Systems, 2010.

[3] Li Z, Ding B, Zhang C, et al. “Federated matrix factorization with privacy guarantee”. Proceedings of the VLDB Endowment, 2022.

Updated: