Code
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Imad Dabbura
December 10, 2022
December 10, 2022
Long Short-Term Memory (LSTM) is a recurrent neural network (RNN) architectute that was introduced by Hochreiter and Schmidhuber in 1997 to solve the problem of vanishing gradients that RNNs suffered from for long sequences. This issue is the result of repeated multiplication using the same weights in all timesteps since the weights are shared between all timesteps. Instead of having one hidden state as is the case for RNNs, we have two hidden states: cell state that is responsible for retaining long short-term memory, and hidden state that is focused on predicting the next word.
RNNs and LSTMs are sequential models, which means they can only take one input at a time to produce one output because the output at time t depends not only on \(x_t\) but also on the hidden state(s) from \(t -1\). In other words, we can’t parallelize the forward pass and need to iterate over all the timesteps to get all the results.
In this post, we will focus on implementing LSTM from scratch and compare it with pytorch to check our implementation. Along the way, we will consider performance issues and some ways to optimize our implementation. Hopefully this will help us better understand LSTMs, since the only way to really understand something is to build it yourself from scratch.
Let’s first cover, at a high level, how LSTM works:
LSTM solves the vanishing/exploding gradients problem by making it easier to preserve information through longer timesteps.
We will first start with implementing LSTMCell
that operates on 1 input at a time. Next, we will implement LSTM
module that wraps the LSTMCell
to work on sequence of inputs and, optionally, stack multiple layers on top of each other to increase the capacity of the model with some regularization using dropout
.
Let’s take a look at the equations for an LSTMCell
(each gate has the same dimension as hidden state):
Where:
Even though we have 4 gates, we actually implement them using one matrix to speed up the computation. Then later we will split the output to compute the corresponding gates.
Let’s implement LSTMCell
and check its correctness with pytorch.
# Long version
class LSTMCellNew(nn.Module):
def __init__(self, input_sz, hidden_sz, bias=True):
super().__init__()
self.weight_ih = nn.Parameter(torch.randn((input_sz, hidden_sz * 4)))
self.weight_hh = nn.Parameter(torch.randn((hidden_sz, hidden_sz * 4)))
self.bias_ih = nn.Parameter(torch.zeros(hidden_sz * 4))
self.bias_hh = nn.Parameter(torch.zeros(hidden_sz * 4))
def forward(self, x, h, c):
# T x B x hidden_sz
out = x @ self.weight_ih + h @ self.weight_hh + self.bias_ih + self.bias_hh
i, f, g, o = torch.split(out, 100, dim=-1)
i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
g = torch.tanh(g)
c_t = f * c + i * g
h_t = o * torch.tanh(c_t)
return h_t, c_t
# Short version utilizing linear layer module
class LSTMCellNew(nn.Module):
def __init__(self, input_sz, hidden_sz, bias=True):
super().__init__()
self.ih = nn.Linear(input_sz, hidden_sz * 4, bias=bias)
self.hh = nn.Linear(hidden_sz, hidden_sz * 4, bias=bias)
def forward(self, x, h, c):
out = self.ih(x) + self.hh(h)
i, f, g, o = torch.split(out, 100, dim=-1)
i, f, o = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o)
g = torch.tanh(g)
c_t = f * c + i * g
h_t = o * torch.tanh(c_t)
return h_t, c_t
pytorch_cell = nn.LSTMCell(input_sz, hidden_sz, bias=True)
(
pytorch_cell.weight_hh.shape,
pytorch_cell.weight_ih.shape,
pytorch_cell.bias_ih.shape,
pytorch_cell.bias_hh.shape,
)
(torch.Size([400, 100]),
torch.Size([400, 20]),
torch.Size([400]),
torch.Size([400]))
cell = LSTMCellNew(input_sz, hidden_sz)
# To make sure pytorch and our implementation both
# have the same weights so we can compare them
cell.ih.weight.data = pytorch_cell.weight_ih.data
cell.hh.weight.data = pytorch_cell.weight_hh.data
cell.ih.bias.data = pytorch_cell.bias_ih.data
cell.hh.bias.data = pytorch_cell.bias_hh.data
There are few things worth mentioning about our LSTM
implementations as well as other implementations in common libraries:
T x B x input_sz
.num_layers
argument. This would build multiple LSTM layers, each has its own LSTMCell
that is shared across all timesteps within each layer. This would increase the capacity of the model.class LSTMNew(nn.Module):
def __init__(self, input_sz, hidden_sz, num_layers=1):
super().__init__()
self.num_layers = num_layers
self.hidden_sz = hidden_sz
self.cells = nn.ModuleList(
[
LSTMCellNew(input_sz, hidden_sz)
if i == 0
else LSTMCellNew(hidden_sz, hidden_sz)
for i in range(self.num_layers)
]
)
def forward(self, x, h_t, c_t):
# x : T x B x hidden_sz
# h_t: num_layers x B x hidden_sz
# c_t: num_layers x B x hidden_sz
T, B, _ = x.shape
H = torch.zeros(T, B, self.hidden_sz)
for i, cell in enumerate(self.cells):
h, c = h_t[i], c_t[i]
if i > 0:
x = H
for t in range(T):
h, c = cell(x[t], h, c)
H[t] = h
# last hidden state for each layer
h_t[i], c_t[i] = h, c
# Truncated BPTT
return H, (h_t.detach(), c_t.detach())
lstm = LSTMNew(input_sz, hidden_sz, num_layers=num_layers)
for i in range(num_layers):
lstm.cells[i].ih.weight.data = getattr(pytorch_lstm, f"weight_ih_l{i}").data
lstm.cells[i].hh.weight.data = getattr(pytorch_lstm, f"weight_hh_l{i}").data
lstm.cells[i].ih.bias.data = getattr(pytorch_lstm, f"bias_ih_l{i}").data
lstm.cells[i].hh.bias.data = getattr(pytorch_lstm, f"bias_hh_l{i}").data
H, (h_t, c_t) = lstm(X, h_0, c_0)
LSTMs were for a long time the solution for vanishing/exploding gradients problems vanilla RNNs have. They were the backbone models used in many NLP tasks such as machine translation and classification. In this post, we implemented both LSTM and LSTMCell. Hopefully, working through the implementation step by step made it a little easier and less intimidating to understand it.
The key takeaways are: