UCI Adult Dataset or Census Income¶

This is a very popular ML task, with tabular data. The objective is to predict whether income exceeds $50K/yr based on census data. Also known as "Census Income" dataset.

The data is old and biased on different ways ... but it can be used opaquely for ML experimentation.

The source code of this example in one piece can be seen in the demo under .../examples/adult/demo/.

Environment Set Up¶

Let's set up go.mod to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.

If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.

Notice the directory ${HOME}/Projects/gomlx is where the GoMLX code is copied by default in its Docker.

For this example we are forcing it to use the CPU backend, even if there is a GPU available -- it is such a small model, not worth going to the GPU.

In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx" "${HOME}/Projects/gonb" "${HOME}/Projects/gopjrt" "${HOME}/Projects/bsplines"
%goworkfix
%env GOMLX_BACKEND=cpu
	- Added replace rule for module "github.com/gomlx/bsplines" to local directory "/home/janpf/Projects/bsplines".
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
	- Added replace rule for module "github.com/janpfeifer/gonb" to local directory "/home/janpf/Projects/gonb".
	- Added replace rule for module "github.com/gomlx/gopjrt" to local directory "/home/janpf/Projects/gopjrt".
Set: GOMLX_BACKEND="cpu"

Data Preparation¶

GoMLX provides a simple adult library to facilitate downdoaling and preprocessing the data. Data is available in UCI Machine Learning Repository.

After downloading the data and validating the checksum (both training and testing), it generates the quantiles for the continuous features, and the vocabularies for the categorical features. It saves all this info for faster restart later in a binary file. So this won't be necessary a second time.

The quantiles are used to calibrate the values, using a piece-wise-lienar calibration, very good for these things. See layers.PieceWiseLinearCalibration documentation.

We create a flag --data to define the directory where to save the intermediary files: downloaded and preprocessed datasets. In this examle we set it to ~/work/uci-adult. Verbosity can be contolled with the --verbosity flag.

We set default in Go for these flags, but they can easily be reset for a new run by providing them after the %% Jupyter kernel meta-command -- in indicates that the subsequent lines should be put in to a func main.

In [2]:
import (
    "flag"
    
    "github.com/gomlx/gomlx/examples/adult"
)

var (
    flagDataDir       = flag.String("data", "~/work/uci-adult", "Directory to save and load downloaded and generated dataset files.")
    flagVerbosity     = flag.Int("verbosity", 0, "Level of verbosity, the higher the more verbose.")
    flagForceDownload = flag.Bool("force_download", false, "Force re-download of Adult dataset files.")
    flagNumQuantiles  = flag.Int("quantiles", 100, "Max number of quantiles to use for numeric features, used during piece-wise linear calibration. It will only use unique values, so if there are fewer variability, fewer quantiles are used.")
)

%% --verbosity=2
adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)
Sample Categorical: (24.08% positive ratio, 23.86% weighted positive ratio)
	Row 0:	[7 10 5 1 2 5 2 39]
	Row 1:	[6 10 3 4 1 5 2 39]
	Row 2:	[4 12 1 6 2 5 2 39]
	...
	Row 32558:	[4 12 7 1 5 5 1 39]
	Row 32559:	[4 12 5 1 4 5 2 39]
	Row 32560:	[5 12 3 4 6 5 1 39]

Sample Continuous:
	Row 0:	[39 13 2174 0 40]
	Row 1:	[50 13 0 0 13]
	Row 2:	[38 9 0 0 40]
	...
	Row 32558:	[58 9 0 0 40]
	Row 32559:	[22 9 0 0 20]
	Row 32560:	[52 9 15024 0 40]
In [3]:
!ls -lh ~/work/uci-adult
total 170M
-rw-r--r-- 1 janpf janpf 3.8M Jul 20 18:36 adult.data
-rw-r--r-- 1 janpf janpf 1.3M Jul 20 18:36 adult_data-100_quantiles.bin
-rw-r--r-- 1 janpf janpf 2.0M Jul 20 18:36 adult.test
drwxr-x--- 2 janpf janpf 4.0K Aug  9 11:50 base_model
drwxr-xr-x 2 janpf janpf 4.0K Jun  4  2009 cifar-10-batches-bin
-rw-r--r-- 1 janpf janpf 163M Aug 15 09:33 cifar-10-binary.tar.gz
drwxr-x--- 2 janpf janpf 4.0K Aug  2 10:47 fnn
drwxr-x--- 2 janpf janpf 4.0K Jul 21 20:24 kan_baseline
drwxr-x--- 2 janpf janpf 4.0K Aug  9 11:51 kan_model

Hyperparameters¶

This sets the the superset of hyperparameters with their default values that can be used by the model, by setting them in the context.

See the demo/main.go file for how to add a flag to allow setting them from the command line.

In [4]:
import (
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/ml/layers"
    "github.com/gomlx/gomlx/ml/layers/fnn"
    "github.com/gomlx/gomlx/ml/layers/kan"
    "github.com/gomlx/gomlx/ml/layers/regularizers"
    "github.com/gomlx/gomlx/ui/commandline"
    "github.com/gomlx/gomlx/ml/train/optimizers"
	"github.com/janpfeifer/must"    
)

// settings is bound to a "-set" flag to be used to set context hyperparameters.
var settings = commandline.CreateContextSettingsFlag(createDefaultContext(), "set")

func createDefaultContext() *context.Context {
	ctx := context.NewContext()
	ctx.RngStateReset()
	ctx.SetParams(map[string]any{
		"train_steps":     5000,
		"batch_size":      128,
		"eval_batch_size":      1000,
		"plots":           true,
        "num_checkpoints": 3,

		optimizers.ParamOptimizer:           "adam",
		optimizers.ParamLearningRate:        0.001,
		optimizers.ParamAdamEpsilon:         1e-7,
		optimizers.ParamAdamDType:           "",
		optimizers.ParamCosineScheduleSteps: 0,
		activations.ParamActivation:         "sigmoid",
		layers.ParamDropoutRate:             0.0,
		regularizers.ParamL2:                1e-5,
		regularizers.ParamL1:                1e-5,

		// FNN network parameters:
		fnn.ParamNumHiddenLayers: 1,
		fnn.ParamNumHiddenNodes:  4,
		fnn.ParamResidual:        true,
		fnn.ParamNormalization:   "layer",

		// KAN network parameters:
		"kan":                       false, // Enable kan
		kan.ParamNumControlPoints:   20,    // Number of control points
		kan.ParamNumHiddenNodes:     4,
		kan.ParamNumHiddenLayers:    1,
		kan.ParamBSplineDegree:      2,
		kan.ParamBSplineMagnitudeL1: 1e-5,
		kan.ParamBSplineMagnitudeL2: 0.0,
		kan.ParamDiscrete:           false,
		kan.ParamDiscreteSoftness:   0.1,
	})
	return ctx
}

// contextFromSettings is the default context (createDefaultContext) changed by -set flag.
func contextFromSettings() (ctx *context.Context, paramsSet []string) {
    ctx = createDefaultContext()
    paramsSet = must.M1(commandline.ParseContextSettings(ctx, *settings))
    return
}

// Let's test that we can set hyperparameters by setting it in the "-set" flag:
%% -set="batch_size=17;fnn_num_hidden_layers=12"
ctx, paramsSet := contextFromSettings()
for _, name := range paramsSet {
    if value, found := ctx.GetParam(name); found {
        fmt.Printf("\t%s=%v\n", name, value)
    }
}
	batch_size=17
	fnn_num_hidden_layers=12

Creating Datasets¶

First we create the GoMLX's backend: it's the engine instance (XLA in this case) that compiles and executes our computation graph.

It's needed to create tensors that will be fed to the accelearator (GPU or even CPU accelerated code)

With that we create the samplers of data that we will use to train and evaluate. They implement GoMLX's train.Dataset interface, which is what is used by our training loop to draw batches to train, or our eval loop to draw batches to evaluate.

The inputs are 3 tensors: categorical values, continuous values and weights.

In the cell below we define BuildDatasets and printout some samples.

In [5]:
import (
    "flag"
    "fmt"
    "io"

    "github.com/gomlx/gomlx/backends"
    "github.com/gomlx/gomlx/examples/adult"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/types/tensors"

    _ "github.com/gomlx/gomlx/backends/xla"
)

// Global backend created an initialization, used everywhere.
var backend = backends.New()

// BuildDatasets returns 3 `train.Dataset`:
// * trainingSampler is an endless random sampler used for training.
// * trainingEvalSampler samples through exactly one epoch of the train dataset.
// * testEvalSampler samples through exactly one epoch of the test dataset.
func BuildDatasets(ctx *context.Context) (trainDS, trainEvalDS, testEvalDS train.Dataset) {
	batchSize := context.GetParamOr(ctx, "batch_size", 128)
	evalBatchSize := context.GetParamOr(ctx, "eval_batch_size", 1000)
    baseDS := adult.NewDataset(backend, adult.Data.Train, "batched train")
    trainEvalDS = baseDS.Copy().BatchSize(evalBatchSize, false)
    testEvalDS = adult.NewDataset(backend, adult.Data.Test, "test").
        BatchSize(evalBatchSize, false)
    // For training, we shuffle and loop indefinitely.
    trainDS = baseDS.BatchSize(batchSize, true).Shuffle().Infinite(true)
    return
}

// PositiveRatio finds out the the ratio of positive labels in the
// training and testing data.
//
// We could do this easily with GoMLX computation model (just `ReduceAllSum`), but
// this examples shows it's also ok to mix Go computations.
func PositiveRatio(ds train.Dataset) float32 {
    ds.Reset()  // Start from beginning.
    var sum float32
    var count float32
    for {
        _, _, labels, err := ds.Yield()
        if err == io.EOF {
            break;
        }
        if err != nil { panic(err) }
        data := tensors.CopyFlatData[float32](labels[0])
        for _, value := range data {
            sum += value
        }
        count += float32(len(data))
    }
    return sum/count
}

%%
adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    
ctx, _ := contextFromSettings()
trainingDS, trainingEvalDS, testEvalDS := BuildDatasets(ctx)

// Take one batch.
_, inputs, labels, err := trainingDS.Yield()
if err != nil { panic(err) }
fmt.Printf("Inputs of batch (size %d):\n", context.GetParamOr(ctx, "batch_size", 0))
fmt.Printf("\tcategorical:\n\t\tFeatures=%v\n", adult.Data.VocabulariesFeatures)
fmt.Printf("\t\tValues: %s\n", inputs[0].StringN(16))
fmt.Printf("\tcontinuous:\n\t\tFeatures=%v\n", adult.Data.QuantilesFeatures)
fmt.Printf("\t\tValues: %s\n", inputs[1].StringN(10))
fmt.Printf("\tweights: %s\n", inputs[2].StringN(5))
fmt.Printf("\nLabels of batch:\n\t%s\n", labels[0].StringN(10))
fmt.Printf("\nLabels distributions:\n\tTrain:\t%.2f%% positive\n\tTest:\t%.2f%% positive\n",
           PositiveRatio(trainingEvalDS)*100.0, PositiveRatio(testEvalDS)*100.0)
Inputs of batch (size 128):
	categorical:
		Features=[workclass education marital-status occupation relationship race sex native-country]
		Values: (Int64)[128 8]: (... too large, 1024 values ..., first 16 values: [4 16 5 3 2 5 2 39 4 12 5 7 5 3 2 39])
	continuous:
		Features=[age education-num capital-gain capital-loss hours-per-week]
		Values: (Float32)[128 5]: (... too large, 640 values ..., first 10 values: [23 10 0 0 40 41 9 0 0 40])
	weights: (Float32)[128 1]: (... too large, 128 values ..., first 5 values: [175266 157025 98642 328216 303462])

Labels of batch:
	(Float32)[128 1]: (... too large, 128 values ..., first 10 values: [0 0 1 1 0 0 0 0 0 0])

Labels distributions:
	Train:	24.08% positive
	Test:	23.62% positive

Model Definition¶

Lots of hyper-parameter flags, but otherwise a straight forward FNN, using piece-wise linear calibration of the continuous features, and embeddings for the categorical features.

Note: building models is a constant checking that shapes are compatible. It's a bit annoying, in particular because shapes are known in runtime only -- no compile time check. GoMLX tries to help providing a stack trace of where errors happen so one can pin-point issues quickly. But often it involves lots of experimentation (more than ordinary Go code).

Developing with a Noteboook (see GoNB) or simply a unit test on your ModelGraph function are quick and convenient ways to develop models -- before actually training them. You can also use shape asserts in the middle of the ModelGraph, as we do below.

In [6]:
import (
    "fmt"
    "io"

    . "github.com/gomlx/gomlx/graph"

    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/examples/adult"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/ml/train/optimizers"
    "github.com/gomlx/gomlx/types/shapes"
    "github.com/gomlx/gopjrt/dtypes"
)

var (
    // ModelDType used for the model. Must match RawData Go types.
    ModelDType = dtypes.Float32
    

    // Model hyperparameters.
    flagUseCategorical       = flag.Bool("use_categorical", true, "Use categorical features.")
    flagUseContinuous        = flag.Bool("use_continuous", true, "Use continuous features.")
    flagTrainableCalibration = flag.Bool("trainable_calibration", true, "Allow piece-wise linear calibration to adjust outputs.")
    flagEmbeddingDim    = flag.Int("embedding_dim", 8, "Default embedding dimension for categorical values.")
)


// ModelGraph outputs the logits (not the probabilities). The parameter inputs should contain 3 tensors:
//
// - categorical inputs, shaped  `(int64)[batch_size, len(VocabulariesFeatures)]`
// - continuous inputs, shaped `(float32)[batch_size, len(Quantiles)]`
// - weights: not currently used, but shaped `(float32)[batch_size, 1]`.
func ModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
	_ = spec // Not used, since the dataset is always the same.
	g := inputs[0].Graph()
	dtype := inputs[1].DType() // From continuous features.
    ctx = ctx.In("model")
    
	// Use Cosine schedule of the learning rate, if hyperparameter is set to a value > 0.
	optimizers.CosineAnnealingSchedule(ctx, g, dtype).FromContext().Done()

	categorical, continuous := inputs[0], inputs[1]
	batchSize := categorical.Shape().Dimensions[0]

	// Feature preprocessing:
	var allEmbeddings []*Node
	if *flagUseCategorical {
		// Embedding of categorical values, each with its own vocabulary.
		numCategorical := categorical.Shape().Dimensions[1]
		for catIdx := 0; catIdx < numCategorical; catIdx++ {
			// Take one column at a time of the categorical values.
			split := Slice(categorical, AxisRange(), AxisRange(catIdx, catIdx+1))
			// Embed it accordingly.
			embedCtx := ctx.In(fmt.Sprintf("categorical_%d_%s", catIdx, adult.Data.VocabulariesFeatures[catIdx]))
			vocab := adult.Data.Vocabularies[catIdx]
			vocabSize := len(vocab)
			embedding := layers.Embedding(embedCtx, split, ModelDType, vocabSize, *flagEmbeddingDim)
			embedding.AssertDims(batchSize, *flagEmbeddingDim) // 2-dim tensor, with batch size as the leading dimension.
			allEmbeddings = append(allEmbeddings, embedding)
		}
	}

	if *flagUseContinuous {
		// Piecewise-linear calibration of the continuous values. Each feature has its own number of quantiles.
		numContinuous := continuous.Shape().Dimensions[1]
		for contIdx := 0; contIdx < numContinuous; contIdx++ {
			// Take one column at a time of the continuous values.
			split := Slice(continuous, AxisRange(), AxisRange(contIdx, contIdx+1))
			featureName := adult.Data.QuantilesFeatures[contIdx]
			calibrationCtx := ctx.In(fmt.Sprintf("continuous_%d_%s", contIdx, featureName))
			quantiles := adult.Data.Quantiles[contIdx]
			layers.AssertQuantilesForPWLCalibrationValid(quantiles)
			calibrated := layers.PieceWiseLinearCalibration(calibrationCtx, split, Const(g, quantiles),
				*flagTrainableCalibration)
			calibrated.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.
			allEmbeddings = append(allEmbeddings, calibrated)
		}
	}
	logits := Concatenate(allEmbeddings, -1)
	logits.AssertDims(batchSize, -1) // 2-dim tensor, with batch size as the leading dimension (-1 means it is not checked).

	// Model itself is an FNN or a KAN.
	if context.GetParamOr(ctx, "kan", false) {
		// Use KAN, all configured by context hyperparameters. See createDefaultContext for defaults.
		logits = kan.New(ctx.In("kan"), logits, 1).Done()
	} else {
		// Normal FNN, all configured by context hyperparameters. See createDefaultContext for defaults.
		logits = fnn.New(ctx.In("fnn"), logits, 1).Done()
	}
	logits.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.
	return []*Node{logits}
}

%%
adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    

// Let's just check that we get the right shape from the model function, wihtout any real data.
ctx, _ := contextFromSettings()
graph := NewGraph(backend, "test")
batchSize := context.GetParamOr(ctx, "batch_size", 128)
// Create placeholder (parameters) graph nodes, just to test the graph building is working.
inputs := []*Node{
    // Categorical: shaped [batch_size, num_categorical]
    Parameter(graph, "categorical", shapes.Make(dtypes.Int64, batchSize, len(adult.Data.VocabulariesFeatures))),
    // Continuous: shaped [batch_size, num_continuos]
    Parameter(graph, "continuous", shapes.Make(dtypes.Float32, batchSize, len(adult.Data.QuantilesFeatures))),
    // Weights: shaped [batch_size, 1]
    Parameter(graph, "weights", shapes.Make(dtypes.Float32, batchSize, 1)),    
}
logits := ModelGraph(ctx, nil, inputs)
fmt.Printf("Logits shape for batch_size=%d: %s\n", batchSize, logits[0].Shape())
Logits shape for batch_size=128: (Float32)[128 1]

Training Loop¶

We can create a training loop with only a Manager, a Context (for the model varibles) and the ModelGraph function.

To make it more interesting we also add the following:

  • Accuracy metrics for training and testing.
  • Checkpoints -- so trained model can be saved, and reloaded.
  • A progress-bar that also shows training metrics.
  • We dynamically plot how the loss and accuracy evolve.

First we define the corresponding flags and the trainModel function, and run it for very few steps to make sure it is working.

In [7]:
import (
    "fmt"
    "time"
    "github.com/gomlx/gomlx/ui/gonb/plotly"
)

var (
    flagCheckpoint     = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
    flagPlots          = flag.Bool("plots", true, "Plots during training: perform periodic evaluations, "+
                                   "save results if --checkpoint is set and draw plots, if in a Jupyter notebook.")
    flagPlotType       = flag.String("plot_type", "plotly", "Type of plot to use, values are \"plotly\" or \"margaid\"")
)

func trainModel(ctx *context.Context, paramsSet []string) {
    *flagDataDir = data.ReplaceTildeInDir(*flagDataDir)

    // Load data and create datasets.
    adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    
    trainDS, trainEvalDS, testEvalDS := BuildDatasets(ctx)

	// Checkpoints loading (and saving)
	var checkpoint *checkpoints.Handler
	if *flagCheckpoint != "" {
		numCheckpointsToKeep := context.GetParamOr(ctx, "num_checkpoints", 3)
		checkpoint = must.M1(checkpoints.Build(ctx).
			DirFromBase(*flagCheckpoint, *flagDataDir).
			Keep(numCheckpointsToKeep).
			ExcludeParams(append(paramsSet, "train_steps", "plots", "num_checkpoints")...).
			Done())
	}

	// Metrics we are interested.
	meanAccuracyMetric := metrics.NewMeanBinaryLogitsAccuracy("Mean Accuracy", "#acc")
	movingAccuracyMetric := metrics.NewMovingAverageBinaryLogitsAccuracy("Moving Average Accuracy", "~acc", 0.01)

	// Create a train.Trainer: this object will orchestrate running the model, feeding
	// results to the optimizer, evaluating the metrics, etc. (all happens in trainer.TrainStep)
	trainer := train.NewTrainer(backend, ctx, ModelGraph, losses.BinaryCrossentropyLogits,
		optimizers.FromContext(ctx),
		[]metrics.Interface{movingAccuracyMetric}, // trainMetrics
		[]metrics.Interface{meanAccuracyMetric})   // evalMetrics

    // Use standard training loop.
    loop := train.NewLoop(trainer)
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.

    // Attach a checkpoint saver: checkpoint every 1 minute of training.
    if checkpoint != nil {
        period := time.Minute * 1
        train.PeriodicCallback(loop, period, true, "saving checkpoint", 100,
            func(loop *train.Loop, metrics []*tensors.Tensor) error {
                fmt.Printf("\n[saving checkpoint@%d] [median train step (ms): %d]\n", loop.LoopStep, loop.MedianTrainStepDuration().Milliseconds())
                return checkpoint.Save()
            })
    }

	// Attach Plotly plots: plot points at exponential steps.
	// The points generated are saved along the checkpoint directory (if one is given).
	if context.GetParamOr(ctx, margaid.ParamPlots, false) {
		_ = plotly.New().
			WithCheckpoint(checkpoint).
			Dynamic().
			WithDatasets(trainEvalDS, testEvalDS).
			ScheduleExponential(loop, 200, 1.2)
	}

	// Train up to "train_steps".
	globalStep := int(optimizers.GetGlobalStep(ctx))
	trainSteps := context.GetParamOr(ctx, "train_steps", 0)
	if globalStep < trainSteps {
		if globalStep != 0 {
			fmt.Printf("\t- restarting training from global_step=%d\n", globalStep)
            trainer.SetContext(ctx.Reuse())
		}
		_ = must.M1(loop.RunSteps(trainDS, trainSteps-globalStep))
		fmt.Printf("\t[Step %d] median train step: %d microseconds\n", loop.LoopStep, loop.MedianTrainStepDuration().Microseconds())
	} else {
		fmt.Printf("\t - target train_steps=%d already reached. To train further, set a number larger than "+
			"current global step.\n", trainSteps)
	}

	// Finally, print an evaluation on train and test datasets.
	must.M(commandline.ReportEval(trainer, trainEvalDS, testEvalDS))
}

// Notice command line flags are passed in the %% notebook command. We set plots=false here to disable plotting
// since this is only a quick test that our train() loop is working. See below the final run for a full training.
%% -set="train_steps=100;plots=false"
trainModel(contextFromSettings())
Training (100 steps):  100% [========================================] (145 steps/s) [step=99] [loss+=0.537] [~loss+=0.618] [~loss=0.618] [~acc=75.56%]        
	[Step 100] median train step: 346 microseconds
Results on batched train:
	Mean Loss+Regularization (#loss+): 0.537
	Mean Loss (#loss): 0.537
	Mean Accuracy (#acc): 80.49%
Results on test:
	Mean Loss+Regularization (#loss+): 0.535
	Mean Loss (#loss): 0.535
	Mean Accuracy (#acc): 80.65%

Final run with 5K steps¶

With everything working, we can do our final run.

Note here is where someone might want to hyperparameter tune, trying out different hyperparameters.

In [8]:
// Remove previously trained model -- skip this cell, if you want to continue training.
!rm -rf ~/work/uci-adult/base_model
In [9]:
%% --checkpoint base_model -set="plots=true;train_steps=5000"
trainModel(contextFromSettings())
Training (5000 steps):   14% [====>...................................] (434 steps/s) [1s:9s] [step=724] [loss+=0.312] [~loss+=0.321] [~loss=0.320] [~acc=85.96%]         
Training (5000 steps):  100% [========================================] (1851 steps/s) [step=4999] [loss+=0.244] [~loss+=0.273] [~loss=0.273] [~acc=87.31%]        

[saving checkpoint@5000] [median train step (ms): 0]

Metric: accuracy

Metric: loss

	[Step 5000] median train step: 271 microseconds
Results on batched train:
	Mean Loss+Regularization (#loss+): 0.271
	Mean Loss (#loss): 0.270
	Mean Accuracy (#acc): 87.39%
Results on test:
	Mean Loss+Regularization (#loss+): 0.279
	Mean Loss (#loss): 0.279
	Mean Accuracy (#acc): 87.05%

Extend training another 5K steps¶

Since the model training went well, and it doesn't seem to be yet terribly overfiting, let's train further, another 5k steps, for 10K steps in total.

Notice the plots continue from where it stopped. And this time we use Plotly to plot the training results -- they don't display in Github since they depend on javascript.

Unfortunately, it doesn't help (the accuracy on the test set doesn't improve), 5k steps was already enough.

In [10]:
%% --checkpoint base_model -set="plots=true;train_steps=10000"
ctx, paramsSet := contextFromSettings()
fmt.Printf("train_steps=%d\n", context.GetParamOr(ctx, "train_steps", 0))
trainModel(ctx, paramsSet)
train_steps=10000
	- restarting training from global_step=5000
Training (5000 steps):  100% [========================================] (2074 steps/s) [step=9999] [loss+=0.351] [~loss+=0.274] [~loss=0.273] [~acc=87.45%]        

[saving checkpoint@10000] [median train step (ms): 0]

Metric: accuracy

Metric: loss

	[Step 10000] median train step: 261 microseconds
Results on batched train:
	Mean Loss+Regularization (#loss+): 0.269
	Mean Loss (#loss): 0.269
	Mean Accuracy (#acc): 87.47%
Results on test:
	Mean Loss+Regularization (#loss+): 0.281
	Mean Loss (#loss): 0.280
	Mean Accuracy (#acc): 87.32%

Using Kolmogorov-Arnold Networks (KAN)¶

Since it's avaialable as a layer (see package kan), the model supports it by simply changing a hyperparameter.

See description in https://arxiv.org/pdf/2404.19756

In [11]:
// Remove previously trained model -- skip this cell, if you want to continue training.
!rm -rf ~/work/uci-adult/kan_model
In [12]:
%% --checkpoint kan_model -set="kan=true;activation=swish;plots=true;train_steps=5000"
trainModel(contextFromSettings())
Training (5000 steps):   14% [====>...................................] (93 steps/s) [9s:45s] [step=724] [loss+=0.223] [~loss+=0.287] [~loss=0.284] [~acc=86.97%]        
Training (5000 steps):  100% [========================================] (89 steps/s) [step=4999] [loss+=0.355] [~loss+=0.275] [~loss=0.271] [~acc=87.33%]        

[saving checkpoint@5000] [median train step (ms): 10]

Metric: accuracy

Metric: loss

	[Step 5000] median train step: 10580 microseconds
Results on batched train:
	Mean Loss+Regularization (#loss+): 0.269
	Mean Loss (#loss): 0.266
	Mean Accuracy (#acc): 87.60%
Results on test:
	Mean Loss+Regularization (#loss+): 0.282
	Mean Loss (#loss): 0.279
	Mean Accuracy (#acc): 87.38%