{:check ["true"]}
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
import tensorflow.keras.losses as losses
import tensorflow.keras.optimizers as optimizers
import tensorflow.keras.datasets as datasets
import numpy as np
import matplotlib.pyplot as pl
Let's consider a scenario where we have only a small amount of training data.
(data_train, data_test) = tf.keras.datasets.mnist.load_data(path='/data/shared/datasets/mnist.npz')
x_train, y_train = data_train
x_test, y_test = data_test
x_train = x_train / 255
x_test = x_test / 255
N = x_train.shape[0]
N_train = 100
I = np.arange(N)
np.random.shuffle(I)
x_train = x_train[I[:N_train]]
y_train = y_train[I[:N_train]]
Let's build a fairly large feed-forward neural network that are densely connected.
model = models.Sequential([
layers.Input(shape=(28, 28)),
layers.Reshape((-1,)),
layers.Dense(100, activation='relu'),
layers.Dense(100, activation='relu'),
layers.Dense(100, activation='relu'),
layers.Dense(10, activation='softmax')
])
model.save_weights('overfitting_model.init.h5')
model.compile(
loss=losses.SparseCategoricalCrossentropy(),
optimizer=optimizers.Adam(),
metrics=['acc'],
)
We will train the model with 10 epoches.
Note that we ensure that the model loads the initial randomized parameter weights.
model.load_weights('overfitting_model.init.h5')
history = model.fit(x_train, y_train, epochs=10, verbose=2)
Eventhough the training accuracy is well over 90%, the test accuracy is quite poor: 68%.
This phenomenon is commonly occurring, and is known as overfitting.
Overfitting is caused by the model being too complex for the amount of variations in the trainging data.
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
pl.plot(history.history['acc'])
pl.plot([0, 10], [test_acc, test_acc], '--')
pl.ylim(0, 1.2);
We are not allowed to use test data in anyway during training.
So, how do we detect the occurrence of overfitting?
Cross validation is the answer.
model.load_weights('overfitting_model.init.h5')
history = model.fit(x_train, y_train, epochs=40, verbose=2,
validation_split=0.1)
pl.plot(history.history['loss'])
pl.plot(history.history['val_loss']);
pl.legend(['Training Loss', 'Cross Validation Loss']);
pl.plot(history.history['acc'])
pl.plot(history.history['val_acc'])
pl.plot([0, 40], [test_acc, test_acc], '--')
pl.legend(['Training Accuacy', 'Cross Validation Accuracy', 'Test Accuracy']);