import torch
import torch.nn.functional as F
from torch.nn import Linear, ModuleList
from torch.nn import BatchNorm1d, Identity
[docs]class MLP(torch.nn.Module):
"""
Multilayer Perceptron
"""
def __init__(self,
channel_list,
dropout=0.,
batch_norm=True,
relu_first=False):
super().__init__()
assert len(channel_list) >= 2
self.channel_list = channel_list
self.dropout = dropout
self.relu_first = relu_first
self.linears = ModuleList()
self.norms = ModuleList()
for in_channel, out_channel in zip(channel_list[:-1],
channel_list[1:]):
self.linears.append(Linear(in_channel, out_channel))
self.norms.append(
BatchNorm1d(out_channel) if batch_norm else Identity())
[docs] def forward(self, x):
x = self.linears[0](x)
for layer, norm in zip(self.linears[1:], self.norms[:-1]):
if self.relu_first:
x = F.relu(x)
x = norm(x)
if not self.relu_first:
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = layer.forward(x)
return x