Sequence Learning Using Long-Short Term Memory (LSTM) [draft] ✅

import numpy as np
import torch
import torch.nn as nn
import torch.functional as functional
from torch.optim import Adam

import mytext
from torchtext.data import get_tokenizer
tokenizer = get_tokenizer('basic_english')
(reviews, targets) = mytext.imdb_reviews(tokenizer)
voc = mytext.build_vocab(reviews)
reviews_tensor = mytext.build_tensor(reviews, voc)
inputs_tensor = reviews_tensor[:, :200]
targets_tensor = torch.tensor(targets, dtype=torch.long)
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(inputs_tensor, targets_tensor)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
class SeqClassifierWithLSTM(nn.Module):
    def __init__(self, voc_size, input_dim, state_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(voc_size, input_dim)
        self.lstm = nn.LSTM(input_dim, state_dim, batch_first=True)
        self.fc = nn.Linear(state_dim, output_dim)
    def forward(self, tokens):
        x = self.embedding(tokens)
        y, (s, c) = self.lstm(x) # (-1, L, 
        # y: (-1, L, state_dim)
        # s: (1, L, state_dim)
        # c: (1, L, state_dim)
        return self.fc(c[0])
epochs = 5
input_dim = 32
state_dim = 64
model = SeqClassifierWithLSTM(len(voc), input_dim, state_dim, 2)
loss = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters())

for epoch in range(epochs):
    losses = []
    for (x, target) in dataloader:
        y = model(x)
        l = loss(y, target)
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(l.item())
    l = np.mean(losses)
    print("{}: loss={:.4f}".format(epoch, l))
0: loss=0.6943
1: loss=0.6850
2: loss=0.6682
3: loss=0.6208
4: loss=0.5540
#
# Evaluate
#
from torchmetrics import Accuracy

with torch.no_grad():
    success = 0
    total = 0
    for x, target in dataloader:
        y = model(x)
        pred = y.argmax(axis=1)
        success += (pred == target).sum()
        total += target.shape[0]
    print("Accuracy = {:.2f}".format(success/total))
Accuracy = 0.75