{:check ["true"]}

Index

Text Encoding And The IMDB Movie Review Dataset

1 Text Learning

In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.datasets as datasets
import tensorflow.keras.layers as layers
import numpy as np
In [2]:
#
# Load the IMDB movie review dataset
#
data = datasets.imdb.load_data()
<__array_function__ internals>:5: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
/opt/conda/lib/python3.8/site-packages/tensorflow/python/keras/datasets/imdb.py:159: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
/opt/conda/lib/python3.8/site-packages/tensorflow/python/keras/datasets/imdb.py:160: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
In [3]:
#
# Similar to MNIST dataset, it's split between training and testing data
#
(x_train, y_train), (x_test, y_test) = data
In [8]:
#
# Load the word_index table
#
index = datasets.imdb.get_word_index()
In [9]:
#
# build the lookup table
#
id_to_word = dict((i+3, w) for (w,i) in index.items())
id_to_word.update({
    0: '<PAD>',
    1: '<START>',
    2: '<UNKNOWN>',
    3: '<UNUSED>',
})
In [13]:
np.array(x_train[0])[:10]
Out[13]:
array([   1,   14,   22,   16,   43,  530,  973, 1622, 1385,   65])
In [16]:
" ".join([id_to_word[i] for i in x_train[0][:10]])
Out[16]:
'<START> this film was just brilliant casting location scenery story'
In [21]:
#
# The vocabulary size is
#
len(index)
Out[21]:
88584
In [17]:
def decode(ordinals):
    return " ".join(id_to_word.get(i) for i in ordinals)
In [20]:
decode(x_train[2]), y_train[2]
Out[20]:
("<START> this has to be one of the worst films of the 1990s when my friends i were watching this film being the target audience it was aimed at we just sat watched the first half an hour with our jaws touching the floor at how bad it really was the rest of the time everyone else in the theatre just started talking to each other leaving or generally crying into their popcorn that they actually paid money they had earnt working to watch this feeble excuse for a film it must have looked like a great idea on paper but on film it looks like no one in the film has a clue what is going on crap acting crap costumes i can't get across how embarrasing this is to watch save yourself an hour a bit of your life",
 0)

A reduced dataset

In [22]:
data = datasets.imdb.load_data(
    num_words=1000,
    skip_top=5
)

(x_train, y_train), (x_test, y_test) = data
<__array_function__ internals>:5: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
/opt/conda/lib/python3.8/site-packages/tensorflow/python/keras/datasets/imdb.py:159: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
/opt/conda/lib/python3.8/site-packages/tensorflow/python/keras/datasets/imdb.py:160: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
In [23]:
decode(x_train[0])
Out[23]:
"<UNKNOWN> this film was just brilliant casting <UNKNOWN> <UNKNOWN> story direction <UNKNOWN> really <UNKNOWN> <UNKNOWN> part they played and you could just imagine being there robert <UNKNOWN> is an amazing actor and now <UNKNOWN> same being director <UNKNOWN> father came from <UNKNOWN> same <UNKNOWN> <UNKNOWN> as myself so i loved <UNKNOWN> fact there was a real <UNKNOWN> with this film <UNKNOWN> <UNKNOWN> <UNKNOWN> throughout <UNKNOWN> film were great it was just brilliant so much that i <UNKNOWN> <UNKNOWN> film as soon as it was released for <UNKNOWN> and would recommend it to everyone to watch and <UNKNOWN> <UNKNOWN> <UNKNOWN> was amazing really <UNKNOWN> at <UNKNOWN> end it was so sad and you know what they say if you <UNKNOWN> at a film it must have been good and this definitely was also <UNKNOWN> to <UNKNOWN> two little <UNKNOWN> that played <UNKNOWN> <UNKNOWN> of <UNKNOWN> and paul they were just brilliant children are often left out of <UNKNOWN> <UNKNOWN> <UNKNOWN> i think because <UNKNOWN> stars that play them all <UNKNOWN> up are such a big <UNKNOWN> for <UNKNOWN> whole film but these children are amazing and should be <UNKNOWN> for what they have done don't you think <UNKNOWN> whole story was so <UNKNOWN> because it was true and was <UNKNOWN> life after all that was <UNKNOWN> with us all"
In [24]:
decode(x_train[1])
Out[24]:
"<UNKNOWN> big <UNKNOWN> big <UNKNOWN> bad music and a <UNKNOWN> <UNKNOWN> <UNKNOWN> these are <UNKNOWN> words to best <UNKNOWN> this terrible movie i love cheesy horror movies and i've seen <UNKNOWN> but this had got to be on of <UNKNOWN> worst ever made <UNKNOWN> plot is <UNKNOWN> <UNKNOWN> and ridiculous <UNKNOWN> acting is an <UNKNOWN> <UNKNOWN> script is completely <UNKNOWN> <UNKNOWN> best is <UNKNOWN> end <UNKNOWN> with <UNKNOWN> <UNKNOWN> and how he worked out who <UNKNOWN> killer is it's just so <UNKNOWN> <UNKNOWN> written <UNKNOWN> <UNKNOWN> are <UNKNOWN> and funny in <UNKNOWN> <UNKNOWN> <UNKNOWN> <UNKNOWN> is big lots of <UNKNOWN> <UNKNOWN> men <UNKNOWN> those cut <UNKNOWN> <UNKNOWN> that show off their <UNKNOWN> <UNKNOWN> that men actually <UNKNOWN> them and <UNKNOWN> music is just <UNKNOWN> <UNKNOWN> that plays over and over again in almost every scene there is <UNKNOWN> music <UNKNOWN> and <UNKNOWN> taking away <UNKNOWN> and <UNKNOWN> <UNKNOWN> still doesn't close for <UNKNOWN> all <UNKNOWN> <UNKNOWN> this is a truly bad film whose only <UNKNOWN> is to look back on <UNKNOWN> <UNKNOWN> that was <UNKNOWN> <UNKNOWN> and have a good old laugh at how bad everything was back then"
In [25]:
#
# To encode text to ordinals (at the application level), we compute the reverse lookup table
#
word_to_id = {
    w:i for (i,w) in id_to_word.items()
}
In [26]:
def encode(text):
    return [word_to_id.get(w, 2) for w in text.split()]
In [27]:
encode("this film was just brilliant casting")
Out[27]:
[14, 22, 16, 43, 530, 973]
In [ ]:
 

2 Sequence Layers

In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.datasets as datasets
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
import tensorflow.keras.preprocessing.sequence as sequence
import numpy as np

Load the IMDB Movie Review dataset

In [2]:
data = datasets.imdb.load_data(num_words=10000, skip_top=10)
word_index = datasets.imdb.get_word_index()
<__array_function__ internals>:5: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
/opt/conda/lib/python3.8/site-packages/tensorflow/python/keras/datasets/imdb.py:159: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
/opt/conda/lib/python3.8/site-packages/tensorflow/python/keras/datasets/imdb.py:160: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
In [3]:
i2w = dict((i+3, w) for (w,i) in word_index.items())
i2w.update({
    0: '<PAD>',
    1: '<START>',
    2: '<OOV>',
    3: '<?>',
})
In [4]:
(x_train, y_train), (x_test, y_test) = data

The ordinal2text function decodes the ordinal numbers to text.

In [5]:
def ordinal2text(seq):
    return " ".join([i2w[i] for i in seq])
In [6]:
ordinal2text(x_train[0])
Out[6]:
"<OOV> this film was just brilliant casting location scenery story direction everyone's really suited <OOV> part they played <OOV> you could just imagine being there robert <OOV> <OOV> an amazing actor <OOV> now <OOV> same being director <OOV> father came from <OOV> same scottish island as myself so i loved <OOV> fact there was <OOV> real connection with this film <OOV> witty remarks throughout <OOV> film were great it was just brilliant so much that i bought <OOV> film as soon as it was released for <OOV> <OOV> would recommend it <OOV> everyone <OOV> watch <OOV> <OOV> fly fishing was amazing really cried at <OOV> end it was so sad <OOV> you know what they say if you cry at <OOV> film it must have been good <OOV> this definitely was also <OOV> <OOV> <OOV> two little boy's that played <OOV> <OOV> <OOV> norman <OOV> paul they were just brilliant children are often left out <OOV> <OOV> <OOV> list i think because <OOV> stars that play them all grown up are such <OOV> big profile for <OOV> whole film but these children are amazing <OOV> should be praised for what they have done don't you think <OOV> whole story was so lovely because it was true <OOV> was someone's life after all that was shared with us all"

The embedding layer maps ordinal indices to embedding vectors.

  • Input shape is:

    (batch_size, sequence_length)

  • Output shape is:

    (batch_size, sequence_length, dimension)

In [7]:
embedding = layers.Embedding(10000, 20)
In [8]:
input_seq = np.array([[1,2,3, 3, 2], [1,2, 0, 0, 0]])

embedding_vectors = embedding(input_seq)

print(input_seq.shape, "=>", embedding_vectors.shape)
(2, 5) => (2, 5, 20)

The SimpleRNN layer uses state vectors to perform reduction on the sequence of input vectors.

By default it outputs the final output vector.

It's possible to obtain the final state vector using return_state=True.

It's also possible to obtain the sequence of output vectors during the reduction using return_sequences=True.

In [9]:
rnn = layers.SimpleRNN(7)
In [10]:
output_vector = rnn(embedding_vectors)

print(embedding_vectors.shape, "=>", output_vector.shape)
(2, 5, 20) => (2, 7)

We can use a dense layer to conver the RNN output to a probability of good movie review.

In [11]:
dense = layers.Dense(1, activation='sigmoid')
In [12]:
output = dense(output_vector)

print(output_vector.shape, "=>", output.shape)
(2, 7) => (2, 1)

Padding and truncation

We need all the input sequences to have the same length for embedding layer to work.

This requires us to padd all the sequences that are too short, and truncate sequences that are too long.

keras.preprocessing.sequence module provides the pad_sequences function to make this task easier.

In [13]:
batch = x_train[:5]

[len(x) for x in batch]
Out[13]:
[218, 189, 141, 550, 147]
In [14]:
padded_batch = sequence.pad_sequences(batch, maxlen=200)
[len(x) for x in padded_batch]
Out[14]:
[200, 200, 200, 200, 200]
In [15]:
padded_batch
Out[15]:
array([[   2,   25,  100,   43,  838,  112,   50,  670,    2,    2,   35,
         480,  284,    2,  150,    2,  172,  112,  167,    2,  336,  385,
          39,    2,  172, 4536, 1111,   17,  546,   38,   13,  447,    2,
         192,   50,   16,    2,  147, 2025,   19,   14,   22,    2, 1920,
        4613,  469,    2,   22,   71,   87,   12,   16,   43,  530,   38,
          76,   15,   13, 1247,    2,   22,   17,  515,   17,   12,   16,
         626,   18,    2,    2,   62,  386,   12,    2,  316,    2,  106,
           2,    2, 2223, 5244,   16,  480,   66, 3785,   33,    2,  130,
          12,   16,   38,  619,    2,   25,  124,   51,   36,  135,   48,
          25, 1415,   33,    2,   22,   12,  215,   28,   77,   52,    2,
          14,  407,   16,   82,    2,    2,    2,  107,  117, 5952,   15,
         256,    2,    2,    2, 3766,    2,  723,   36,   71,   43,  530,
         476,   26,  400,  317,   46,    2,    2,    2, 1029,   13,  104,
          88,    2,  381,   15,  297,   98,   32, 2071,   56,   26,  141,
           2,  194, 7486,   18,    2,  226,   22,   21,  134,  476,   26,
         480,    2,  144,   30, 5535,   18,   51,   36,   28,  224,   92,
          25,  104,    2,  226,   65,   16,   38, 1334,   88,   12,   16,
         283,    2,   16, 4472,  113,  103,   32,   15,   16, 5345,   19,
         178,   32],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           2,  194, 1153,  194, 8255,   78,  228,    2,    2, 1463, 4369,
        5012,  134,   26,    2,  715,    2,  118, 1634,   14,  394,   20,
          13,  119,  954,  189,  102,    2,  207,  110, 3103,   21,   14,
          69,  188,    2,   30,   23,    2,    2,  249,  126,   93,    2,
         114,    2, 2300, 1523,    2,  647,    2,  116,    2,   35, 8163,
           2,  229,    2,  340, 1322,    2,  118,    2,    2,  130, 4901,
          19,    2, 1002,    2,   89,   29,  952,   46,   37,    2,  455,
           2,   45,   43,   38, 1543, 1905,  398,    2, 1649,   26, 6853,
           2,  163,   11, 3215,    2,    2, 1153,    2,  194,  775,    2,
        8255,    2,  349, 2637,  148,  605,    2, 8003,   15,  123,  125,
          68,    2, 6853,   15,  349,  165, 4362,   98,    2,    2,  228,
           2,   43,    2, 1157,   15,  299,  120,    2,  120,  174,   11,
         220,  175,  136,   50,    2, 4373,  228, 8255,    2,    2,  656,
         245, 2350,    2,    2, 9837,  131,  152,  491,   18,    2,   32,
        7464, 1212,   14,    2,    2,  371,   78,   22,  625,   64, 1382,
           2,    2,  168,  145,   23,    2, 1690,   15,   16,    2, 1355,
           2,   28,    2,   52,  154,  462,   33,   89,   78,  285,   16,
         145,   95],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    2,   14,   47,    2,   30,   31,    2,
           2,  249,  108,    2,    2, 5974,   54,   61,  369,   13,   71,
         149,   14,   22,  112,    2, 2401,  311,   12,   16, 3711,   33,
          75,   43, 1829,  296,    2,   86,  320,   35,  534,   19,  263,
        4821, 1301,    2, 1873,   33,   89,   78,   12,   66,   16,    2,
         360,    2,    2,   58,  316,  334,   11,    2, 1716,   43,  645,
         662,    2,  257,   85, 1200,   42, 1228, 2578,   83,   68, 3912,
          15,   36,  165, 1539,  278,   36,   69,    2,  780,    2,  106,
          14, 6905, 1338,   18,    2,   22,   12,  215,   28,  610,   40,
           2,   87,  326,   23, 2300,   21,   23,   22,   12,  272,   40,
          57,   31,   11,    2,   22,   47,    2, 2307,   51,    2,  170,
          23,  595,  116,  595, 1352,   13,  191,   79,  638,   89,    2,
          14,    2,    2,  106,  607,  624,   35,  534,    2,  227,    2,
         129,  113],
       [   2,  341,    2,   27,  846,   10,   10,   29,    2, 1906,    2,
          97,    2,  236,    2, 1311,    2,    2,    2,    2,   31,    2,
           2,   91,    2, 3987,   70,    2,  882,   30,  579,   42,    2,
          12,   32,   11,  537,   10,   10,   11,   14,   65,   44,  537,
          75,    2, 1775, 3353,    2, 1846,    2,    2,    2,  154,    2,
           2,  518,   53,    2,    2,    2, 3211,  882,   11,  399,   38,
          75,  257, 3807,   19,    2,   17,   29,  456,    2,   65,    2,
          27,  205,  113,   10,   10,    2,    2,    2,    2,    2,  242,
           2,   91, 1202,    2,    2, 2070,  307,   22,    2, 5168,  126,
          93,   40,    2,   13,  188, 1076, 3222,   19,    2,    2,    2,
        2348,  537,   23,   53,  537,   21,   82,   40,    2,   13,    2,
          14,  280,   13,  219,    2,    2,  431,  758,  859,    2,  953,
        1052,    2,    2, 5991,    2,   94,   40,   25,  238,   60,    2,
           2,    2,  804,    2,    2,    2, 9941,  132,    2,   67,    2,
          22,   15,    2,  283,    2, 5168,   14,   31,    2,  242,  955,
          48,   25,  279,    2,   23,   12, 1685,  195,   25,  238,   60,
         796,    2,    2,  671,    2, 2804,    2,    2,  559,  154,  888,
           2,  726,   50,   26,   49, 7008,   15,  566,   30,  579,   21,
          64, 2574],
       [   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    2,  249,
        1323,    2,   61,  113,   10,   10,   13, 1637,   14,   20,   56,
          33, 2401,   18,  457,   88,   13, 2626, 1400,   45, 3171,   13,
          70,   79,   49,  706,  919,   13,   16,  355,  340,  355, 1696,
          96,  143,    2,   22,   32,  289,    2,   61,  369,   71, 2359,
           2,   13,   16,  131, 2073,  249,  114,  249,  229,  249,   20,
          13,   28,  126,  110,   13,  473,    2,  569,   61,  419,   56,
         429,    2, 1513,   18,   35,  534,   95,  474,  570,    2,   25,
         124,  138,   88,   12,  421, 1543,   52,  725, 6397,   61,  419,
          11,   13, 1571,   15, 1543,   20,   11,    2,    2,    2,  296,
          12, 3524,    2,   15,  421,  128,   74,  233,  334,  207,  126,
         224,   12,  562,  298, 2167, 1272,    2, 2601,    2,  516,  988,
          43,    2,   79,  120,   15,  595,   13,  784,   25, 3171,   18,
         165,  170,  143,   19,   14,    2, 7224,    2,  226,  251,    2,
          61,  113]], dtype=int32)

Building an end-to-end network

In [16]:
maxlen = 200
inputs = layers.Input(shape=(maxlen))

x = embedding(inputs)
x = rnn(x)
sentiment_output = dense(x)
In [17]:
model = models.Model(inputs=inputs, outputs=sentiment_output)
In [18]:
keras.utils.plot_model(model)
Out[18]:
In [19]:
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
In [20]:
x_train_padded = sequence.pad_sequences(x_train, maxlen=maxlen)
model.fit(x_train_padded, y_train, epochs=5, validation_split=0.2)
Epoch 1/5
625/625 [==============================] - 32s 50ms/step - loss: 0.6826 - acc: 0.5432 - val_loss: 0.5354 - val_acc: 0.7558
Epoch 2/5
625/625 [==============================] - 30s 48ms/step - loss: 0.4742 - acc: 0.7805 - val_loss: 0.5094 - val_acc: 0.7670
Epoch 3/5
625/625 [==============================] - 30s 48ms/step - loss: 0.3608 - acc: 0.8498 - val_loss: 0.4959 - val_acc: 0.7804
Epoch 4/5
625/625 [==============================] - 30s 48ms/step - loss: 0.2646 - acc: 0.8940 - val_loss: 0.5261 - val_acc: 0.7790
Epoch 5/5
625/625 [==============================] - 30s 48ms/step - loss: 0.1882 - acc: 0.9376 - val_loss: 0.5668 - val_acc: 0.7898
Out[20]:
<tensorflow.python.keras.callbacks.History at 0x7f0bd4ca4cd0>

We can make use of the model to do some text analysis.

In [24]:
w2i = {w:i for (i,w) in i2w.items()}

def text2ordinal(text):
    return [w2i.get(w, 2) for w in text.split()]
In [34]:
text_ordinals = text2ordinal("this film was just brilliant")
model.predict(sequence.pad_sequences([text_ordinals], maxlen=maxlen))
Out[34]:
array([[0.9585831]], dtype=float32)
In [46]:
text_ordinals = text2ordinal("rendered terrible flat flat flat performances")
model.predict(sequence.pad_sequences([text_ordinals], maxlen=maxlen))
Out[46]:
array([[0.4182482]], dtype=float32)