Source code for federatedscope.core.mlp

import torch
import torch.nn.functional as F
from torch.nn import Linear, ModuleList
from torch.nn import BatchNorm1d, Identity


[docs]class MLP(torch.nn.Module): """ Multilayer Perceptron """ def __init__(self, channel_list, dropout=0., batch_norm=True, relu_first=False): super().__init__() assert len(channel_list) >= 2 self.channel_list = channel_list self.dropout = dropout self.relu_first = relu_first self.linears = ModuleList() self.norms = ModuleList() for in_channel, out_channel in zip(channel_list[:-1], channel_list[1:]): self.linears.append(Linear(in_channel, out_channel)) self.norms.append( BatchNorm1d(out_channel) if batch_norm else Identity())
[docs] def forward(self, x): x = self.linears[0](x) for layer, norm in zip(self.linears[1:], self.norms[:-1]): if self.relu_first: x = F.relu(x) x = norm(x) if not self.relu_first: x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = layer.forward(x) return x