Deep Convolutional Networks ✅

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchsummaryX import summary
import numpy as np
import my
device(type='cuda', index=0)

1 Working with the MNIST dataset

mnist = my.mnist()
dataloader = DataLoader(mnist, batch_size=128)
(xs, ys) = next(iter(dataloader))
print("xs is", xs.shape, xs.dtype)
print("ys is", ys.shape, ys.dtype)
xs is torch.Size([128, 1, 28, 28]) torch.float32
ys is torch.Size([128]) torch.int64

2 A single convolutional layer as a nn.Module

class ConvMaxPool(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, pool_size=2, activation=None):            
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            padding='same'
        )
        self.pool = nn.MaxPool2d(kernel_size=pool_size)
        self.activation = activation
    
    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        if self.activation:
            x = self.activation(x)
        return x
model = ConvMaxPool(1, 5, kernel_size=3)
summary(model, xs);
=========================================================
        Kernel Shape      Output Shape Params Mult-Adds
Layer                                                  
0_conv  [1, 5, 3, 3]  [128, 5, 28, 28]   50.0    35.28k
1_pool             -  [128, 5, 14, 14]      -         -
---------------------------------------------------------
                      Totals
Total params            50.0
Trainable params        50.0
Non-trainable params     0.0
Mult-Adds             35.28k
=========================================================

3 Stacked convolutional pooling layers

class DeepConvMaxPool(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ConvMaxPool(1, 8, kernel_size=5, activation=nn.ReLU())
        self.conv2 = ConvMaxPool(8, 16, kernel_size=5, activation=nn.ReLU())
        self.conv3 = ConvMaxPool(16, 32, kernel_size=3, activation=nn.ReLU())
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x
model = DeepConvMaxPool()
summary(model, xs);
=============================================================================
                           Kernel Shape       Output Shape  Params Mult-Adds
Layer                                                                       
0_conv1.Conv2d_conv        [1, 8, 5, 5]   [128, 8, 28, 28]   208.0    156.8k
1_conv1.MaxPool2d_pool                -   [128, 8, 14, 14]       -         -
2_conv1.ReLU_activation               -   [128, 8, 14, 14]       -         -
3_conv2.Conv2d_conv       [8, 16, 5, 5]  [128, 16, 14, 14]  3.216k    627.2k
4_conv2.MaxPool2d_pool                -    [128, 16, 7, 7]       -         -
5_conv2.ReLU_activation               -    [128, 16, 7, 7]       -         -
6_conv3.Conv2d_conv      [16, 32, 3, 3]    [128, 32, 7, 7]   4.64k  225.792k
7_conv3.MaxPool2d_pool                -    [128, 32, 3, 3]       -         -
8_conv3.ReLU_activation               -    [128, 32, 3, 3]       -         -
-----------------------------------------------------------------------------
                         Totals
Total params             8.064k
Trainable params         8.064k
Non-trainable params        0.0
Mult-Adds             1.009792M
=============================================================================

4 The complete deep network for MNIST classification

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.deepconv = DeepConvMaxPool()
        self.flatten = nn.Flatten()
        self.fc1 = nn.LazyLinear(100)
        self.fc2 = nn.Linear(100, 10)
    def forward(self, x):
        x = self.deepconv(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
model = MNISTClassifier()
summary(model, xs);
=======================================================================================
                                    Kernel Shape       Output Shape  Params  \
Layer                                                                         
0_deepconv.conv1.Conv2d_conv        [1, 8, 5, 5]   [128, 8, 28, 28]   208.0   
1_deepconv.conv1.MaxPool2d_pool                -   [128, 8, 14, 14]       -   
2_deepconv.conv1.ReLU_activation               -   [128, 8, 14, 14]       -   
3_deepconv.conv2.Conv2d_conv       [8, 16, 5, 5]  [128, 16, 14, 14]  3.216k   
4_deepconv.conv2.MaxPool2d_pool                -    [128, 16, 7, 7]       -   
5_deepconv.conv2.ReLU_activation               -    [128, 16, 7, 7]       -   
6_deepconv.conv3.Conv2d_conv      [16, 32, 3, 3]    [128, 32, 7, 7]   4.64k   
7_deepconv.conv3.MaxPool2d_pool                -    [128, 32, 3, 3]       -   
8_deepconv.conv3.ReLU_activation               -    [128, 32, 3, 3]       -   
9_flatten                                      -         [128, 288]       -   
10_fc1                                [288, 100]         [128, 100]   28.9k   
11_fc2                                 [100, 10]          [128, 10]   1.01k   

                                 Mult-Adds  
Layer                                       
0_deepconv.conv1.Conv2d_conv        156.8k  
1_deepconv.conv1.MaxPool2d_pool          -  
2_deepconv.conv1.ReLU_activation         -  
3_deepconv.conv2.Conv2d_conv        627.2k  
4_deepconv.conv2.MaxPool2d_pool          -  
5_deepconv.conv2.ReLU_activation         -  
6_deepconv.conv3.Conv2d_conv      225.792k  
7_deepconv.conv3.MaxPool2d_pool          -  
8_deepconv.conv3.ReLU_activation         -  
9_flatten                                -  
10_fc1                               28.8k  
11_fc2                                1.0k  
---------------------------------------------------------------------------------------
                         Totals
Total params            37.974k
Trainable params        37.974k
Non-trainable params        0.0
Mult-Adds             1.039592M
=======================================================================================

5 Prepare for training

model = MNISTClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters())
loss_f = nn.CrossEntropyLoss()
(train_data, val_data) = random_split(mnist, (0.9, 0.1))

#
# Training data loader
#
train_dataloader = DataLoader(train_data, batch_size=128)

#
# Validation tensors
#
val_dataloader = DataLoader(val_data, batch_size=128)
def get_accuracy(model, dataloader):
    with torch.no_grad():
        total = 0.0
        succ = 0.0
        for x, target in dataloader:
            x, target = x.to(device), target.to(device)
            y = model(x)
            pred = y.argmax(axis=-1)
            succ += (pred == target).sum().item()
            total += len(target)
        return succ / total
get_accuracy(model, val_dataloader)
0.07033333333333333

6 Training loop

from tqdm.notebook import trange, tqdm
import time

def train(epochs=1):
    history = {
        'losses': [],
        'accuracy': [],
    }
    total_datasize = len(train_data)
    i = 0    
    for epoch in range(epochs):
        processed = 0
        start = time.time()

        for (x, target) in tqdm(train_dataloader):
            x, target = x.to(device), target.to(device)
            optimizer.zero_grad()
            y = model(x)
            loss = loss_f(y, target)
            loss.backward()
            optimizer.step()
            processed += len(x)
            progress = processed / total_datasize * 100
            if i % 100 == 0:
                print("({:.1f}%) Loss: {:.4f}".format(progress, loss.item()))
            history['losses'].append((i, loss.item()))
            i += 1
        val_acc = get_accuracy(model, val_dataloader)
        history['accuracy'].append((i, val_acc))
        duration = time.time() - start
        print("(epoch {}) Accuracy: {:.4f}, took {:.2f} seconds".format(epoch, val_acc, duration))
    return history
history = train(1)
(0.2%) Loss: 2.3020
(23.9%) Loss: 0.2657
(47.6%) Loss: 0.1738
(71.3%) Loss: 0.0929
(95.1%) Loss: 0.1230
(epoch 0) Accuracy: 0.9678, took 8.24 seconds
loss_history = np.array(history['losses'])
loss_history.shape
(422, 2)
import matplotlib.pyplot as pyplot

pyplot.plot(loss_history[:, 0], loss_history[:, 1]);