import torch.nn as nn
import torch.nn.functional as F
[docs]class LSTM(nn.Module):
def __init__(self,
in_channels,
hidden,
out_channels,
n_layers=2,
embed_size=8,
dropout=.0):
super(LSTM, self).__init__()
self.in_channels = in_channels
self.hidden = hidden
self.embed_size = embed_size
self.out_channels = out_channels
self.n_layers = n_layers
self.encoder = nn.Embedding(in_channels, embed_size)
self.rnn =\
nn.LSTM(
input_size=embed_size if embed_size else in_channels,
hidden_size=hidden,
num_layers=n_layers,
batch_first=True,
dropout=dropout
)
self.decoder = nn.Linear(hidden, out_channels)
[docs] def forward(self, input_):
if self.embed_size:
input_ = self.encoder(input_)
output, _ = self.rnn(input_)
output = self.decoder(output)
output = output.permute(0, 2, 1) # change dimension to (B, C, T)
final_word = output[:, :, -1]
return final_word