IMDB Movie Review Dataset¶

This is a library to download and parse the IMDB's Large Movie Review Dataset dataset and a demo of a transformer based model. The dataset has 25K training, and 25K test dataset, plus 50K unlabeled examples.

It's inspired on Keras' Text classification with Transformer demo.

Environment Set Up¶

Let's set up go.mod to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.

If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.

Notice the directory ${HOME}/Projects/gomlx is where the GoMLX code is copied by default in its Docker.

In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gopjrt"
%goworkfix
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".

Data Preparation¶

Downloading data files¶

To download, uncompress and untar to the local directory, simply do the following. Notice if it's already downloaded in the given --data directory, it returns immediately.

In [2]:
import (
    "github.com/gomlx/gomlx/examples/imdb"
    "github.com/gomlx/gomlx/ml/data"
    "github.com/janpfeifer/must"

    _ "github.com/gomlx/gomlx/backends/xla"
)

var (
	flagDataDir    = flag.String("data", "~/tmp/imdb", "Directory to cache downloaded and generated dataset files.")
	flagEval       = flag.Bool("eval", true, "Whether to evaluate the model on the validation data in the end.")
	flagVerbosity  = flag.Int("verbosity", 1, "Level of verbosity, the higher the more verbose.")
	flagCheckpoint = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
)

func AssertDownloaded() {
    *flagDataDir = data.ReplaceTildeInDir(*flagDataDir)
    if !data.FileExists(*flagDataDir) {
        must.M(os.MkdirAll(*flagDataDir, 0777))
    }
    must.M(imdb.Download(*flagDataDir))
}

%%
AssertDownloaded()
> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.

Sampling some examples¶

It creates a small dataset and print out some random examples.

It also defines the DType, used for all internal representations of the model, and the flag --max_len that defines the maximum number of tokens used per observation. This will beused in the modeling later.

In [3]:
import "github.com/gomlx/gomlx/examples/imdb"

%%
AssertDownloaded()
imdb.PrintSample(3)
> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
┌────────────────────────────────────────────────────────────┐
│                                                            │
│    [Sample 0 - label 1]                                    │
│    attempt to raise the funds to save the club from        │
│    becoming a poker machine haven a familiar and           │
│    successful formula that is handled well there is no     │
│    denying that the film owes it s success to the great    │
│    casting of molloy he seemed to have a great rapport     │
│    with samuel johnson and excellent chemistry with        │
│    judith lucy and while the character is probably not     │
│    a far stretch from his own personality you can t        │
│    help but wonder why he hadn t tried his arm at film     │
│    earlier to smooth out the in experienced cast the       │
│    delightful frank wilson and bill hunter support and     │
│    often steal their scenes they are two fine actors       │
│    and the pair cruise through their roles with ease       │
│    had it not been for the huge success of my big fat      │
│    greek wedding crackerjack would have made it to         │
│    number 1 at the australian box office but when you      │
│    consider what he film is about and who is involved      │
│    even making it to number 2 was an outstanding effort    │
│    all in all a witty feel good movie great cast great     │
│    crew and a great soundtrack combine to make one of      │
│    the better australian films of 2002 7 10                │
│                                                            │
│                                                            │
└────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────┐
│                                                            │
│    [Sample 1 - label 0]                                    │
│    <START> i read about this film on line and after        │
│    seeing the generally positive reviews it has            │
│    received and viewing the trailer i decided to check     │
│    it out for myself what a disappointment it starts       │
│    out well enough the opening scene was actually          │
│    pretty tense but from there it s all downhill i can     │
│    see that the filmmakers were trying to do something     │
│    different with this movie but by doing so they took     │
│    all the enjoyment out of watching it those choices      │
│    combined with the c s i editing use of music and        │
│    montage lack of suspense scares or humor really drag    │
│    this film down there s too much foreshadowing and to    │
│    many subtle clues so when the first twist arrives       │
│    early on you already know how the movie is going to     │
│    end i gave the movie three stars because i think the    │
│    cast did a good job other than that i can t             │
│    recommend this movie                                    │
│                                                            │
│                                                            │
└────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────┐
│                                                            │
│    [Sample 2 - label 1]                                    │
│    <START> the idea is not original if you have seen       │
│    such kind of story before you would know what the       │
│    ending would come out after watching for the first      │
│    twenty minutes the script the positioning of the        │
│    actors and the screening is too obvious if you haven    │
│    t seen such story before it is definitely a good        │
│    experience you will enjoy the twist at the end don t    │
│    forget to watch it again after you know the truth       │
│    you will even more enjoy the plots even though i        │
│    have a right guess at the very beginning i still        │
│    couldn t help stick on my seat till the end             │
│    conclusion a must see this one from korea is better     │
│    than any recent movies of the genre from japan          │
│    forget hollywood don t miss it                          │
│                                                            │
│                                                            │
└────────────────────────────────────────────────────────────┘

Training¶

We will create 3 different types of models for this demo: Bag of Words ("bow"), Convolutionals ("cnn") and Transformers ("transformer").

Model Configuration¶

As with other demos we leverage the context.Context object to store all model and training parameters. One can set specific parameters using the -set command line flag.

The imdb.CreateDefaultContext() method sets all the default values for the hyperparameters that may be used by any of the 3 model types. The parameter "model" specify the model type.

In [4]:
import (
    "golang.org/x/exp/maps"
    "github.com/gomlx/gomlx/ml/context"
)

// settings is bound to a "-set" flag to be used to set context hyperparameters.
var settings = commandline.CreateContextSettingsFlag(imdb.CreateDefaultContext(), "set")

// ContextFromSettings is the default context (createDefaultContext) changed by -set flag.
// It also returns the list of parameters changed by -set in paramsSet: we use this later to avoid loading over the values from checkpoints.
func ContextFromSettings() (ctx *context.Context, paramsSet []string) {
    ctx = imdb.CreateDefaultContext()
    paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))
    return ctx, paramsSet
}

%% -set="model=cnn"
fmt.Printf("Model types: %q\n", maps.Keys(imdb.ValidModels))
ctx, _ := ContextFromSettings()
fmt.Println(commandline.SprintContextSettings(ctx))
Model types: ["bow" "cnn" "transformer"]
Context hyperparameters:
	"activation": (string) 
	"adam_dtype": (string) 
	"adam_epsilon": (float64) 1e-07
	"batch_size": (int) 32
	"cnn_dropout_rate": (float64) 0.5
	"cnn_normalization": (string) 
	"cnn_num_layers": (float64) 5
	"cosine_schedule_steps": (int) 0
	"dropout_rate": (float64) 0.1
	"eval_batch_size": (int) 200
	"fnn_dropout_rate": (float64) 0.3
	"fnn_normalization": (string) 
	"fnn_num_hidden_layers": (int) 2
	"fnn_num_hidden_nodes": (int) 32
	"fnn_residual": (bool) true
	"imdb_content_max_len": (int) 200
	"imdb_include_separators": (bool) false
	"imdb_mask_word_task_weight": (float64) 0
	"imdb_max_vocab": (int) 20000
	"imdb_token_embedding_size": (int) 32
	"imdb_use_unsupervised": (bool) false
	"imdb_word_dropout_rate": (float64) 0
	"l1_regularization": (float64) 0
	"l2_regularization": (float64) 0
	"learning_rate": (float64) 0.0001
	"model": (string) cnn
	"normalization": (string) layer
	"num_checkpoints": (int) 3
	"optimizer": (string) adamw
	"plots": (bool) true
	"train_steps": (int) 5000
	"transformer_att_key_size": (int) 8
	"transformer_dropout_rate": (float64) -1
	"transformer_max_att_len": (int) 200
	"transformer_num_att_heads": (int) 2
	"transformer_num_att_layers": (int) 1

Bag Of Words Model (bow)¶

This is the simplest model we are going to train: it embeds each token of the sentence (default size of the is 32 numbers) sum them up, and pass that through a FNN.

The code in imdb.BagOfWordsModelGraph looks like this:

// BagOfWordsModelGraph builds the computation graph for the "bag of words" model: simply the sum of the embeddings
// for each token included.
func BagOfWordsModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
	embed, _ := EmbedTokensGraph(ctx, inputs[0])

	// Take the max over the content length, and put an FNN on top.
	// Shape transformation: [batch_size, content_len, embed_size] -> [batch_size, embed_size]
	embed = ReduceMax(embed, 1)
	logits := fnn.New(ctx, embed, 1).Done()
	return []*Node{logits}
}

We played a bit with the hyperparameters to get to ~85% accuracy on the validation data.

The code for imdb.TrainModel is here. It's a straight forward GoMLX training loop.

In [5]:
%% --set="model=bow;l2_regularization=1e-3;learning_rate=1e-4;normalization=none;train_steps=10000"
ctx, paramsSet := ContextFromSettings()
imdb.TrainModel(ctx, *flagDataDir, *flagCheckpoint, paramsSet, *flagEval, *flagVerbosity)
> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Backend "xla":	xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54
Model: bow
Training (10000 steps):    7% [=>......................................] (421 steps/s) [1s:22s] [step=719] [loss+=0.695] [~loss+=0.692] [~loss=0.691] [~acc=51.86%]        
Training (10000 steps):  100% [========================================] (1807 steps/s) [step=9999] [loss+=0.257] [~loss+=0.335] [~loss=0.320] [~acc=86.35%]        

Metric: accuracy

Metric: loss

	[Step 10000] median train step: 287 microseconds

Results on train-eval:
	Mean Loss+Regularization (#loss+): 0.250
	Mean Loss (#loss): 0.234
	Mean Accuracy (#acc): 91.26%
Results on test-eval:
	Mean Loss+Regularization (#loss+): 0.352
	Mean Loss (#loss): 0.337
	Mean Accuracy (#acc): 85.42%

Convolution Model (cnn)¶

The function imdb.CnnModelGraph creates a 1D convolution model, with arbitrary number of convolutions. After the convolution, it behaves the same way as the Bag Of Words model.

The core of the convolution model looks like this:

// 1D Convolution: embed is [batch_size, content_len, embed_size].
	numConvolutions := context.GetParamOr(ctx, "cnn_num_layers", 5)
	logits := embed
	for convIdx := range numConvolutions {
		ctx := ctx.Inf("%03d_conv", convIdx)
		residual := logits
		if convIdx > 0 {
			logits = NormalizeSequence(ctx, logits)
		}
		logits = layers.Convolution(ctx, embed).KernelSize(7).Filters(embedSize).Strides(1).Done()
		logits = activations.ApplyFromContext(ctx, logits)
		if dropoutNode != nil {
			logits = layers.Dropout(ctx, logits, dropoutNode)
		}
		if residual.Shape().Equal(logits.Shape()) {
			logits = Add(logits, residual)
		}
	}

	// Take the max over the content length, and put an FNN on top.
	// Shape transformation: [batch_size, content_len, embed_size] -> [batch_size, embed_size]
	logits = ReduceMax(logits, 1)
	logits = fnn.New(ctx, logits, 1).Done()
	logits.AssertDims(batchSize, 1)

Notice how well it can overfit to the training data ... but it doesn't help the test results. To improve this one needs some careful regularization.

In [6]:
%% --set="model=cnn;l2_regularization=1e-3;learning_rate=1e-4;normalization=layer;train_steps=10000"
ctx, paramsSet := ContextFromSettings()
imdb.TrainModel(ctx, *flagDataDir, *flagCheckpoint, paramsSet, *flagEval, *flagVerbosity)
> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Backend "xla":	xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54
Model: cnn
Training (10000 steps):    7% [=>......................................] (183 steps/s) [4s:50s] [step=719] [loss+=0.725] [~loss+=0.713] [~loss=0.693] [~acc=52.01%]        
Training (10000 steps):  100% [========================================] (551 steps/s) [step=9999] [loss+=0.267] [~loss+=0.101] [~loss=0.084] [~acc=97.33%]        

Metric: accuracy

Metric: loss

	[Step 10000] median train step: 999 microseconds

Results on train-eval:
	Mean Loss+Regularization (#loss+): 0.063
	Mean Loss (#loss): 0.046
	Mean Accuracy (#acc): 98.88%
Results on test-eval:
	Mean Loss+Regularization (#loss+): 0.719
	Mean Loss (#loss): 0.702
	Mean Accuracy (#acc): 83.02%

Transformer Model¶

Finally a Transformer version of the model, as defined in the "Attention Is All You Need" famous paper.

Notice it's not significantly better than our previous simple Bag-Of-Words model. Likely because there is not enough data for the transformer to make any difference. The success of transformers in large-language-models is in large part due to the training with huge amounts of unsupervised (or self-supervised) data, but that is beyond the scope of this small test.

The code is in imdb.TransformerModelGraph, and the core of it looks like this:

...
	// Add the requested number of attention layers.
	numAttLayers := context.GetParamOr(ctx, "transformer_num_att_layers", 1)
	numAttHeads := context.GetParamOr(ctx, "transformer_num_att_heads", 2)
	attKeySize := context.GetParamOr(ctx, "transformer_att_key_size", 8)
	for layerNum := range numAttLayers {
		// Each layer in its own scope.
		ctx := ctx.Inf("%03d_attention_layer", layerNum)
		residual := embed
		embed = layers.MultiHeadAttention(ctx.In("000_attention"), embed, embed, embed, numAttHeads, attKeySize).
			SetKeyMask(mask).SetQueryMask(mask).
			SetOutputDim(embedSize).
			SetValueHeadDim(embedSize).Done()
		if dropoutNode != nil {
			embed = layers.Dropout(ctx.In("001_dropout"), embed, dropoutNode)
		}
		embed = NormalizeSequence(ctx.In("002_normalization"), embed)
		attentionOutput := embed

		// Transformers recipe: 2 dense layers after attention.
		embed = fnn.New(ctx.In("003_fnn"), embed, embedSize).NumHiddenLayers(1, embedSize).Done()
		if dropoutNode != nil {
			embed = layers.Dropout(ctx.In("004_dropout"), embed, dropoutNode)
		}
		embed = Add(embed, attentionOutput)
		embed = NormalizeSequence(ctx.In("005_normalization"), embed)

		// Residual connection:
		if layerNum > 0 {
			embed = Add(residual, embed)
		}
	}
    ...

With only 5000 steps we got ~87% on the test data -- and significant overfitting as well.

In [7]:
%% --set="model=transformer;normalization=none;activation=swish;l2_regularization=1e-3;cnn_dropout_rate=0.5;fnn_dropout_rate=0.3;learning_rate=1e-4;train_steps=5000"
ctx, paramsSet := ContextFromSettings()
imdb.TrainModel(ctx, *flagDataDir, *flagCheckpoint, paramsSet, *flagEval, *flagVerbosity)
> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Backend "xla":	xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.54
Model: transformer
Training (5000 steps):   14% [====>...................................] (138 steps/s) [8s:31s] [step=724] [loss+=0.692] [~loss+=0.693] [~loss=0.692] [~acc=51.51%]        
Training (5000 steps):  100% [========================================] (146 steps/s) [step=4999] [loss+=0.271] [~loss+=0.218] [~loss=0.181] [~acc=93.53%]        

Metric: accuracy

Metric: loss

	[Step 5000] median train step: 5180 microseconds

Results on train-eval:
	Mean Loss+Regularization (#loss+): 0.200
	Mean Loss (#loss): 0.162
	Mean Accuracy (#acc): 94.19%
Results on test-eval:
	Mean Loss+Regularization (#loss+): 0.350
	Mean Loss (#loss): 0.313
	Mean Accuracy (#acc): 87.20%