import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
Auto Encoder ✅
1 Revisiting the MNIST dataset
import torchvision
from torchvision import transforms
import my
= torchvision.datasets.mnist.MNIST(
mnist
my.DATASET_PATH,= transforms.Compose([
transform
transforms.ToTensor(),
]) )
#
# Find an example of each digit 0 .. 9
#
= DataLoader(mnist, batch_size=100)
dataloader = next(iter(dataloader))
xs, targets = []
digits for d in range(10):
= np.where(targets == d)[0][0]
i digits.append(xs[i])
#
# Visualize them
#
import matplotlib.pyplot as pl
for (i, x) in enumerate(digits):
1, 10, i+1)
pl.subplot(='gray')
pl.imshow(x.squeeze(), cmap
pl.xticks([]) pl.yticks([])
2 Low dimensional representation with encoder
- Each digit is encoded as a 28x28=784 image. This is a huge dimensionality for 10 classes.
- Can we encode each digit using a much smaller vector, say 2 dimensions?
- This is the function of an encoder.
\[ \mathbf{Encoder} : \mathbb{R}^{28\times 28} \to \mathbb{R}^2 \]
We can implement this using any kind of architecture – for example a simple MLP will do.
def make_encoder(DIM):
return nn.Sequential(
nn.Flatten(),28 * 28, 100),
nn.Linear(
nn.ReLU(),100, DIM),
nn.Linear(
nn.Tanh(), )
Questions?
- What is the evaluation of the quality of the encoding in \(\mathbb{R}^2\)? What is the loss function?
Answers
- The quality is that each encoding has enough to help a decoder to recover the input.
NOTE:
We do not make use of the labels at all in this approach.
3 Reconstruction from low dimensional representation with decoder
A decoder is a function that tries to recover the initial input from its low dimensional encoding. Namely,
\[ \mathbf{Decoder} : \mathbb{R}^2 \to\mathbb{R}^{28\times 28} \]
def make_decoder(DIM):
return nn.Sequential(
100),
nn.Linear(DIM,
nn.ReLU(),100, 28*28),
nn.Linear(
nn.Sigmoid(),1, (1, 28, 28)),
nn.Unflatten( )
4 Visualizing encoder-decoder action
def show_enc_dec(encoder, decoder, digits):
= encoder.to('cpu')
encoder = decoder.to('cpu')
decoder with torch.no_grad():
for i, x in enumerate(digits):
= decoder(encoder(x[None, :, :, :]))
x2 1, 10, i+1)
pl.subplot(='gray')
pl.imshow(x2.squeeze(), cmap
pl.xticks([]) pl.yticks([])
= make_encoder(100)
encoder = make_decoder(100)
decoder
show_enc_dec(encoder, decoder, digits)
5 Training of encoder-decoder stack
The encoder-decoder stack is to recover the image.
\[ \mathbf{Decoder} \circ \mathbf{Encoder} : \mathbb{R}^{28\times 28} \to\mathbb{R}^{28\times 28} \]
We can use the original images as training data for the encoder-decoder stack.
\[ \mathbf{Training} = \{(x, x): x\in\mathbf{Dataste}\} \]
The loss function is just to compare the recovered image with the original image using mean-square error: \(\mathbf{MSE}(\mathbf{decoder}(\mathbf{encoder}(x)), x)\)
def train(encoder, decoder, dataloader, epochs, lr=0.01):
= list(encoder.parameters()) + \
params list(decoder.parameters())
= torch.optim.Adam(params)
optimizer = nn.MSELoss()
loss
= encoder.to(device)
encoder = decoder.to(device)
decoder
for epoch in range(epochs):
for x, _ in dataloader:
= x.to(device)
x = decoder(encoder(x))
x_out = loss(x_out, x)
l
l.backward()
optimizer.step()
optimizer.zero_grad()print(epoch, l.item())
= make_encoder(100)
encoder = make_decoder(100)
decoder
=1) train(encoder, decoder, dataloader, epochs
0 0.04748028889298439
show_enc_dec(encoder, decoder, digits)
=1)
train(encoder, decoder, dataloader, epochs show_enc_dec(encoder, decoder, digits)
0 0.03105493076145649
=1)
train(encoder, decoder, dataloader, epochs show_enc_dec(encoder, decoder, digits)
0 0.023360546678304672
=1)
train(encoder, decoder, dataloader, epochs show_enc_dec(encoder, decoder, digits)
0 0.018861427903175354
=1)
train(encoder, decoder, dataloader, epochs show_enc_dec(encoder, decoder, digits)
0 0.01599368266761303
=1)
train(encoder, decoder, dataloader, epochs show_enc_dec(encoder, decoder, digits)
0 0.014078974723815918
6 Going super low dimensional
= make_encoder(3)
e3 = make_decoder(3)
d3 =5) train(e3, d3, dataloader, epochs
0 0.06385622918605804
1 0.06373556703329086
2 0.0636959820985794
3 0.06369175761938095
4 0.06371547281742096
show_enc_dec(e3, d3, digits)
= make_encoder(50)
e10 = make_decoder(50)
d10 =5) train(e10, d10, dataloader, epochs
0 0.06367220729589462
1 0.0583176463842392
2 0.04323694854974747
3 0.034544687718153
4 0.02812880463898182
show_enc_dec(e10, d10, digits)