IMDB Movie Review DatasetΒΆ
This is a library to download and parse the IMDB's Large Movie Review Dataset dataset and a demo of a transformer based model. The dataset has 25K training, and 25K test dataset, plus 50K unlabeled examples.
It's inspired on Keras' Text classification with Transformer demo.
Environment Set UpΒΆ
Let's set up go.mod
to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.
If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.
Notice the directory ${HOME}/Projects/gomlx
is where the GoMLX code is copied by default in its Docker.
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gopjrt"
%goworkfix
- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx". - Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".
import (
"github.com/gomlx/gomlx/examples/imdb"
"github.com/gomlx/gomlx/pkg/support/fsutil"
"github.com/janpfeifer/must"
_ "github.com/gomlx/gomlx/backends/default"
)
var (
flagDataDir = flag.String("data", "~/tmp/imdb", "Directory to cache downloaded and generated dataset files.")
flagEval = flag.Bool("eval", true, "Whether to evaluate the model on the validation data in the end.")
flagVerbosity = flag.Int("verbosity", 1, "Level of verbosity, the higher the more verbose.")
flagCheckpoint = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
)
func AssertDownloaded() {
*flagDataDir = must.M1(fsutil.ReplaceTildeInDir(*flagDataDir))
if !fsutil.MustFileExists(*flagDataDir) {
must.M(os.MkdirAll(*flagDataDir, 0777))
}
must.M(imdb.Download(*flagDataDir))
}
%%
AssertDownloaded()
> Loading previously generated preprocessed binary file. Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Sampling some examplesΒΆ
It creates a small dataset and print out some random examples.
It also defines the DType
, used for all internal representations of the model, and the flag --max_len
that defines the maximum number of tokens used per observation. This will beused in the modeling later.
import "github.com/gomlx/gomlx/examples/imdb"
%%
AssertDownloaded()
imdb.PrintSample(3)
> Loading previously generated preprocessed binary file. Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total. ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β β β [Sample 0 - label 0] β β also dresses provocatively in every scene yet curses β β like a guy she gets very emotional when she fights β β with will smith for a while while trying to defend β β all women from bad guys everywhere the men are no β β better in this movie what we see is a bunch of β β idiots trying to do anything they can to win a date β β the males in this movie are concerned with getting β β either sexual favours or unable to speak clearly β β when face to face with a woman men in real life do β β not behave this way what we see in this movie is a β β product of culture gone awry everything is flip β β flopped guys act like girls girls act like guys all β β this is done while keeping the extreme predilections β β of the sexes very much a part of the story men are β β shown as soft and stupid but only interested in sex β β most of the time while women are shown as macho and β β overbearing but only as a veneer for their emotional β β insecurities this movie would be good if it wasn t β β presented obnoxiously to the audience the content is β β not the culprit it is the manner in which the β β content is presented β β β β β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β β β [Sample 1 - label 0] β β <START> some directors take 2 and a half hours to β β tell a story david lynch takes 2 and a half hours to β β piece together scenes with clues and his trademark β β oddity but there s never a story no plot no β β progression of the characters unless you find β β revealed delusion a progression it amazes me how β β anyone can call lynch s garbage art but if beauty β β rests in the eye of the beholder so be it lynch s β β movie and tv work in the 1980 s came off as avant β β garde and alternative fine 20 years later work like β β mulholland drive comes off as a 2 5 hour david lynch β β masturbation piece it s embarrasing i ve finally β β seen the movie that takes my top spot as the worst β β ever at least the people churning out godzilla and β β rodan weren t passing them off as art β β β β β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β β β [Sample 2 - label 0] β β <START> for many year i saw this movie as a real β β movie of ninjas but after study more about this β β culture i can only think this is just another karate β β film a black shinobi and some weapons doesn t make a β β ninja it s much more than that the ninja are the β β most dangerous warrior of the japan because they are β β trained in every aspect of life to survive to β β anything killing whatever try to stop them this β β movie is not a about a ninja warrior just about a β β clown trying to be something he cannot even β β understand β β β β β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
TrainingΒΆ
We will create 3 different types of models for this demo: Bag of Words ("bow"), Convolutionals ("cnn") and Transformers ("transformer").
Model ConfigurationΒΆ
As with other demos we leverage the context.Context
object to store all model and training parameters.
One can set specific parameters using the -set
command line flag.
The imdb.CreateDefaultContext()
method sets all the default values for the hyperparameters that may be used by any of the 3 model types. The parameter "model" specify the model type.
import (
"golang.org/x/exp/maps"
"github.com/gomlx/gomlx/pkg/ml/context"
)
// settings is bound to a "-set" flag to be used to set context hyperparameters.
var settings = commandline.CreateContextSettingsFlag(imdb.CreateDefaultContext(), "set")
// ContextFromSettings is the default context (createDefaultContext) changed by -set flag.
// It also returns the list of parameters changed by -set in paramsSet: we use this later to avoid loading over the values from checkpoints.
func ContextFromSettings() (ctx *context.Context, paramsSet []string) {
ctx = imdb.CreateDefaultContext()
paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))
return ctx, paramsSet
}
%% -set="model=cnn"
fmt.Printf("Model types: %q\n", maps.Keys(imdb.ValidModels))
ctx, _ := ContextFromSettings()
fmt.Println(commandline.SprintContextSettings(ctx))
Model types: ["bow" "cnn" "transformer"] "/activation": (string) "/adam_dtype": (string) "/adam_epsilon": (float64) 1e-07 "/batch_size": (int) 32 "/cnn_dropout_rate": (float64) 0.5 "/cnn_normalization": (string) "/cnn_num_layers": (float64) 5 "/cosine_schedule_steps": (int) 0 "/dropout_rate": (float64) 0.1 "/eval_batch_size": (int) 200 "/fnn_dropout_rate": (float64) 0.3 "/fnn_normalization": (string) "/fnn_num_hidden_layers": (int) 2 "/fnn_num_hidden_nodes": (int) 32 "/fnn_residual": (bool) true "/imdb_content_max_len": (int) 200 "/imdb_include_separators": (bool) false "/imdb_mask_word_task_weight": (float64) 0 "/imdb_max_vocab": (int) 20000 "/imdb_token_embedding_size": (int) 32 "/imdb_use_unsupervised": (bool) false "/imdb_word_dropout_rate": (float64) 0 "/l1_regularization": (float64) 0 "/l2_regularization": (float64) 0 "/learning_rate": (float64) 0.0001 "/model": (string) cnn "/normalization": (string) layer "/num_checkpoints": (int) 3 "/optimizer": (string) adamw "/plots": (bool) true "/train_steps": (int) 5000 "/transformer_att_key_size": (int) 8 "/transformer_dropout_rate": (float64) -1 "/transformer_max_att_len": (int) 200 "/transformer_num_att_heads": (int) 2 "/transformer_num_att_layers": (int) 1
Bag Of Words Model (bow)ΒΆ
This is the simplest model we are going to train: it embeds each token of the sentence (default size of the is 32 numbers) sum them up, and pass that through a FNN.
The code in imdb.BagOfWordsModelGraph
looks like this:
// BagOfWordsModelGraph builds the computation graph for the "bag of words" model: simply the sum of the embeddings
// for each token included.
func BagOfWordsModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
embed, _ := EmbedTokensGraph(ctx, inputs[0])
// Take the max over the content length, and put an FNN on top.
// Shape transformation: [batch_size, content_len, embed_size] -> [batch_size, embed_size]
embed = ReduceMax(embed, 1)
logits := fnn.New(ctx, embed, 1).Done()
return []*Node{logits}
}
We played a bit with the hyperparameters to get to ~85% accuracy on the validation data.
The code for imdb.TrainModel
is here.
It's a straight forward GoMLX training loop.
%% --set="model=bow;l2_regularization=1e-3;learning_rate=1e-4;normalization=none;train_steps=10000"
ctx, paramsSet := ContextFromSettings()
imdb.TrainModel(ctx, *flagDataDir, *flagCheckpoint, paramsSet, *flagEval, *flagVerbosity)
> Loading previously generated preprocessed binary file. Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total. Backend "stablehlo": stablehlo:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO] Model: bow
7% [=>......................................] (639 steps/s) [0s:14s] [step=719] [loss+=0.785] [~loss+=0.794] [~loss=0.682] [~acc=55.80%]
100% [========================================] (2365 steps/s) [step=9999] [loss+=0.32] [~loss+=0.33] [~loss=0.279] [~acc=88.72%] %]
Metric: accuracy
Metric: loss
[Step 10000] median train step: 217 microseconds Results on train-eval: Mean Loss+Regularization (#loss+): 0.245 Mean Loss (#loss): 0.194 Mean Accuracy (#acc): 93.39% Results on test-eval: Mean Loss+Regularization (#loss+): 0.384 Mean Loss (#loss): 0.333 Mean Accuracy (#acc): 85.35%
Convolution Model (cnn)ΒΆ
The function imdb.CnnModelGraph
creates a 1D convolution model, with arbitrary number of convolutions. After the convolution, it behaves the same way as the Bag Of Words model.
The core of the convolution model looks like this:
// 1D Convolution: embed is [batch_size, content_len, embed_size].
numConvolutions := context.GetParamOr(ctx, "cnn_num_layers", 5)
logits := embed
for convIdx := range numConvolutions {
ctx := ctx.Inf("%03d_conv", convIdx)
residual := logits
if convIdx > 0 {
logits = NormalizeSequence(ctx, logits)
}
logits = layers.Convolution(ctx, embed).KernelSize(7).Filters(embedSize).Strides(1).Done()
logits = activations.ApplyFromContext(ctx, logits)
if dropoutNode != nil {
logits = layers.Dropout(ctx, logits, dropoutNode)
}
if residual.Shape().Equal(logits.Shape()) {
logits = Add(logits, residual)
}
}
// Take the max over the content length, and put an FNN on top.
// Shape transformation: [batch_size, content_len, embed_size] -> [batch_size, embed_size]
logits = ReduceMax(logits, 1)
logits = fnn.New(ctx, logits, 1).Done()
logits.AssertDims(batchSize, 1)
Notice how well it can overfit to the training data ... but it doesn't help the test results. To improve this one needs some careful regularization.
%% --set="model=cnn;l2_regularization=1e-3;learning_rate=1e-4;normalization=layer;train_steps=10000"
ctx, paramsSet := ContextFromSettings()
imdb.TrainModel(ctx, *flagDataDir, *flagCheckpoint, paramsSet, *flagEval, *flagVerbosity)
> Loading previously generated preprocessed binary file. Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total. Backend "stablehlo": stablehlo:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO] Model: cnn
7% [=>......................................] (425 steps/s) [1s:21s] [step=719] [loss+=1.39] [~loss+=1.39] [~loss=0.736] [~acc=52.33%]
100% [========================================] (942 steps/s) [step=9999] [loss+=0.25] [~loss+=0.273] [~loss=0.0894] [~acc=97.09%] 8%]
Metric: accuracy
Metric: loss
[Step 10000] median train step: 580 microseconds Results on train-eval: Mean Loss+Regularization (#loss+): 0.218 Mean Loss (#loss): 0.0358 Mean Accuracy (#acc): 99.08% Results on test-eval: Mean Loss+Regularization (#loss+): 0.893 Mean Loss (#loss): 0.71 Mean Accuracy (#acc): 83.41%
Transformer ModelΒΆ
Finally a Transformer version of the model, as defined in the "Attention Is All You Need" famous paper.
Notice it's not significantly better than our previous simple Bag-Of-Words model. Likely because there is not enough data for the transformer to make any difference. The success of transformers in large-language-models is in large part due to the training with huge amounts of unsupervised (or self-supervised) data, but that is beyond the scope of this small test.
The code is in imdb.TransformerModelGraph
, and the core of it looks like this:
...
// Add the requested number of attention layers.
numAttLayers := context.GetParamOr(ctx, "transformer_num_att_layers", 1)
numAttHeads := context.GetParamOr(ctx, "transformer_num_att_heads", 2)
attKeySize := context.GetParamOr(ctx, "transformer_att_key_size", 8)
for layerNum := range numAttLayers {
// Each layer in its own scope.
ctx := ctx.Inf("%03d_attention_layer", layerNum)
residual := embed
embed = layers.MultiHeadAttention(ctx.In("000_attention"), embed, embed, embed, numAttHeads, attKeySize).
SetKeyMask(mask).SetQueryMask(mask).
SetOutputDim(embedSize).
SetValueHeadDim(embedSize).Done()
if dropoutNode != nil {
embed = layers.Dropout(ctx.In("001_dropout"), embed, dropoutNode)
}
embed = NormalizeSequence(ctx.In("002_normalization"), embed)
attentionOutput := embed
// Transformers recipe: 2 dense layers after attention.
embed = fnn.New(ctx.In("003_fnn"), embed, embedSize).NumHiddenLayers(1, embedSize).Done()
if dropoutNode != nil {
embed = layers.Dropout(ctx.In("004_dropout"), embed, dropoutNode)
}
embed = Add(embed, attentionOutput)
embed = NormalizeSequence(ctx.In("005_normalization"), embed)
// Residual connection:
if layerNum > 0 {
embed = Add(residual, embed)
}
}
...
With only 5000 steps we got ~87% on the test data -- and significant overfitting as well.
%% --set="model=transformer;normalization=none;activation=swish;l2_regularization=1e-3;cnn_dropout_rate=0.5;fnn_dropout_rate=0.3;learning_rate=1e-4;train_steps=5000"
ctx, paramsSet := ContextFromSettings()
imdb.TrainModel(ctx, *flagDataDir, *flagCheckpoint, paramsSet, *flagEval, *flagVerbosity)
> Loading previously generated preprocessed binary file. Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total. Backend "stablehlo": stablehlo:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.76 [StableHLO] Model: transformer
14% [====>...................................] (418 steps/s) [2s:10s] [step=724] [loss+=1.34] [~loss+=1.54] [~loss=1.14] [~acc=50.92%]
100% [========================================] (697 steps/s) [step=4999] [loss+=0.541] [~loss+=0.626] [~loss=0.286] [~acc=89.05%]
Metric: accuracy
Metric: loss
[Step 5000] median train step: 764 microseconds Results on train-eval: Mean Loss+Regularization (#loss+): 0.616 Mean Loss (#loss): 0.278 Mean Accuracy (#acc): 90.83% Results on test-eval: Mean Loss+Regularization (#loss+): 0.685 Mean Loss (#loss): 0.346 Mean Accuracy (#acc): 86.06%