Source code for federatedscope.nlp.model.rnn

import torch.nn as nn
import torch.nn.functional as F


[docs]class LSTM(nn.Module): def __init__(self, in_channels, hidden, out_channels, n_layers=2, embed_size=8, dropout=.0): super(LSTM, self).__init__() self.in_channels = in_channels self.hidden = hidden self.embed_size = embed_size self.out_channels = out_channels self.n_layers = n_layers self.encoder = nn.Embedding(in_channels, embed_size) self.rnn =\ nn.LSTM( input_size=embed_size if embed_size else in_channels, hidden_size=hidden, num_layers=n_layers, batch_first=True, dropout=dropout ) self.decoder = nn.Linear(hidden, out_channels)
[docs] def forward(self, input_): if self.embed_size: input_ = self.encoder(input_) output, _ = self.rnn(input_) output = self.decoder(output) output = output.permute(0, 2, 1) # change dimension to (B, C, T) final_word = output[:, :, -1] return final_word