Modular Structure of Trainer ✅

import torch
from torch import nn
from torch.utils.data import TensorDataset
import rich
import rich.console
import rich.table
import my

1 Training Loop

1.1 A trainer class

from torch.utils.data import (
    DataLoader,
    random_split,
)
class Trainer:
    def __init__(self, model, optimizer, loss_fn, 
                 dataset=None, x=None, y=None, batch_size=None, 
                 validation_split=0.1,
                 optimizer_options=None,
                ):
        self.model = model
        optimizer_kw = optimizer_options or {}
        self.optimizer = optimizer(model.parameters(), **optimizer_kw)
        self.optimizer_name = str(self.optimizer).split()[0]
        self.loss_fn = loss_fn
        self.dataset = dataset
        self.batch_size = batch_size or 32
        
        if (x is not None) and (y is not None):
            self.dataset = TensorDataset(x, y)

        self.validation_split = validation_split
        (train_dataset, val_dataset) = random_split(self.dataset, [1-validation_split, validation_split])
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)

1.2 Summarize the trainer configuration

model = nn.Sequential(
    nn.Flatten(),
    nn.LazyLinear(10),
)
dataset = my.mnist()
optimizer = torch.optim.Adam
loss_fn = nn.CrossEntropyLoss()

trainer = Trainer(model, torch.optim.Adam, loss_fn, dataset, validation_split=0.1)
@my.add_method(Trainer)
def summary(self):
    console = rich.console.Console()
    table = rich.table.Table(show_header=False)
    table.add_row('optimizer', self.optimizer_name)
    table.add_row('loss', str(self.loss_fn))
    table.add_row('dataset', str(len(self.dataset)))
    table.add_row('validation_split', str(self.validation_split))
    table.add_row('batch_size', str(self.batch_size))
    console.print(table)
trainer.summary()
┌──────────────────┬────────────────────┐
│ optimizer        │ Adam               │
│ loss             │ CrossEntropyLoss() │
│ dataset          │ 60000              │
│ validation_split │ 0.1                │
│ batch_size       │ 32                 │
└──────────────────┴────────────────────┘

1.3 Training step

@my.add_method(Trainer)
def train_step(self, batch):
    (x, target) = batch
    y_out = self.model(x)
    loss = self.loss_fn(y_out, target)
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()
    acc = self.categorical_accuracy(y_out, target)
    return (loss.item(), acc)

1.4 Metrics

@my.add_method(Trainer)
def categorical_accuracy(self, y_out, target):
    pred = y_out.argmax(axis=1)
    success = (pred == target).sum().item()
    total = len(target)
    return success / total

1.5 Validate step

@my.add_method(Trainer)
def validate_step(self, batch):
    with torch.no_grad():
        (x, target) = batch
        y_out = model(x)
        loss = self.loss_fn(y_out, target)
        acc = self.categorical_accuracy(y_out, target)
        return (loss.item(), acc)

1.6 Training

import pandas as pd
import random
import numpy as np
from collections import defaultdict
from itertools import islice

@my.add_method(Trainer)
def train(self, epochs, max_batches=None):
    history = defaultdict(list)
    
    for epoch in range(epochs):
        train_loss, train_acc = [], []
        val_loss, val_acc = [], []

        for batch in islice(self.train_loader, 0, max_batches):            
            (loss, acc) = self.train_step(batch)
            train_loss.append(loss)
            train_acc.append(acc)
        
        for batch in islice(self.val_loader, 0, max_batches):
            (loss, acc) = self.validate_step(batch)
            val_loss.append(loss)
            val_acc.append(acc)
        
        # update history for this epoch
        history['train_loss'].append(np.mean(train_loss))
        history['val_loss'].append(np.mean(val_loss))
        
        history['train_acc'].append(np.mean(train_acc))
        history['val_acc'].append(np.mean(val_acc))
                
    return pd.DataFrame(history)

1.7 Train the model

history = trainer.train(epochs=10, max_batches=10)
history
train_loss val_loss train_acc val_acc
0 0.904273 0.848626 0.803125 0.843750
1 0.810787 0.815212 0.837500 0.828125
2 0.843481 0.766887 0.840625 0.853125
3 0.814210 0.737944 0.803125 0.859375
4 0.792618 0.718360 0.837500 0.850000
5 0.769474 0.694460 0.800000 0.850000
6 0.740073 0.665791 0.803125 0.865625
7 0.679719 0.646025 0.831250 0.862500
8 0.662727 0.627127 0.865625 0.878125
9 0.636068 0.612255 0.862500 0.875000

1.8 Reset the model

@my.add_method(Trainer)
def reset_parameters(self):
    for layer in self.model:
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()
trainer.reset_parameters()
history = trainer.train(epochs=10, max_batches=50)
history
train_loss val_loss train_acc val_acc
0 1.684306 1.151697 0.567500 0.765000
1 0.986109 0.806676 0.801250 0.837500
2 0.758648 0.673942 0.836875 0.841875
3 0.629792 0.593518 0.862500 0.863125
4 0.567320 0.548597 0.870625 0.868125
5 0.541308 0.511265 0.870000 0.871250
6 0.493871 0.488149 0.880000 0.870000
7 0.480564 0.466194 0.879375 0.878125
8 0.470205 0.449064 0.888125 0.876875
9 0.442064 0.434517 0.881250 0.881250
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(1,2,1)
history[['train_loss', 'val_loss']].plot.line(ax=ax)

ax = fig.add_subplot(1,2,2)
history[['train_acc', 'val_acc']].plot.line(ax=ax);

2 Test model

@my.add_method(Trainer)
def test(self):
    if self.test_dataset:
        dataloader = DataLoader(self.test_dataset, batch_size=self.batch_size)
        test_loss = []
        test_acc = []
        for batch in dataloader:
            (loss, acc) = self.validate_step(batch)
            test_loss.append(loss)
            test_acc.append(acc)
        return pd.DataFrame({
            'test_loss': [np.mean(test_loss)],
            'test_acc': [np.mean(test_acc)],
        })
trainer.test_dataset = my.mnist(train=False)
trainer.test()
test_loss test_acc
0 0.426321 0.893271