Cifar Library and Demo¶
This is a library to download and parse the Cifar datasets (Cifar-10 and Cifar-100), and a very small demo of a FNN (Feedforward Neural Network) with GoMLX. FNNs are notoriously bad for images, but it's only a demo. Look for the Resnet50 model for a more serious image classification model (old but still good -- best results as of the time of this writing is with ViT model).
The CIFAR-10 and CIFAR-100 are labeled subsets of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. See more details in its homepage here.
This notebook serves as documentation and example for the github.com/gomlx/gomlx/examples/cifar library, and the demo code in one piece can be seen in .../examples/cifar/demo/
Environment Set Up¶
Let's set up go.mod
to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.
If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.
Notice the directory ${HOME}/Projects/gomlx
is where the GoMLX code is copied by default in its Docker.
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx"
%goworkfix
- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
import (
"github.com/gomlx/gomlx/examples/cifar"
"github.com/gomlx/gomlx/ml/data"
"github.com/janpfeifer/must"
)
var flagDataDir = flag.String("data", "~/work/cifar", "Directory to cache downloaded and generated dataset files.")
func AssertDownloaded() {
*flagDataDir = data.ReplaceTildeInDir(*flagDataDir)
if !data.FileExists(*flagDataDir) {
must.M(os.MkdirAll(*flagDataDir, 0777))
}
must.M(cifar.DownloadCifar10(*flagDataDir))
must.M(cifar.DownloadCifar100(*flagDataDir))
}
%%
AssertDownloaded()
!ls -lh ~/work/cifar/
total 52K -rw-r--r-- 1 janpf janpf 358 Aug 2 10:32 a -rw-r--r-- 1 janpf janpf 3.7K Aug 2 10:25 b drwxr-x--- 2 janpf janpf 4.0K Nov 12 17:04 base_cnn_model drwxr-x--- 2 janpf janpf 4.0K Nov 12 17:06 base_fnn_model drwxr-x--- 2 janpf janpf 4.0K Nov 12 16:56 base_kan_model drwxr-xr-x 2 janpf janpf 4.0K Feb 20 2010 cifar-100-binary drwxr-xr-x 2 janpf janpf 4.0K Jun 4 2009 cifar-10-batches-bin drwxr-x--- 2 janpf janpf 4.0K Aug 2 09:51 cnn drwxr-x--- 2 janpf janpf 4.0K Aug 1 17:28 cnn_layer drwxr-x--- 2 janpf janpf 4.0K Jul 31 19:19 cnn_nonorm drwxr-x--- 2 janpf janpf 4.0K Jul 31 08:28 fnn_batchnorm_0 drwxr-x--- 2 janpf janpf 4.0K Aug 1 14:24 fnn_layer drwxr-x--- 2 janpf janpf 4.0K Nov 12 16:41 r001
Sample some images¶
The cifar.NewDataset
creates a data.InMemoryDataset
that can be used both for training, evaluation, or just to sample a few examples, which we do below:
import (
"strings"
"github.com/gomlx/gopjrt/dtypes"
"github.com/gomlx/gomlx/backends"
"github.com/gomlx/gomlx/examples/cifar"
"github.com/gomlx/gomlx/types/shapes"
"github.com/gomlx/gomlx/types/tensors/images"
"github.com/janpfeifer/gonb/gonbui"
_ "github.com/gomlx/gomlx/backends/xla"
)
var (
// Model DType, used everywhere.
DType = dtypes.Float32
)
// sampleToNotebook generates a sample of Cifar-10 and Cifar-100 in a GoNB Jupyter Notebook.
func sampleToNotebook() {
// Load data into tensors.
backend := backends.New()
ds10 := cifar.NewDataset(backend, "Samples Cifar-10", *flagDataDir, cifar.C10, DType, cifar.Train).Shuffle()
ds100 := cifar.NewDataset(backend, "Samples Cifar-100", *flagDataDir, cifar.C100, DType, cifar.Train).Shuffle()
sampleImages(ds10, 8, cifar.C10Labels)
sampleImages(ds100, 8, cifar.C100FineLabels)
}
// sampleTable generates and outputs one html table of samples, sampling rows x cols from the images/labels provided.
func sampleImages(ds train.Dataset, numImages int, labelNames []string) {
gonbui.DisplayHTML(fmt.Sprintf("<p>%s</p>\n", ds.Name()))
parts := make([]string, 0, numImages+5) // Leave last part empty.
parts = append(parts, "<table><tr>")
for ii := 0; ii < numImages; ii++ {
_, inputs, labels := must.M3(ds.Yield())
imgTensor := inputs[0]
img := images.ToImage().Single(imgTensor)
label := labels[0].Value().([]int64)
labelStr := labelNames[label[0]]
imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))
size := imgTensor.Shape().Dimensions[0]
parts = append(
parts,
fmt.Sprintf(`<td><figure style="padding:4px;text-align: center;"><img width="%d" height="%d" src="%s">` +
`<figcaption style="text-align: center;">%s (%d)</figcaption></figure></td>`,
size*2, size*2, imgSrc, labelStr, label),
)
}
parts = append(parts, "</tr></table>", "")
gonbui.DisplayHTML(strings.Join(parts, "\n"))
}
%%
AssertDownloaded()
sampleToNotebook()
Samples Cifar-10
Samples Cifar-100
import (
"flags"
"github.com/gomlx/gomlx/ml/layers"
"github.com/gomlx/gomlx/ml/layers/fnn"
"github.com/gomlx/gomlx/ml/layers/kan"
"github.com/gomlx/gomlx/ml/layers/regularizers"
"github.com/gomlx/gomlx/ui/commandline"
"github.com/gomlx/gomlx/ml/train/optimizers"
"github.com/gomlx/gomlx/examples/cifar"
"github.com/gomlx/gomlx/ml/context"
)
var (
// ValidModels is the list of model types supported.
ValidModels = []string{"fnn", "kan", "cnn"}
flagEval = flag.Bool("eval", true, "Whether to evaluate the model on the validation data in the end.")
flagVerbosity = flag.Int("verbosity", 1, "Level of verbosity, the higher the more verbose.")
)
// settings is bound to a "-set" flag to be used to set context hyperparameters.
var settings = commandline.CreateContextSettingsFlag(createDefaultContext(), "set")
// createDefaultContext sets the context with default hyperparameters
func createDefaultContext() *context.Context {
ctx := context.New()
ctx.RngStateReset()
ctx.SetParams(map[string]any{
// Model type to use: valid values are fnn, kan and cnn.
"model": cifar.C10ValidModels[0],
"checkpoint": "",
"num_checkpoints": 3,
"train_steps": 3000,
// batch_size for training.
"batch_size": 64,
// eval_batch_size can be larger than training, it's more efficient.
"eval_batch_size": 200,
// "plots" trigger generating intermediary eval data for plotting, and if running in GoNB, to actually
// draw the plot with Plotly.
plotly.ParamPlots: true,
// If "normalization" is set, it overrides "fnn_normalization" and "cnn_normalization".
layers.ParamNormalization: "none",
optimizers.ParamOptimizer: "adamw",
optimizers.ParamLearningRate: 1e-4,
optimizers.ParamAdamEpsilon: 1e-7,
optimizers.ParamAdamDType: "",
activations.ParamActivation: "swish",
layers.ParamDropoutRate: 0.0,
regularizers.ParamL2: 1e-5,
regularizers.ParamL1: 1e-5,
// FNN network parameters:
fnn.ParamNumHiddenLayers: 8,
fnn.ParamNumHiddenNodes: 128,
fnn.ParamResidual: true,
fnn.ParamNormalization: "", // Set to none for no normalization, otherwise it falls back to layers.ParamNormalization.
fnn.ParamDropoutRate: -1.0, // Set to 0.0 for no dropout, otherwise it falls back to layers.ParamDropoutRate.
// KAN network parameters:
kan.ParamNumControlPoints: 10, // Number of control points
kan.ParamNumHiddenNodes: 64,
kan.ParamNumHiddenLayers: 4,
kan.ParamBSplineDegree: 2,
kan.ParamBSplineMagnitudeL1: 1e-5,
kan.ParamBSplineMagnitudeL2: 0.0,
kan.ParamDiscrete: false,
kan.ParamDiscreteSoftness: 0.1,
kan.ParamResidual: true,
})
return ctx
}
// ContextFromSettings is the default context (createDefaultContext) changed by -set flag.
func ContextFromSettings() (ctx *context.Context, paramsSet []string) {
ctx = createDefaultContext()
paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))
return
}
// Let's test that we can set hyperparameters by setting it in the "-set" flag:
%% -set="batch_size=17;fnn_num_hidden_layers=12"
fmt.Printf("Model types: %q\n", cifar.C10ValidModels)
ctx, parametersSet := ContextFromSettings()
fmt.Printf("Parameters set (-set): %q\n", parametersSet)
fmt.Println(commandline.SprintContextSettings(ctx))
Model types: ["fnn" "kan" "cnn"] Parameters set (-set): ["batch_size" "fnn_num_hidden_layers"] "/activation": (string) swish "/adam_dtype": (string) "/adam_epsilon": (float64) 1e-07 "/batch_size": (int) 17 "/checkpoint": (string) "/dropout_rate": (float64) 0 "/eval_batch_size": (int) 200 "/fnn_dropout_rate": (float64) -1 "/fnn_normalization": (string) "/fnn_num_hidden_layers": (int) 12 "/fnn_num_hidden_nodes": (int) 128 "/fnn_residual": (bool) true "/kan_bspline_degree": (int) 2 "/kan_bspline_magnitude_l1": (float64) 1e-05 "/kan_bspline_magnitude_l2": (float64) 0 "/kan_discrete": (bool) false "/kan_discrete_softness": (float64) 0.1 "/kan_num_hidden_layers": (int) 4 "/kan_num_hidden_nodes": (int) 64 "/kan_num_points": (int) 10 "/kan_residual": (bool) true "/l1_regularization": (float64) 1e-05 "/l2_regularization": (float64) 1e-05 "/learning_rate": (float64) 0.0001 "/model": (string) fnn "/normalization": (string) none "/num_checkpoints": (int) 3 "/optimizer": (string) adamw "/plots": (bool) true "/train_steps": (int) 3000
Simple FNN model¶
A trivial model, that can easily get to ~50% accuracy (a random model would do 10%), but hardly much more than that.
Later we are going to define a CNN model to compare, and we just set a placeholder model here for now.
Note:
- The code is here just to exemplify. We are actually using the same code from the
cifar
package.
import (
"flags"
. "github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/examples/cifar"
"github.com/gomlx/gomlx/ml/context"
"github.com/gomlx/gomlx/ml/train/optimizers"
"github.com/gomlx/gomlx/types/shapes"
)
var _ = NewGraph // Make sure the graph package is in use.
// C10PlainModelGraph implements train.ModelFn, and returns the logit Node, given the input image.
// It's a basic FNN (Feedforward Neural Network), so no convolutions. It is meant only as an example.
func C10PlainModelGraph(ctx *context.Context, spec any, inputs []*graph.Node) []*graph.Node {
_ = spec // Not needed, the input type is always the same.
ctx = ctx.In("model")
batchedImages := inputs[0]
batchSize := batchedImages.Shape().Dimensions[0]
logits := graph.Reshape(batchedImages, batchSize, -1)
numClasses := len(cifar.C10Labels)
modelType := context.GetParamOr(ctx, "model", cifar.C10ValidModels[0])
if modelType == "kan" {
// Configuration of the KAN layer(s) use the context hyperparameters.
logits = kan.New(ctx, logits, numClasses).Done()
} else {
// Configuration of the FNN layer(s) use the context hyperparameters.
logits = fnn.New(ctx, logits, numClasses).Done()
}
logits.AssertDims(batchSize, numClasses)
return []*graph.Node{logits}
}
%% -set="batch_size=7"
// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.
AssertDownloaded()
ctx, _ := ContextFromSettings()
g := NewGraph(backends.New(), "placeholder")
batchSize := context.GetParamOr(ctx, "batch_size", int(100))
logits := C10PlainModelGraph(ctx, nil, []*Node{Parameter(g, "images", shapes.Make(DType, batchSize, cifar.Height, cifar.Width, cifar.Depth))})
fmt.Printf("Logits shape for batch_size=%d: %s\n", batchSize, logits[0].Shape())
Logits shape for batch_size=7: (Float32)[7 10]
Training Loop¶
With a model function defined, we use the training loop create for the Cifar10.
The trainer is provided in the cifar
package. It is straight forward (and almost the same for every different project) and does the following for us:
- If a checkpoing is given (--checkpoint) and it has previously saved model, it loads hyperparmeters and trained variables.
- Create trainer: with selected model function (see Simple FNN model and CNN model for Cifar10 sections), optimizer, loss and metrics.
- Create a
train.Loop
and attach to it a progressbar, a periodic checkpoint saver and a plotter (--set="plots=true"
). - Train the selected number of train steps.
- Report results.
Below we train 50 steps with the default settings just to check things are working.
var flagCheckpoint = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
// trainModel with hyperparameters configured with `-set=...`.
func trainModel() {
ctx, paramsSet := ContextFromSettings()
cifar.TrainCifar10Model(ctx, *flagDataDir, *flagCheckpoint, *flagEval, *flagVerbosity, paramsSet)
}
// Train 50 steps, only to test things are working. No plots.
%% --set="train_steps=50;plots=false"
trainModel()
Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.55 Training (50 steps): 100% [========================================] (1250 steps/s) [step=49] [loss+=2.230] [~loss+=2.361] [~loss=2.229] [~acc=18.09%] [Step 50] median train step: 773 microseconds Results on Validation: Mean Loss+Regularization (#loss+): 2.271 Mean Loss (#loss): 2.139 Mean Accuracy (#acc): 24.39% Results on Training: Mean Loss+Regularization (#loss+): 2.269 Mean Loss (#loss): 2.138 Mean Accuracy (#acc): 24.14%
FNN Model Training¶
Let's train the FNN for real this time.
- Note: The FNN model quickly overfits to the data.
// Remove a previously trained model. Skip this if you want to continue training a previous model.
!rm -rf ~/work/cifar/base_fnn_model
%% --checkpoint=base_fnn_model --set="model=fnn;train_steps=50_000;plots=true"
trainModel()
Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.55 Checkpointing model to "/home/janpf/work/cifar/base_fnn_model"
Training (50000 steps): 1% [........................................] (455 steps/s) [0s:1m48s] [step=699] [loss+=1.824] [~loss+=1.928] [~loss=1.798] [~acc=36.38%]
Training (50000 steps): 100% [========================================] (1191 steps/s) [step=49999] [loss+=0.960] [~loss+=0.913] [~loss=0.770] [~acc=72.39%] ]
Metric: accuracy
Metric: loss
[Step 50000] median train step: 718 microseconds Results on Validation: Mean Loss+Regularization (#loss+): 1.597 Mean Loss (#loss): 1.454 Mean Accuracy (#acc): 54.70% Results on Training: Mean Loss+Regularization (#loss+): 0.857 Mean Loss (#loss): 0.714 Mean Accuracy (#acc): 74.90%
CNN model for Cifar-10¶
Let's now properly define a CNN model to compare.
The model was built following a Keras model in Kaggle (thanks @ektasharm), which provided hardcoded values for all layers of the model -- so it doesn't make use of the hyperparamters set in the context. It achieves ~86% on the validation set, after 80,000 steps of batch size 64 (~100 epochs).
Notice that since it uses batch normalization, the training process will, at the end, update the moving averages of mean and variance: this improves the running estimate the model keeps during training. This interesting blog post explains about it.
// ConvolutionModelGraph implements train.ModelFn and returns the logit Node, given the input image.
// It's a straight forward CNN (Convolution Neural Network) model.
//
// This is modeled after the Keras example in Kaggle:
// https://www.kaggle.com/code/ektasharma/simple-cifar10-cnn-keras-code-with-88-accuracy
// (Thanks @ektasharma)
func ConvolutionModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
ctx = ctx.In("model")
batchedImages := inputs[0]
g := batchedImages.Graph()
dtype := batchedImages.DType()
batchSize := batchedImages.Shape().Dimensions[0]
logits := batchedImages
layerIdx := 0
nextCtx := func(name string) *context.Context {
newCtx := ctx.Inf("%03d_%s", layerIdx, name)
layerIdx++
return newCtx
}
logits = layers.Convolution(nextCtx("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 32, 32, 32)
logits = activations.Relu(logits)
logits = batchnorm.New(nextCtx("batchnorm"), logits, -1).Done()
logits = layers.Convolution(nextCtx("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
logits = activations.Relu(logits)
logits = batchnorm.New(nextCtx("batchnorm"), logits, -1).Done()
logits = MaxPool(logits).Window(2).Done()
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, Scalar(g, dtype, 0.3), true)
logits.AssertDims(batchSize, 16, 16, 32)
logits = layers.Convolution(nextCtx("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 16, 16, 64)
logits = activations.Relu(logits)
logits = batchnorm.New(nextCtx("batchnorm"), logits, -1).Done()
logits = layers.Convolution(nextCtx("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 16, 16, 64)
logits = activations.Relu(logits)
logits = batchnorm.New(nextCtx("batchnorm"), logits, -1).Done()
logits = MaxPool(logits).Window(2).Done()
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, Scalar(g, dtype, 0.5), true)
logits.AssertDims(batchSize, 8, 8, 64)
logits = layers.Convolution(nextCtx("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 8, 8, 128)
logits = activations.Relu(logits)
logits = batchnorm.New(nextCtx("batchnorm"), logits, -1).Done()
logits = layers.Convolution(nextCtx("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
logits.AssertDims(batchSize, 8, 8, 128)
logits = activations.Relu(logits)
logits = batchnorm.New(nextCtx("batchnorm"), logits, -1).Done()
logits = MaxPool(logits).Window(2).Done()
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, Scalar(g, dtype, 0.5), true)
logits.AssertDims(batchSize, 4, 4, 128)
// Flatten logits, and we can use the usual FNN/KAN.
logits = Reshape(logits, batchSize, -1)
logits = layers.Dense(nextCtx("dense"), logits, true, 128)
logits = activations.Relu(logits)
logits = batchnorm.New(nextCtx("batchnorm"), logits, -1).Done()
logits = layers.DropoutNormalize(nextCtx("dropout"), logits, Scalar(g, dtype, 0.5), true)
numClasses := len(cifar.C10Labels)
logits = layers.Dense(nextCtx("dense"), logits, true, numClasses)
return []*Node{logits}
}
%% -set="batch_size=11"
// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.
AssertDownloaded()
ctx, _ := ContextFromSettings()
g := NewGraph(backends.New(), "placeholder")
batchSize := context.GetParamOr(ctx, "batch_size", int(100))
logits := ConvolutionModelGraph(ctx, nil, []*Node{Parameter(g, "images", shapes.Make(DType, batchSize, cifar.Height, cifar.Width, cifar.Depth))})
fmt.Printf("Logits shape for batch_size=%d: %s\n", batchSize, logits[0].Shape())
Logits shape for batch_size=11: (Float32)[11 10]
Training the CNN model¶
CNNs have a much better inductive bias for machine learning on images, and it can easily achieve > 80% accuracy in training data, and some less on validation data, due to overfitting.
Likely it would benefit from pre-training the model on a larger unlabeled datasets -- see the "Dogs vs Cats" example to see transfer learning in action for an image model.
!rm -rf ~/work/cifar/base_cnn_model
%% --checkpoint=base_cnn_model --set="model=cnn;learning_rate=1e-3;l2_regularization=0;l1_regularization=0;train_steps=80000"
trainModel()
Backend "xla": xla:cuda - PJRT "cuda" plugin (/usr/local/lib/gomlx/pjrt/pjrt_c_api_cuda_plugin.so) v0.55 Checkpointing model to "/home/janpf/work/cifar/base_cnn_model"
Training (80000 steps): 0% [........................................] (175 steps/s) [4s:7m32s] [step=719] [loss+=1.832] [~loss+=1.745] [~loss=1.745] [~acc=35.49%]
Training (80000 steps): 100% [========================================] (240 steps/s) [step=79999] [loss+=0.870] [~loss+=0.641] [~loss=0.641] [~acc=78.02%] 4%]
Metric: accuracy
Metric: loss
[Step 80000] median train step: 3742 microseconds Results on Validation: Mean Loss+Regularization (#loss+): 0.584 Mean Loss (#loss): 0.584 Mean Accuracy (#acc): 80.74% Results on Training: Mean Loss+Regularization (#loss+): 0.369 Mean Loss (#loss): 0.369 Mean Accuracy (#acc): 87.19%
Inference¶
Inference, or serving the model, is done by using the same code as used to train the model. That is, currently the way to save the model is to export the Go model creation function, along with the checkpoint with learned weights.
We created a small library cifar/classifier
that takes an image as input, convert it to a tensor and calls the trained model
import (
"encoding/base64"
"image/png"
"github.com/gomlx/gomlx/examples/cifar/classifier"
// We also must import then engine that will execute the model.
// Currently only XLA is supported.
_ "github.com/gomlx/gomlx/backends/xla"
)
%%
// Decode and print PNG image.
imgBase64 := bytes.NewBufferString(
"iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI5UlEQVR4nCxWSZMc1dV988upMmvo6uqWWvrUgFDwGWEMAZhwEAZ7YXvrrX+CN1765+GVg43DCwIhbAdCotVTVeX88o33OarxWWfcOPnuGS77458+Q4xWy0U1r7wF75CZIEaOEUUIDcNQ1zWlNIBnjAYIR0dHi9VKTZr5iAMYBEJKjpCx02azmi/mVrtJaWttjLGc5WyVPAwYUc1FWFo3OuMRIt4CISgEO01T13Xeey54VVVlXiUsGXZtbxSO2Acvi9wo58dJqR7jQCjBETvnxnFUatrvt+ze8SOPgAkZMFijjI7WGIxpWeYyyX3wp6enXdcpbTCi3gTlh7Efsipbnhw5FLU1hUwV4DxPEMa73Y5i2rUDAHDO0zRjxTw/TMWx67px1M55732aJsYq65AQAgAwxk3TqVFvqoVHlAG8tb7nGL1otlab0werfC26vhGCpFnS1v00TTHGIs+r+Zw5ZExQ+6692e66znjviiIDIEoF57yU0vsAENNE1tt9GHTGkv9/6+2TxfHXL//VDQ1F5NWLl6v1ilJKCBFCIjTEiBjn5XzBOCdmckM/9Z1q26FpGoyJ9zBNpm0H52A+X3KeVOX8yfnDk9W8HdX8wfkHX/z+P5fbq6axMWDOt7vm6nKbJrM0LTHiQqSE8ojZvu2ud3vGuXDWq1FBgKIoZrNZCMFab61zzl9dXRNCnPOJkKcPz4NY/fKLP3zy+e+e/fMfTT/gFCOHY8RdNxjtACBCZIzneaGN2W13CCHW32FUIwAImXjvrbV3+0kRQtbasiwRQsoBYunx2fGgyYvXzRSinkxCeKCwXK5RjJeXN3mRZJls+65rO8EYBrRarch2ux3GASEUY0QIOXcg4r03+gBKKQAYY6xz2tim7V+8uLq57SxCRrv2tnHa+RDrup2UybNZ2w510xymeZCUIetYjHD3JtYYmxclQhhhhCKCGKdpAgCtdZ7nszIbVJhUNKO22iXFzFrnxyEC3e46PaknTx5vNqfW2n1XV/OKuIC073c7+vS9pz4AIgRjioE7G4zWd/SnEBzGiDGWSJ6nSSJTpT0iPMuzJCFGjRSR4EPbDYQy66wQsppXMYLTJjpPIDKC6Re/+e3q6LgoyxiIVbFv+6Ht1DCg6DECKXmepcH7u5cjL1/92PZ7jIMU+e3tHkFs673RBlHmnEsSWVYlAdTtmuAsZ0xKwYw2Zhh7ddDo0I6jGkMwjBFKCKOcYu7twZMR0/2ujtqCMeryqmWcU7IfBmsdwXgaRylZs9+1y/J0vTbjkCVilsk8TVj0h8QZ+8EZbf0Qo8UkMk6dA8Z5kc8PciIkUjpNbc6kHVWom5rhPE8XVdnvbiilpeSr1YJitLu+miUcYVBqEBQQOPr+06eT1sZo66zzh6AIHsVI8qzcbO4vl2uCGaUiIIIxAeNU01qtgEQEQRz+lKVSSEE5p3kmrZmGoevapqn3nGHOKBv7wXsPHhgmgifWxBABxwRFMQ6aURXv9Cvy9PTk/qWDtqmbtvENkpzYg9igH3pGQAoavJvUMIxBCi45JQQRghiR0gfoxskHiIitlqdt21tjtzdbNYyJkPTgZEcJdoTM8kwfLfuhmYYGryrOBCMoOQzzlCEpudYMvMMEECYBRxM8sQhdbbej0s5BkhTL5TocXKa8NUZNzXan+sFpjazptrcMoaPFIhUSeevN5K1eVLN79zZpLiY9dn0XABjjANH64AFjJpkZVfSO4sAoWVV5V98EZ4TgEMIwdgHc5vgYIeTDwRllWXIuAnhKKcYYIjjns0LM0CzLMn8HSuSkpjxPpJilScmCtlWaQsJWq2VRyLa+KQoJmEUUrFHD0GIMaZomqdjttk27n81m3tsqz7xzwzBmARWzo7IsASBJkhijFKm1Xggxn8+rcsaIhypL15t7y0X546sLisysSPvpwCWAizF0XTOOfV6kITgAb8xkjd4s5oSQpm4Q5ofgskoIwRhPEqmUgkOBCwBr3UTPVieUACPBWbXf3gBCiPB+0sPQoxhmZRERaDMJwU5PT4TkxugQfJ5lj99+ks/Kru+l5HW7k1IeAt7am9vX1zcX1ukAVqmeea9PVuuEQd/U2mqgiXa2GYaA/HozJ4RhErWZ8jJbHi+MMUmRQADwcFnfzBdL0bBtty3nM0cRZoCJK2a5dwaCF+yQpPTs3oYSoNHXTV0rjXliPLSjWi7K9apKk1SNkxrV6f2TYpYTSmMEhOKuri9eXxCKN5u1sdrYgy4IBs5JdAGFKIVMZJImKauOl5PXHBFeljISQqgzapYmp+t1Kgmj4ofvX2VptqzmyEPCeVrN67qOzj16cHa0WHCC33njzZffv8wR3SyP8iyJER0aUErOOaGUHW+OIvhU8izJqm6ajDXu8vH9B5vjlRCEUYmBfPfdv9948H/r9YoQIqV89uxZSunHn3xSzmbGmkzmD49Ov33+za8+/DjNJMKE/CRiAIQR++C9pygiQinngrP0xQ8/YEQ+//VnAD7EkMjs/Xd/8eWXf9vfbD96/+eUsn29d2r69KMP33v3Z9aarusikOLRo+fPv7m4uDg/f5im+U91K4TgnNO//uXP82qxXh1V1XK9Ptnebpfz6s3zc++sPHwiBePHR+tn33w9tF2WpV/9/SvB2aeffoQjaD3dbRIhQrWZur6hjNX7elJ6HIbt7c4YyyLCnDGEMYlYj8N6udA6ubi4MEYLISnVAA0h5PHbbzVte7PbdePw7vk7r68ujdGMHXomIpSm+f2zY4AgBItAnAdt7Diqvh/Zt8+/i3fAmNT7drFYtm3b9/2dJw8NhTEehgFhRBm9eH3x4OEDLuT1zbUxuqrmVVUyxjFGhHBKJaVM8JySO/cZjzFhXdcZY0IIMaLXF9fX1zdFUQghQggBwkEwMVprXfBZli2XyyRJtJ44EwRTSph3gWCGCfYu6EkbY+dzRDBTSmGMhBSMMfa/jUd8dnYWQqCUpmn604Hkvc/zXEoJKBZFcTiQlOIHylQIPJ9XhJAYo1JKa51lWTnLrLHWqjzPGWPe+/8GAAD//7swHNFia9prAAAAAElFTkSuQmCC")
imgPNG := must.M1(io.ReadAll(base64.NewDecoder(base64.StdEncoding, imgBase64)))
img := must.M1(png.Decode(bytes.NewBuffer(imgPNG)))
size := img.Bounds()
gonbui.DisplayHTML(fmt.Sprintf("<p>Image: (%d x %d)</p>", size.Dx(), size.Dy()))
gonbui.DisplayPNG(imgPNG)
// Classify:
c10Classifier := must.M1(classifier.New("~/work/cifar/base_fnn_model"))
classID := must.M1(c10Classifier.Classify(img))
className := cifar.C10Labels[classID]
gonbui.DisplayHTML(fmt.Sprintf("<p>Class: <b>%s (%d)</b></p>", className, classID))
Image: (32 x 32)
Class: horse (7)
Generate a random image as base64¶
//
%%
backend := backends.New()
ds := cifar.NewDataset(backend, "Samples Cifar-10", *flagDataDir, cifar.C10, DType, cifar.Train).Shuffle()
_, inputs, _ := must.M3(ds.Yield())
imgTensor := inputs[0]
img := images.ToImage().Single(imgTensor)
imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))
imgBase64 := imgSrc[22:] // Strip the preamble for a <img> src tag.
fmt.Printf("%s\n\n", imgBase64)
size := imgTensor.Shape().Dimensions[0]
gonbui.DisplayHTML(fmt.Sprintf(`<img width="%d" height="%d" src="%s"/>`, size*2, size*2, imgSrc))
iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAKPklEQVR4nDyV2ZNdx0HGezt9tnvufu/cWe7M6M5oFskz8kjCkuINZDkiCdjYlCtFyENY/AfwDEUVSwEPvMALCfBkEooHCAaSAlJJOastWcJaGFkzmhnNaLa7b+eevTfquqh8Vd0P/dBd3d/v+5r80XffXrD1VxfXfcze+9c7t2/Hy2tZRyPbWzRL9VpheiAzH999eLLzcLZcqmbsNECWgics/rQ3nDnnXL+21Ghsvnaj9j9bT1OFiQfbhy+sXVCcPNs6HMTxMJDIa8udntuI/dEAtPfD+UnLRljjuoWRikILcRMDHdOik8kbRporksWBERPAFY8Dr18qjhSxWgBxbPdbid8BiYvOlGfK9mLZXnDbdQJ4QSX9nd1toOUSPfHjtmPYFBearl4wrECRUMDBKOgH0rFUluLyACmps1z++rVXoeknZpcq6vVbbuBSmh7GsQ+9ZnKA85oF5OAuJOmi6jQTMVEYeV2Q2KFbJ9DEAJxfPUeiECsRu52sFsG8Y5tYqIRTzBjXWDCNcWl+JZXdsQD33EEhuxQNfRTWNZwcH1Hs06Z/amALl9crQgxy5RLiyYc/Ol46a1Occ9Sy24uibt9CsHl8eutnP+206g4leY04jqlsWxKj26t36/VSNSSOvds+HQzc7km3PJHNZFPJYNay5LPWnoZSpJBLTWWpjMKIJf1+SKhz0mgAZ27/WWNxojZZndbtPM2U2q2mhfD03PTi4nQIFQbItKQmaRRuZfLUbumfHu/oNMPCoBI5D+/dW1otIoTDYEhqpQXPe0KKaO8QVCbp5Gzu8e7Tg+0fp7NLwzhiMpldrL30K++EYfB0c7N9/Ky9fTgadaafP3du7flctlTvVTXTY4kJVBYgDal4trQwmGkSiYup/NFwixTM1WLG6sQPOr2hldKabY9HBCJoOVYYhTvbW9ORmlu9bJimPxqe3Lr/uRuvXF+/oZWmy9UJTGGuvKRAeKb62usvAim4gjHA/NJzYRxFo+j0Qq1BUnY5nUn1mo9KFcj7qXt32/0jN2GoUAghYzq12odHP/juv0MIjjY3Z85MXvn1X+W26dfdZtt1I06Up5CUCiglKaUAEAgJQppulItWNWtHJJdOUWKO+sTr0XarHQQRVFRxub35OI3M2tra02fbn7z/Xytnl3/tCy9efPFa2clwQuSsnig2PPUYU5KDIPEkBCwIJIeYIQWZFAxDLYlDYigUjHzCzzy4fXdY55U5p+V3HTMd+AIRKCBiSmGkvvTOy1oW1eMubLd0ImwnTahupUyKqOJQ51TXCOdKAKYTEwooeKAUCsOQePGoeXo0M7GoApotQIiBTslo6Bl6DgAlpZAcYwvvxXeO7vtLy+Gs75gUTJWnquWKEtgdBfs7p8XJdDrtaBoZRuK41Yd9v7ZQdjKmJjh5srcTdptWqTpsxzML9nDgxUGCgE6xJhPOGAMSAKq5zNYJswxH6hZKaYEksYRcSMGh22Pd3rFSWNet2MYHPhp8/Mh1a7YDLpyvojsf3t3f3VUJ1zT92ZOO35c8ghrSOEuoRqQSQEHDzvCkXLSmaJLauv/IbfeYn/SGfYAVgDIJYeJpgQuPj7oDj/eYmioVAjfqNFwhBDrcPeo2W5xFtmOPupAH5JWXb3LGoYJASc6442Q2NjbOn60tLW1kMpO3P/jR/sNHMIq5iACUAPLN/93udX0pRMpRVorGkk1XHBZEaSMnpSIMJKaeQxDBWGNCWFr54tUv3v3Jh5qgtmVcvvLy2bPLZ88tZ8o5zuXuSXjz9W4ydGOW+P7QGneHvrA6xTgyLSIYLTrmuqEXkfKS2MlbCCFCCaQ5R0gRuDGWUIY88pSpZwVnX3333Tfefsu2TCFByLgAKoWjX/rF1/qDgW5qAWsoqDSKr9/cCOMIQUUQgbpaQbol4+pcTtM0HvrEtiwtbQGpICBJmLQazbs/ucdiQChqBcGP79y9uLFBse6HUavfj9vdp22mDL1iGEyYCo2/BU0nNsUIYSk54AIJjjSsYaFUBAHEv/zGlzQTz1dm3RG/8sLFds/r1yMhB1ba4FQHWFtbXzMp0TFOWHJy3CjNLqQcJ/TcerdfLpcwxJIrwRFUVAipAGUYM4l5IjCiQAn0yvUbSMBSrpCr5J+7tFqcnszatq7jfKFgmSnLtJUESPGMSecniySV9yOfeUPLwAnnYZQkDAIggOJCCQG5F448vzMa9hgLxwgAQCrlKYOalcrUb33lywf7e++89cX7H/xs7erXwjBpAo4hxYgjlfr+93/4yf2HN97+jVEiLKpHzMfUaLlhxcG6DgGQEDIghUbh++99k0D4ztfehZBJyAkQaH/v1PPClz93OWPapXIpaWz/zu/+5q1bd0KT1AeDvX/5B1GZ/4Pf/+u9vV2lZ15/402MsUIojlzHntcpVkABgJRSQCFv5P7n+98hWH7+rS8Xc3nOIhK44fmVixox4iTJZDOUkNXV5wiSyyu1h1vbifCajx5vP35yfNjQNOOD//5ebemsN/Lave5+vfvC+eeBZAwgJSBSCiJ0cvCMahoAsn56lE+XMTLI1V+49NJLVwlQnMX9fr9YyJ9b39CI4Q7Y7Y82j06PcqnU5idPsWOldPPo4PAv//QvGo1WypmcqC7EX4iohRHGSgEIkZKqcXqazecxNTqdThgGgkXINgnRFMGYMdZqtYKIPdo9YkJSA8/PTfZOmz+8v/do9+DazZtvfuWrQkq37dkkzcNw0DoZhF4iMUvUyPOHQ3c4HoOu23OKk1EswmgUJwFJBI9jRj57RPXZNZ+/dBlCmUnrF1bOzP3ebzcOBn/19a8vnl0GCdNMK/HjKAwED5DBwihOGj0IkW7oBIxporr93JXXqtUzL13ZWKzNmIZG/vk7HxQy6eWlasY2EUK2qacsQ0jRa47+/m/ec6NBsVB2B67bOt3feaoU7PTqScKw4k7B8dudSnEiX8l/dPtWMGxdurD6vf/41szq58N+6+q5Kcdxxpj6fX5hdbJSLokk0cb+ACmlksrJpqaqM3/7J98yDVoolTWNTNcWnjx+8Oq1VyFEE1OZmWqxWsAYdLY++vAbf/5ngql7axf3Htwf9Xkja/7xHw5WVlbXX7gCn+w1SqW0aRCg1Mh1C4XC/58BwJMne2+8+XZlomiZ+fTEbCpX+cG//dO1S2tCqoizwG02223P80ZewLgUSiWCjTsYpQnBlpWtTs9cvLwOuRAAyDHKCiCEFADj2EjxeHv/G3/3zW9/+x91CnhMADERdoadehh0YhYDgAGQCFuEEKQZRDfS2XQ2l5uYmDgzP1urVZcWV+bnZgs5g0CEEEBgvP9Y6LMZI2SnMwLYVmqm1ToAgukGytj20uq6ZuiZTKZ2ZmbghzyGExPlarUyV52amMzn8vliIZ/LaDoGPxf86c7hOCOfxRFCOAZpvKx0XQtCfuvjh3t7e4LjdNopFnOVyYplmUkSL1TzCeNIKdPQia4TqkmgYi4451KMLYRSgnGXyv8LAAD//5qlqJXHo301AAAAAElFTkSuQmCC