{:check ["true"]}

Index

Overfitting And What To Do About It.

Overfitting and detection with cross-validation

Kernel Regulation and Dropout layers

2 Overfitting

On Overfitting

In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
import tensorflow.keras.losses as losses
import tensorflow.keras.optimizers as optimizers
import tensorflow.keras.datasets as datasets

import numpy as np
import matplotlib.pyplot as pl

Dataset

Let's consider a scenario where we have only a small amount of training data.

In [2]:
(data_train, data_test) = tf.keras.datasets.mnist.load_data(path='/data/shared/datasets/mnist.npz')

x_train, y_train = data_train
x_test, y_test = data_test

x_train = x_train / 255
x_test = x_test / 255

N = x_train.shape[0]
N_train = 100

I = np.arange(N)
np.random.shuffle(I)
x_train = x_train[I[:N_train]]
y_train = y_train[I[:N_train]]

Model

Let's build a fairly large feed-forward neural network that are densely connected.

In [3]:
model = models.Sequential([
    layers.Input(shape=(28, 28)),
    layers.Reshape((-1,)),
    layers.Dense(100, activation='relu'),
    layers.Dense(100, activation='relu'),
    layers.Dense(100, activation='relu'),
    layers.Dense(10, activation='softmax')
])
model.save_weights('overfitting_model.init.h5')

model.compile(
    loss=losses.SparseCategoricalCrossentropy(),
    optimizer=optimizers.Adam(),
    metrics=['acc'],
)

Training And Testing

We will train the model with 10 epoches.

Note that we ensure that the model loads the initial randomized parameter weights.

In [6]:
model.load_weights('overfitting_model.init.h5')
history = model.fit(x_train, y_train, epochs=10, verbose=2)
Epoch 1/10
4/4 - 0s - loss: 2.2853 - acc: 0.1200
Epoch 2/10
4/4 - 0s - loss: 2.1497 - acc: 0.2400
Epoch 3/10
4/4 - 0s - loss: 1.9989 - acc: 0.5600
Epoch 4/10
4/4 - 0s - loss: 1.8123 - acc: 0.5800
Epoch 5/10
4/4 - 0s - loss: 1.6280 - acc: 0.5700
Epoch 6/10
4/4 - 0s - loss: 1.3980 - acc: 0.8000
Epoch 7/10
4/4 - 0s - loss: 1.2080 - acc: 0.8200
Epoch 8/10
4/4 - 0s - loss: 0.9806 - acc: 0.8400
Epoch 9/10
4/4 - 0s - loss: 0.7226 - acc: 0.9500
Epoch 10/10
4/4 - 0s - loss: 0.5848 - acc: 0.9500

Eventhough the training accuracy is well over 90%, the test accuracy is quite poor: 68%.

This phenomenon is commonly occurring, and is known as overfitting.

Overfitting is caused by the model being too complex for the amount of variations in the trainging data.

In [7]:
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
313/313 - 3s - loss: 1.0629 - acc: 0.6894
In [8]:
pl.plot(history.history['acc'])
pl.plot([0, 10], [test_acc, test_acc], '--')
pl.ylim(0, 1.2);

Cross Validation

We are not allowed to use test data in anyway during training.

So, how do we detect the occurrence of overfitting?

Cross validation is the answer.

In [9]:
model.load_weights('overfitting_model.init.h5')
history = model.fit(x_train, y_train, epochs=40, verbose=2,
                    validation_split=0.1)
Epoch 1/40
3/3 - 1s - loss: 2.2906 - acc: 0.1333 - val_loss: 2.2282 - val_acc: 0.1000
Epoch 2/40
3/3 - 0s - loss: 2.1928 - acc: 0.2000 - val_loss: 2.1784 - val_acc: 0.1000
Epoch 3/40
3/3 - 0s - loss: 2.0639 - acc: 0.3889 - val_loss: 2.1169 - val_acc: 0.3000
Epoch 4/40
3/3 - 0s - loss: 1.9116 - acc: 0.5000 - val_loss: 2.0431 - val_acc: 0.5000
Epoch 5/40
3/3 - 0s - loss: 1.7690 - acc: 0.6000 - val_loss: 1.9523 - val_acc: 0.5000
Epoch 6/40
3/3 - 0s - loss: 1.6036 - acc: 0.6556 - val_loss: 1.8473 - val_acc: 0.5000
Epoch 7/40
3/3 - 0s - loss: 1.4228 - acc: 0.7556 - val_loss: 1.7209 - val_acc: 0.6000
Epoch 8/40
3/3 - 0s - loss: 1.2384 - acc: 0.8333 - val_loss: 1.6003 - val_acc: 0.6000
Epoch 9/40
3/3 - 0s - loss: 1.0577 - acc: 0.8889 - val_loss: 1.4825 - val_acc: 0.7000
Epoch 10/40
3/3 - 0s - loss: 0.8745 - acc: 0.9222 - val_loss: 1.3857 - val_acc: 0.7000
Epoch 11/40
3/3 - 0s - loss: 0.7170 - acc: 0.9444 - val_loss: 1.2968 - val_acc: 0.7000
Epoch 12/40
3/3 - 0s - loss: 0.5644 - acc: 0.9778 - val_loss: 1.2046 - val_acc: 0.8000
Epoch 13/40
3/3 - 0s - loss: 0.4458 - acc: 0.9778 - val_loss: 1.1458 - val_acc: 0.8000
Epoch 14/40
3/3 - 0s - loss: 0.3456 - acc: 0.9778 - val_loss: 1.0967 - val_acc: 0.8000
Epoch 15/40
3/3 - 0s - loss: 0.2703 - acc: 0.9889 - val_loss: 1.0527 - val_acc: 0.8000
Epoch 16/40
3/3 - 0s - loss: 0.2040 - acc: 0.9889 - val_loss: 1.0557 - val_acc: 0.8000
Epoch 17/40
3/3 - 0s - loss: 0.1625 - acc: 0.9889 - val_loss: 1.0886 - val_acc: 0.8000
Epoch 18/40
3/3 - 0s - loss: 0.1231 - acc: 1.0000 - val_loss: 1.1287 - val_acc: 0.8000
Epoch 19/40
3/3 - 0s - loss: 0.0985 - acc: 1.0000 - val_loss: 1.1244 - val_acc: 0.7000
Epoch 20/40
3/3 - 0s - loss: 0.0783 - acc: 1.0000 - val_loss: 1.1261 - val_acc: 0.7000
Epoch 21/40
3/3 - 0s - loss: 0.0635 - acc: 1.0000 - val_loss: 1.1453 - val_acc: 0.7000
Epoch 22/40
3/3 - 0s - loss: 0.0525 - acc: 1.0000 - val_loss: 1.1926 - val_acc: 0.7000
Epoch 23/40
3/3 - 0s - loss: 0.0439 - acc: 1.0000 - val_loss: 1.2427 - val_acc: 0.7000
Epoch 24/40
3/3 - 0s - loss: 0.0363 - acc: 1.0000 - val_loss: 1.2489 - val_acc: 0.7000
Epoch 25/40
3/3 - 0s - loss: 0.0313 - acc: 1.0000 - val_loss: 1.2383 - val_acc: 0.7000
Epoch 26/40
3/3 - 0s - loss: 0.0274 - acc: 1.0000 - val_loss: 1.2352 - val_acc: 0.7000
Epoch 27/40
3/3 - 0s - loss: 0.0240 - acc: 1.0000 - val_loss: 1.2478 - val_acc: 0.7000
Epoch 28/40
3/3 - 0s - loss: 0.0212 - acc: 1.0000 - val_loss: 1.2805 - val_acc: 0.7000
Epoch 29/40
3/3 - 0s - loss: 0.0189 - acc: 1.0000 - val_loss: 1.2955 - val_acc: 0.7000
Epoch 30/40
3/3 - 0s - loss: 0.0171 - acc: 1.0000 - val_loss: 1.3165 - val_acc: 0.7000
Epoch 31/40
3/3 - 0s - loss: 0.0156 - acc: 1.0000 - val_loss: 1.3221 - val_acc: 0.7000
Epoch 32/40
3/3 - 0s - loss: 0.0144 - acc: 1.0000 - val_loss: 1.3211 - val_acc: 0.7000
Epoch 33/40
3/3 - 0s - loss: 0.0132 - acc: 1.0000 - val_loss: 1.3312 - val_acc: 0.7000
Epoch 34/40
3/3 - 0s - loss: 0.0123 - acc: 1.0000 - val_loss: 1.3341 - val_acc: 0.7000
Epoch 35/40
3/3 - 0s - loss: 0.0115 - acc: 1.0000 - val_loss: 1.3408 - val_acc: 0.7000
Epoch 36/40
3/3 - 0s - loss: 0.0108 - acc: 1.0000 - val_loss: 1.3482 - val_acc: 0.7000
Epoch 37/40
3/3 - 0s - loss: 0.0102 - acc: 1.0000 - val_loss: 1.3585 - val_acc: 0.7000
Epoch 38/40
3/3 - 0s - loss: 0.0096 - acc: 1.0000 - val_loss: 1.3656 - val_acc: 0.7000
Epoch 39/40
3/3 - 0s - loss: 0.0091 - acc: 1.0000 - val_loss: 1.3720 - val_acc: 0.7000
Epoch 40/40
3/3 - 0s - loss: 0.0086 - acc: 1.0000 - val_loss: 1.3768 - val_acc: 0.7000
In [13]:
pl.plot(history.history['loss'])
pl.plot(history.history['val_loss']);
pl.legend(['Training Loss', 'Cross Validation Loss']);
In [14]:
pl.plot(history.history['acc'])
pl.plot(history.history['val_acc'])
pl.plot([0, 40], [test_acc, test_acc], '--')
pl.legend(['Training Accuacy', 'Cross Validation Accuracy', 'Test Accuracy']);
In [ ]: