Cifar Library and Demo¶

This is a library to download and parse the Cifar datasets (Cifar-10 and Cifar-100), and a very small demo of a FNN (Feedforward Neural Network) with GoMLX. FNNs are notoriously bad for images, but it's only a demo. Look for the Resnet50 model for a more serious image classification model (old but still good -- best results as of the time of this writing is with ViT model).

The CIFAR-10 and CIFAR-100 are labeled subsets of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. See more details in its homepage here.

This notebook serves as documentation and example for the github.com/gomlx/gomlx/examples/cifar library, and the demo code in one piece can be seen in .../examples/cifar/demo/

Data Preparation¶

Downloading data files¶

To download, uncompress and untar to the local directory, simply do the following. Notice if it's already downloaded in the given --data directory, it returns immediately.

In [1]:
import (
    "github.com/gomlx/gomlx/examples/cifar"
    "github.com/gomlx/gomlx/support/fsutil"
    "github.com/janpfeifer/must"
)

var flagDataDir = flag.String("data", "~/work/cifar", "Directory to cache downloaded and generated dataset files.")

func AssertDownloaded() {
    *flagDataDir = must.M1(fsutil.ReplaceTildeInDir(*flagDataDir))
    if !fsutil.MustFileExists(*flagDataDir) {
        must.M(os.MkdirAll(*flagDataDir, 0777))
    }

    must.M(cifar.DownloadCifar10(*flagDataDir))
    must.M(cifar.DownloadCifar100(*flagDataDir))
}

%%
AssertDownloaded()
In [2]:
!ls -lh ~/work/cifar/
total 48K
drwxr-x--- 2 janpf janpf 4.0K Jun  7 10:23 base_cnn_model
drwxr-x--- 2 janpf janpf 4.0K Jun  7 10:23 base_fnn_model
drwxr-x--- 2 janpf janpf 4.0K Nov 12  2024 base_kan_model
drwxr-xr-x 2 janpf janpf 4.0K Jun  4  2009 cifar-10-batches-bin
drwxr-xr-x 2 janpf janpf 4.0K Feb 20  2010 cifar-100-binary
drwxr-x--- 2 janpf janpf 4.0K Aug  2  2024 cnn
drwxr-x--- 2 janpf janpf 4.0K Aug  1  2024 cnn_layer
drwxr-x--- 2 janpf janpf 4.0K Jul 31  2024 cnn_nonorm
drwxr-x--- 2 janpf janpf 4.0K Jul 31  2024 fnn_batchnorm_0
drwxr-x--- 2 janpf janpf 4.0K Aug  1  2024 fnn_layer
drwxr-x--- 2 janpf janpf 4.0K Nov 12  2024 kan
drwxr-x--- 2 janpf janpf 4.0K Nov 12  2024 r001

Sample some images¶

The cifar.NewDataset creates a data.InMemoryDataset that can be used both for training, evaluation, or just to sample a few examples, which we do below:

In [3]:
import (
    "strings"
    "github.com/gomlx/compute/dtypes"
    "github.com/gomlx/compute"
    "github.com/gomlx/gomlx/examples/cifar"
    "github.com/gomlx/compute/shapes"
    "github.com/gomlx/gomlx/core/tensors/images"
    "github.com/janpfeifer/gonb/gonbui"

    _ "github.com/gomlx/gomlx/backends/default"
)

var (
    // Model DType, used everywhere.
    DType = dtypes.Float32
)

// sampleToNotebook generates a sample of Cifar-10 and Cifar-100 in a GoNB Jupyter Notebook.
func sampleToNotebook() {
    // Load data into tensors.
    backend := compute.MustNew()
    ds10 := cifar.NewDataset(backend, "Samples Cifar-10", *flagDataDir, cifar.C10, DType, cifar.Train).Shuffle()
    ds100 := cifar.NewDataset(backend, "Samples Cifar-100", *flagDataDir, cifar.C100, DType, cifar.Train).Shuffle()
    sampleImages(ds10, 8, cifar.C10Labels)
    sampleImages(ds100, 8, cifar.C100FineLabels)
}

// sampleTable generates and outputs one html table of samples, sampling rows x cols from the images/labels provided.
func sampleImages(ds train.Dataset, numImages int, labelNames []string) {
    gonbui.DisplayHTML(fmt.Sprintf("<p>%s</p>\n", ds.Name()))
    
    parts := make([]string, 0, numImages+5) // Leave last part empty.
    parts = append(parts, "<table><tr>")
    ii := 0
    for batch, err := range ds.Iter() {
        if ii >= numImages {
            break
        }
        must.M(err)
        inputs, labels := batch.Inputs, batch.Labels
        imgTensor := inputs[0]
        img := images.ToImage().Single(imgTensor)
        label := labels[0].Value().([]int64)
        labelStr := labelNames[label[0]]
    
        imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))
        size := imgTensor.Shape().Dimensions[0]
        parts = append(
            parts, 
            fmt.Sprintf(`<td><figure style="padding:4px;text-align: center;"><img width="%d" height="%d" src="%s">` + 
                        `<figcaption style="text-align: center;">%s (%d)</figcaption></figure></td>`, 
                        size*2, size*2, imgSrc, labelStr, label),
        )
        ii++
    }
    parts = append(parts, "</tr></table>", "")
    gonbui.DisplayHTML(strings.Join(parts, "\n"))
}

%%
AssertDownloaded()
sampleToNotebook()

Samples Cifar-10

No description has been provided for this image
bird ([2])
No description has been provided for this image
airplane ([0])
No description has been provided for this image
horse ([7])
No description has been provided for this image
airplane ([0])
No description has been provided for this image
deer ([4])
No description has been provided for this image
airplane ([0])
No description has been provided for this image
cat ([3])
No description has been provided for this image
airplane ([0])

Samples Cifar-100

No description has been provided for this image
orange ([53])
No description has been provided for this image
tulip ([92])
No description has been provided for this image
cup ([28])
No description has been provided for this image
rocket ([69])
No description has been provided for this image
beetle ([7])
No description has been provided for this image
fox ([34])
No description has been provided for this image
elephant ([31])
No description has been provided for this image
bed ([5])

Training on Cifar-10¶

Models Support¶

  1. flagModel defines the model type, out of validModels options.
  2. createModelStore creates a model.Store and sets the default values for the CIFAR models.
  3. StoreFromSettings uses createModelStore and incorporates changes passed by the -set flag.
In [4]:
import (
    "flags"
    
    "github.com/gomlx/gomlx/ml/layers"
    "github.com/gomlx/gomlx/ml/layers/fnn"
    "github.com/gomlx/gomlx/ml/layers/kan"
    "github.com/gomlx/gomlx/ml/layers/regularizer"
    "github.com/gomlx/gomlx/ui/commandline"
    "github.com/gomlx/gomlx/ml/train/optimizer"
    "github.com/gomlx/gomlx/examples/cifar"
    "github.com/gomlx/gomlx/ml/model"
)

var (
    // ValidModels is the list of model types supported.
    ValidModels = []string{"fnn", "kan", "cnn"}
	flagEval      = flag.Bool("eval", true, "Whether to evaluate the model on the validation data in the end.")
	flagVerbosity = flag.Int("verbosity", 1, "Level of verbosity, the higher the more verbose.")
)

// settings is bound to a "-set" flag to be used to set model hyperparameters.
var settings = commandline.CreateSettingsFlag(createModelStore(), "set")

// createModelStore sets the store with default hyperparameters
func createModelStore() *model.Store {
	store := model.NewStore()
	store.SetParams(map[string]any{
        // Model type to use: valid values are fnn, kan and cnn.
		"model":           cifar.C10ValidModels[0],
		"checkpoint":      "",
		"num_checkpoints": 3,
		"train_steps":     3000,

		// batch_size for training.
		"batch_size": 64,

		// eval_batch_size can be larger than training, it's more efficient.
		"eval_batch_size": 200,

		// "plots" trigger generating intermediary eval data for plotting, and if running in GoNB, to actually
		// draw the plot with Plotly.
		plotly.ParamPlots: true,

		// If "normalization" is set, it overrides "fnn_normalization" and "cnn_normalization".
		layers.ParamNormalization: "none",

		optimizer.ParamOptimizer:           "adamw",
		optimizer.ParamLearningRate:        1e-4,
		optimizer.ParamAdamEpsilon:         1e-7,
		optimizer.ParamAdamDType:           "",
		activation.ParamActivation:         "swish",
		layers.ParamDropoutRate:             0.0,
		regularizer.ParamL2:                1e-5,
		regularizer.ParamL1:                1e-5,

		// FNN network parameters:
		fnn.ParamNumHiddenLayers: 8,
		fnn.ParamNumHiddenNodes:  128,
		fnn.ParamResidual:        true,
		fnn.ParamNormalization:   "",   // Set to none for no normalization, otherwise it falls back to layers.ParamNormalization.
		fnn.ParamDropoutRate:     -1.0, // Set to 0.0 for no dropout, otherwise it falls back to layers.ParamDropoutRate.

		// KAN network parameters:
		kan.ParamNumControlPoints:   10, // Number of control points
		kan.ParamNumHiddenNodes:     64,
		kan.ParamNumHiddenLayers:    4,
		kan.ParamBSplineDegree:      2,
		kan.ParamBSplineMagnitudeL1: 1e-5,
		kan.ParamBSplineMagnitudeL2: 0.0,
		kan.ParamDiscrete:           false,
		kan.ParamDiscreteSoftness:   0.1,
        kan.ParamResidual:           true,
	})
	return store
}

// StoreFromSettings is the default store (createModelStore) changed by -set flag.
func StoreFromSettings() (store *model.Store, paramsSet []string) {
    store = createModelStore()
    paramsSet = must.M1(commandline.ParseSettings(store, *settings))
    return
}

// Let's test that we can set hyperparameters by setting it in the "-set" flag:
%% -set="batch_size=17;fnn_num_hidden_layers=12"
fmt.Printf("Model types: %q\n", cifar.C10ValidModels)
store, parametersSet := StoreFromSettings()
fmt.Printf("Parameters set (-set): %q\n", parametersSet)
fmt.Println(commandline.SprintSettings(store.RootScope()))
Model types: ["fnn" "kan" "cnn"]
Parameters set (-set): ["batch_size" "fnn_num_hidden_layers"]
	"/activation": (string) swish
	"/adam_dtype": (string) 
	"/adam_epsilon": (float64) 1e-07
	"/batch_size": (int) 17
	"/checkpoint": (string) 
	"/dropout_rate": (float64) 0
	"/eval_batch_size": (int) 200
	"/fnn_dropout_rate": (float64) -1
	"/fnn_normalization": (string) 
	"/fnn_num_hidden_layers": (int) 12
	"/fnn_num_hidden_nodes": (int) 128
	"/fnn_residual": (bool) true
	"/kan_bspline_degree": (int) 2
	"/kan_bspline_magnitude_l1": (float64) 1e-05
	"/kan_bspline_magnitude_l2": (float64) 0
	"/kan_discrete": (bool) false
	"/kan_discrete_softness": (float64) 0.1
	"/kan_num_hidden_layers": (int) 4
	"/kan_num_hidden_nodes": (int) 64
	"/kan_num_points": (int) 10
	"/kan_residual": (bool) true
	"/l1_regularization": (float64) 1e-05
	"/l2_regularization": (float64) 1e-05
	"/learning_rate": (float64) 0.0001
	"/model": (string) fnn
	"/normalization": (string) none
	"/num_checkpoints": (int) 3
	"/optimizer": (string) adamw
	"/plots": (bool) true
	"/train_steps": (int) 3000

Simple FNN model¶

A trivial model, that can easily get to ~50% accuracy (a random model would do 10%), but hardly much more than that.

Later we are going to define a CNN model to compare, and we just set a placeholder model here for now.

Note:

  • The code is here just to exemplify. We are actually using the same code from the cifar package.
In [5]:
import (
    "flags"    
    . "github.com/gomlx/gomlx/core/graph"
    "github.com/gomlx/gomlx/examples/cifar"
    "github.com/gomlx/gomlx/ml/model"
    "github.com/gomlx/gomlx/ml/train/optimizer"
    "github.com/gomlx/compute/shapes"
)

var _ = NewGraph  // Make sure the graph package is in use.

// C10PlainModel implements train.ModelFn (compatible), and returns the logit Node, given the input image.
// It's a basic FNN (Feedforward Neural Network), so no convolutions. It is meant only as an example.
func C10PlainModel(scope *model.Scope, batchedImages *graph.Node) *graph.Node {
	scope = scope.In("model")
	batchSize := batchedImages.Shape().Dimensions[0]
	logits := graph.Reshape(batchedImages, batchSize, -1)
	numClasses := len(cifar.C10Labels)
	modelType := model.GetParamOr(scope, "model", cifar.C10ValidModels[0])
	switch modelType {
    case "kan":
		// Configuration of the KAN layer(s) use the store hyperparameters.
		logits = kan.New(scope, logits, numClasses).Done()
	case "fnn":
		// Configuration of the FNN layer(s) use the store hyperparameters.
		logits = fnn.New(scope, logits, numClasses).Done()
	}
	logits.AssertDims(batchSize, numClasses)
	return logits
}

%% -set="batch_size=7"
// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.
AssertDownloaded()
store, _ := StoreFromSettings()
g := NewGraph(compute.MustNew(), "placeholder")
model.SetStore(g, store)
batchSize := model.GetRootParamOr(store, "batch_size", int(100))
logits := C10PlainModel(store.RootScope(), Parameter(g, "images", shapes.Make(DType, batchSize, cifar.Height, cifar.Width, cifar.Depth)))
fmt.Printf("Logits shape for batch_size=%d: %s\n", batchSize, logits.Shape())
Logits shape for batch_size=7: (Float32)[7, 10]

Training Loop¶

With a model function defined, we use the training loop create for the Cifar10.

The trainer is provided in the cifar package. It is straight forward (and almost the same for every different project) and does the following for us:

  • If a checkpoing is given (--checkpoint) and it has previously saved model, it loads hyperparmeters and trained variables.
  • Create trainer: with selected model function (see Simple FNN model and CNN model for Cifar10 sections), optimizer, loss and metrics.
  • Create a train.Loop and attach to it a progressbar, a periodic checkpoint saver and a plotter (--set="plots=true").
  • Train the selected number of train steps.
  • Report results.

Below we train 50 steps with the default settings just to check things are working.

In [6]:
var flagCheckpoint = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")

// trainModel with hyperparameters configured with `-set=...`.
func trainModel() {
    store, paramsSet := StoreFromSettings()
    cifar.TrainCifar10(store, *flagDataDir, *flagCheckpoint, *flagEval, *flagVerbosity, paramsSet)
}

// Train 50 steps, only to test things are working. No plots.
%% --set="train_steps=50;plots=false"
trainModel()
Backend "xla":	xla:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/go-xla/nvidia/pjrt_c_api_cuda_plugin.so) v0.100 [StableHLO] [1 device(s)]
       100% [========================================] (2245 steps/s) [step=49] [loss=2.36] [~loss=2.98] [~acc=17.53%]                 
	[Step 50] median train step: 363 microseconds

Results on Validation:
	Mean Loss (#loss): 2.36
	Mean Accuracy (#acc): 22.67%
Results on Training:
	Mean Loss (#loss): 2.36
	Mean Accuracy (#acc): 23.16%

FNN Model Training¶

Let's train the FNN for real this time.

  • Note: The FNN model quickly overfits to the data.
In [7]:
// Remove a previously trained model. Skip this if you want to continue training a previous model.
!rm -rf ~/work/cifar/base_fnn_model  
In [8]:
%% --checkpoint=base_fnn_model --set="model=fnn;train_steps=50_000;plots=true"
trainModel()
Backend "xla":	xla:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/go-xla/nvidia/pjrt_c_api_cuda_plugin.so) v0.100 [StableHLO] [1 device(s)]
Checkpointing model to "/home/janpf/work/cifar/base_fnn_model"
         1% [........................................] (370 steps/s) [0s:2m13s] [step=699] [loss=1.95] [~loss=1.98] [~acc=37.21%]        
       100% [========================================] (2321 steps/s) [step=49999] [loss=0.549] [~loss=0.68] [~acc=84.15%]        %]        

Metric: accuracy

Metric: loss

	[Step 50000] median train step: 332 microseconds

Results on Validation:
	Mean Loss (#loss): 2.45
	Mean Accuracy (#acc): 50.64%
Results on Training:
	Mean Loss (#loss): 0.659
	Mean Accuracy (#acc): 84.85%

CNN model for Cifar-10¶

Let's now properly define a CNN model to compare.

The model was built following a Keras model in Kaggle (thanks @ektasharm), which provided hardcoded values for all layers of the model -- so it doesn't make use of the hyperparameters set in the model.Store. It achieves ~86% on the validation set, after 80,000 steps of batch size 64 (~100 epochs).

Notice that since it uses batch normalization, the training process will, at the end, update the moving averages of mean and variance: this improves the running estimate the model keeps during training. This interesting blog post explains about it.

In [9]:
// ConvolutionModelGraph implements train.ModelFn and returns the logit Node, given the input image.
// It's a straight forward CNN (Convolution Neural Network) model.
//
// This is modeled after the Keras example in Kaggle:
// https://www.kaggle.com/code/ektasharma/simple-cifar10-cnn-keras-code-with-88-accuracy
// (Thanks @ektasharma)
func ConvolutionModel(scope *model.Scope, batchedImages *Node) *Node {
	scope = scope.In("model")
	g := batchedImages.Graph()
	dtype := batchedImages.DType()
	batchSize := batchedImages.Shape().Dimensions[0]
	logits := batchedImages

	layerIdx := 0
	nextScope := func(name string) *model.Scope {
		newScope := scope.In("%03d_%s", layerIdx, name)
		layerIdx++
		return newScope
	}

	logits = layers.Convolution(nextScope("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
	logits.AssertDims(batchSize, 32, 32, 32)
	logits = activation.Relu(logits)
	logits = norm.BatchNorm(nextScope("batchnorm"), logits, -1).Done()
	logits = layers.Convolution(nextScope("conv"), logits).Filters(32).KernelSize(3).PadSame().Done()
	logits = activation.Relu(logits)
	logits = norm.BatchNorm(nextScope("batchnorm"), logits, -1).Done()
	logits = MaxPool(logits).Window(2).Done()
	logits = layers.DropoutNormalize(nextScope("dropout"), logits, Scalar(g, dtype, 0.3), true)
	logits.AssertDims(batchSize, 16, 16, 32)

	logits = layers.Convolution(nextScope("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
	logits.AssertDims(batchSize, 16, 16, 64)
	logits = activation.Relu(logits)
	logits = norm.BatchNorm(nextScope("batchnorm"), logits, -1).Done()
	logits = layers.Convolution(nextScope("conv"), logits).Filters(64).KernelSize(3).PadSame().Done()
	logits.AssertDims(batchSize, 16, 16, 64)
	logits = activation.Relu(logits)
	logits = norm.BatchNorm(nextScope("batchnorm"), logits, -1).Done()
	logits = MaxPool(logits).Window(2).Done()
	logits = layers.DropoutNormalize(nextScope("dropout"), logits, Scalar(g, dtype, 0.5), true)
	logits.AssertDims(batchSize, 8, 8, 64)

	logits = layers.Convolution(nextScope("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
	logits.AssertDims(batchSize, 8, 8, 128)
	logits = activation.Relu(logits)
	logits = norm.BatchNorm(nextScope("batchnorm"), logits, -1).Done()
	logits = layers.Convolution(nextScope("conv"), logits).Filters(128).KernelSize(3).PadSame().Done()
	logits.AssertDims(batchSize, 8, 8, 128)
	logits = activation.Relu(logits)
	logits = norm.BatchNorm(nextScope("batchnorm"), logits, -1).Done()
	logits = MaxPool(logits).Window(2).Done()
	logits = layers.DropoutNormalize(nextScope("dropout"), logits, Scalar(g, dtype, 0.5), true)
	logits.AssertDims(batchSize, 4, 4, 128)

	// Flatten logits, and we can use the usual FNN/KAN.
	logits = Reshape(logits, batchSize, -1)
	logits = layers.Dense(nextScope("dense"), logits, true, 128)
	logits = activation.Relu(logits)
	logits = norm.BatchNorm(nextScope("batchnorm"), logits, -1).Done()
	logits = layers.DropoutNormalize(nextScope("dropout"), logits, Scalar(g, dtype, 0.5), true)
	numClasses := len(cifar.C10Labels)
	logits = layers.Dense(nextScope("dense"), logits, true, numClasses)
	return logits
}

%% -set="batch_size=11"
// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.
AssertDownloaded()
store, _ := StoreFromSettings()
g := NewGraph(compute.MustNew(), "placeholder")
model.SetStore(g, store)
batchSize := model.GetRootParamOr(store, "batch_size", int(100))
logits := ConvolutionModel(store.RootScope(), Parameter(g, "images", shapes.Make(DType, batchSize, cifar.Height, cifar.Width, cifar.Depth)))
fmt.Printf("Logits shape for batch_size=%d: %s\n", batchSize, logits.Shape())
Logits shape for batch_size=11: (Float32)[11, 10]

Training the CNN model¶

CNNs have a much better inductive bias for machine learning on images, and it can easily achieve > 80% accuracy in training data, and some less on validation data, due to overfitting.

Likely it would benefit from pre-training the model on a larger unlabeled datasets -- see the "Dogs vs Cats" example to see transfer learning in action for an image model.

In [10]:
!rm -rf ~/work/cifar/base_cnn_model
In [11]:
%% --checkpoint=base_cnn_model --set="model=cnn;learning_rate=1e-3;l2_regularization=0;l1_regularization=0;train_steps=20_000"
trainModel()
Backend "xla":	xla:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/go-xla/nvidia/pjrt_c_api_cuda_plugin.so) v0.100 [StableHLO] [1 device(s)]
Checkpointing model to "/home/janpf/work/cifar/base_cnn_model"
         3% [>.......................................] (516 steps/s) [1s:37s] [step=719] [loss=1.74] [~loss=1.7] [~acc=36.38%]           
       100% [========================================] (960 steps/s) [step=19999] [loss=0.663] [~loss=0.722] [~acc=75.12%]        %]        

Metric: accuracy

Metric: loss

	[Step 20000] median train step: 818 microseconds

Results on Validation:
	Mean Loss (#loss): 0.704
	Mean Accuracy (#acc): 76.42%
Results on Training:
	Mean Loss (#loss): 0.546
	Mean Accuracy (#acc): 80.44%

Inference¶

Inference, or serving the model, is done by using the same code as used to train the model. That is, currently the way to save the model is to export the Go model creation function, along with the checkpoint with learned weights.

We created a small library cifar/classifier that takes an image as input, convert it to a tensor and calls the trained model

In [12]:
import (
    "encoding/base64"
    "image/png"
    
    "github.com/gomlx/gomlx/examples/cifar/classifier"
    // We also must import then engine that will execute the model.
    // Currently only XLA is supported.
    _ "github.com/gomlx/gomlx/backends/default"
)

%%
// Decode and print PNG image.
imgBase64 := bytes.NewBufferString(
    "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAI5UlEQVR4nCxWSZMc1dV988upMmvo6uqWWvrUgFDwGWEMAZhwEAZ7YXvrrX+CN1765+GVg43DCwIhbAdCotVTVeX88o33OarxWWfcOPnuGS77458+Q4xWy0U1r7wF75CZIEaOEUUIDcNQ1zWlNIBnjAYIR0dHi9VKTZr5iAMYBEJKjpCx02azmi/mVrtJaWttjLGc5WyVPAwYUc1FWFo3OuMRIt4CISgEO01T13Xeey54VVVlXiUsGXZtbxSO2Acvi9wo58dJqR7jQCjBETvnxnFUatrvt+ze8SOPgAkZMFijjI7WGIxpWeYyyX3wp6enXdcpbTCi3gTlh7Efsipbnhw5FLU1hUwV4DxPEMa73Y5i2rUDAHDO0zRjxTw/TMWx67px1M55732aJsYq65AQAgAwxk3TqVFvqoVHlAG8tb7nGL1otlab0werfC26vhGCpFnS1v00TTHGIs+r+Zw5ZExQ+6692e66znjviiIDIEoF57yU0vsAENNE1tt9GHTGkv9/6+2TxfHXL//VDQ1F5NWLl6v1ilJKCBFCIjTEiBjn5XzBOCdmckM/9Z1q26FpGoyJ9zBNpm0H52A+X3KeVOX8yfnDk9W8HdX8wfkHX/z+P5fbq6axMWDOt7vm6nKbJrM0LTHiQqSE8ojZvu2ud3vGuXDWq1FBgKIoZrNZCMFab61zzl9dXRNCnPOJkKcPz4NY/fKLP3zy+e+e/fMfTT/gFCOHY8RdNxjtACBCZIzneaGN2W13CCHW32FUIwAImXjvrbV3+0kRQtbasiwRQsoBYunx2fGgyYvXzRSinkxCeKCwXK5RjJeXN3mRZJls+65rO8EYBrRarch2ux3GASEUY0QIOXcg4r03+gBKKQAYY6xz2tim7V+8uLq57SxCRrv2tnHa+RDrup2UybNZ2w510xymeZCUIetYjHD3JtYYmxclQhhhhCKCGKdpAgCtdZ7nszIbVJhUNKO22iXFzFrnxyEC3e46PaknTx5vNqfW2n1XV/OKuIC073c7+vS9pz4AIgRjioE7G4zWd/SnEBzGiDGWSJ6nSSJTpT0iPMuzJCFGjRSR4EPbDYQy66wQsppXMYLTJjpPIDKC6Re/+e3q6LgoyxiIVbFv+6Ht1DCg6DECKXmepcH7u5cjL1/92PZ7jIMU+e3tHkFs673RBlHmnEsSWVYlAdTtmuAsZ0xKwYw2Zhh7ddDo0I6jGkMwjBFKCKOcYu7twZMR0/2ujtqCMeryqmWcU7IfBmsdwXgaRylZs9+1y/J0vTbjkCVilsk8TVj0h8QZ+8EZbf0Qo8UkMk6dA8Z5kc8PciIkUjpNbc6kHVWom5rhPE8XVdnvbiilpeSr1YJitLu+miUcYVBqEBQQOPr+06eT1sZo66zzh6AIHsVI8qzcbO4vl2uCGaUiIIIxAeNU01qtgEQEQRz+lKVSSEE5p3kmrZmGoevapqn3nGHOKBv7wXsPHhgmgifWxBABxwRFMQ6aURXv9Cvy9PTk/qWDtqmbtvENkpzYg9igH3pGQAoavJvUMIxBCi45JQQRghiR0gfoxskHiIitlqdt21tjtzdbNYyJkPTgZEcJdoTM8kwfLfuhmYYGryrOBCMoOQzzlCEpudYMvMMEECYBRxM8sQhdbbej0s5BkhTL5TocXKa8NUZNzXan+sFpjazptrcMoaPFIhUSeevN5K1eVLN79zZpLiY9dn0XABjjANH64AFjJpkZVfSO4sAoWVV5V98EZ4TgEMIwdgHc5vgYIeTDwRllWXIuAnhKKcYYIjjns0LM0CzLMn8HSuSkpjxPpJilScmCtlWaQsJWq2VRyLa+KQoJmEUUrFHD0GIMaZomqdjttk27n81m3tsqz7xzwzBmARWzo7IsASBJkhijFKm1Xggxn8+rcsaIhypL15t7y0X546sLisysSPvpwCWAizF0XTOOfV6kITgAb8xkjd4s5oSQpm4Q5ofgskoIwRhPEqmUgkOBCwBr3UTPVieUACPBWbXf3gBCiPB+0sPQoxhmZRERaDMJwU5PT4TkxugQfJ5lj99+ks/Kru+l5HW7k1IeAt7am9vX1zcX1ukAVqmeea9PVuuEQd/U2mqgiXa2GYaA/HozJ4RhErWZ8jJbHi+MMUmRQADwcFnfzBdL0bBtty3nM0cRZoCJK2a5dwaCF+yQpPTs3oYSoNHXTV0rjXliPLSjWi7K9apKk1SNkxrV6f2TYpYTSmMEhOKuri9eXxCKN5u1sdrYgy4IBs5JdAGFKIVMZJImKauOl5PXHBFeljISQqgzapYmp+t1Kgmj4ofvX2VptqzmyEPCeVrN67qOzj16cHa0WHCC33njzZffv8wR3SyP8iyJER0aUErOOaGUHW+OIvhU8izJqm6ajDXu8vH9B5vjlRCEUYmBfPfdv9948H/r9YoQIqV89uxZSunHn3xSzmbGmkzmD49Ov33+za8+/DjNJMKE/CRiAIQR++C9pygiQinngrP0xQ8/YEQ+//VnAD7EkMjs/Xd/8eWXf9vfbD96/+eUsn29d2r69KMP33v3Z9aarusikOLRo+fPv7m4uDg/f5im+U91K4TgnNO//uXP82qxXh1V1XK9Ptnebpfz6s3zc++sPHwiBePHR+tn33w9tF2WpV/9/SvB2aeffoQjaD3dbRIhQrWZur6hjNX7elJ6HIbt7c4YyyLCnDGEMYlYj8N6udA6ubi4MEYLISnVAA0h5PHbbzVte7PbdePw7vk7r68ujdGMHXomIpSm+f2zY4AgBItAnAdt7Diqvh/Zt8+/i3fAmNT7drFYtm3b9/2dJw8NhTEehgFhRBm9eH3x4OEDLuT1zbUxuqrmVVUyxjFGhHBKJaVM8JySO/cZjzFhXdcZY0IIMaLXF9fX1zdFUQghQggBwkEwMVprXfBZli2XyyRJtJ44EwRTSph3gWCGCfYu6EkbY+dzRDBTSmGMhBSMMfa/jUd8dnYWQqCUpmn604Hkvc/zXEoJKBZFcTiQlOIHylQIPJ9XhJAYo1JKa51lWTnLrLHWqjzPGWPe+/8GAAD//7swHNFia9prAAAAAElFTkSuQmCC")
imgPNG := must.M1(io.ReadAll(base64.NewDecoder(base64.StdEncoding, imgBase64)))
img := must.M1(png.Decode(bytes.NewBuffer(imgPNG)))
size := img.Bounds()
gonbui.DisplayHTML(fmt.Sprintf("<p>Image: (%d x %d)</p>", size.Dx(), size.Dy()))
gonbui.DisplayPNG(imgPNG)

// Classify:
c10Classifier := must.M1(classifier.New("~/work/cifar/base_fnn_model"))
classID := must.M1(c10Classifier.Classify(img))
className := cifar.C10Labels[classID]
gonbui.DisplayHTML(fmt.Sprintf("<p>Class: <b>%s (%d)</b></p>", className, classID))

Image: (32 x 32)

No description has been provided for this image

Class: horse (7)

Generate a random image as base64¶

In [13]:
// 
%%
backend := compute.MustNew()
ds := cifar.NewDataset(backend, "Samples Cifar-10", *flagDataDir, cifar.C10, DType, cifar.Train).Shuffle()
var batch train.Batch
for b, err := range ds.Iter() {
    must.M(err)
    batch = b
    break
}
inputs := batch.Inputs
imgTensor := inputs[0]
img := images.ToImage().Single(imgTensor)
imgSrc := must.M1(gonbui.EmbedImageAsPNGSrc(img))
// imgBase64 := imgSrc[22:]  // Strip the preamble for a <img> src tag.
// fmt.Printf("%s\n\n", imgBase64)
size := imgTensor.Shape().Dimensions[0]
gonbui.DisplayHTML(fmt.Sprintf(`<img width="%d" height="%d" src="%s"/>`, size*2, size*2, imgSrc))
No description has been provided for this image