UCI Adult Dataset or Census Income¶

This is a very popular ML task, with tabular data. The objective is to predict whether income exceeds $50K/yr based on census data. Also known as "Census Income" dataset.

The data is old and biased on different ways ... but it can be used opaquely for ML experimentation.

The source code of this example in one piece can be seen in the demo under .../examples/adult/demo/.

Data Preparation¶

GoMLX provides a simple adult library to facilitate downdoaling and preprocessing the data. Data is available in UCI Machine Learning Repository.

After downloading the data and validating the checksum (both training and testing), it generates the quantiles for the continuous features, and the vocabularies for the categorical features. It saves all this info for faster restart later in a binary file. So this won't be necessary a second time.

The quantiles are used to calibrate the values, using a piece-wise-lienar calibration, very good for these things. See layers.PieceWiseLinearCalibration documentation.

We create a flag --data to define the directory where to save the intermediary files: downloaded and preprocessed dataset. In this examle we set it to ~/work/uci-adult. Verbosity can be contolled with the --verbosity flag.

We set default in Go for these flags, but they can easily be reset for a new run by providing them after the %% Jupyter kernel meta-command -- in indicates that the subsequent lines should be put in to a func main.

In [1]:
import (
    "flag"
    
    "github.com/gomlx/gomlx/examples/adult"
)

var (
    flagDataDir       = flag.String("data", "~/work/uci-adult", "Directory to save and load downloaded and generated dataset files.")
    flagVerbosity     = flag.Int("verbosity", 0, "Level of verbosity, the higher the more verbose.")
    flagForceDownload = flag.Bool("force_download", false, "Force re-download of Adult dataset files.")
    flagNumQuantiles  = flag.Int("quantiles", 100, "Max number of quantiles to use for numeric features, used during piece-wise linear calibration. It will only use unique values, so if there are fewer variability, fewer quantiles are used.")
)

%% --verbosity=2
adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)
Sample Categorical: (24.08% positive ratio, 23.86% weighted positive ratio)
	Row 0:	[7 10 5 1 2 5 2 39]
	Row 1:	[6 10 3 4 1 5 2 39]
	Row 2:	[4 12 1 6 2 5 2 39]
	...
	Row 32558:	[4 12 7 1 5 5 1 39]
	Row 32559:	[4 12 5 1 4 5 2 39]
	Row 32560:	[5 12 3 4 6 5 1 39]

Sample Continuous:
	Row 0:	[39 13 2174 0 40]
	Row 1:	[50 13 0 0 13]
	Row 2:	[38 9 0 0 40]
	...
	Row 32558:	[58 9 0 0 40]
	Row 32559:	[22 9 0 0 20]
	Row 32560:	[52 9 15024 0 40]
In [2]:
!ls -lh ~/work/uci-adult
total 170M
drwxr-x--- 2 janpf janpf 4.0K Dec  4  2025 -h
-rw-r--r-- 1 janpf janpf 3.8M Jul 20  2024 adult.data
-rw-r--r-- 1 janpf janpf 2.0M Jul 20  2024 adult.test
-rw-r--r-- 1 janpf janpf 1.3M Jul 20  2024 adult_data-100_quantiles.bin
drwxr-x--- 2 janpf janpf 4.0K May 24 11:55 base_mode
drwxr-x--- 2 janpf janpf 4.0K Jun  9 07:38 base_model
drwxr-xr-x 2 janpf janpf 4.0K Jun  4  2009 cifar-10-batches-bin
-rw-r--r-- 1 janpf janpf 163M Aug 15  2024 cifar-10-binary.tar.gz
drwxr-x--- 2 janpf janpf 4.0K Dec  4  2025 distributed
drwxr-x--- 2 janpf janpf 4.0K Aug  2  2024 fnn
drwxr-x--- 2 janpf janpf 4.0K Jul 21  2024 kan_baseline
drwxr-x--- 2 janpf janpf 4.0K Jun  9 07:38 kan_model
drwxr-x--- 2 janpf janpf 4.0K Feb 23  2025 test
drwxr-x--- 2 janpf janpf 4.0K Jun 18  2025 test_model
drwxr-x--- 3 janpf janpf 4.0K May 24 11:58 v028_test

Hyperparameters¶

This sets the superset of hyperparameters with their default values that can be used by the model, by setting them in the model.Store.

See the demo/main.go file for how to add a flag to allow setting them from the command line.

In [3]:
import (
    "github.com/gomlx/gomlx/ml/model"
    "github.com/gomlx/gomlx/ml/layers"
    "github.com/gomlx/gomlx/ml/layers/fnn"
    "github.com/gomlx/gomlx/ml/layers/kan"
    "github.com/gomlx/gomlx/ml/layers/regularizer"
    "github.com/gomlx/gomlx/ui/commandline"
    "github.com/gomlx/gomlx/ml/train/optimizer"
    "github.com/gomlx/gomlx/ml/train/optimizer/cosineschedule"
	"github.com/janpfeifer/must"    
)

// settings is bound to a "-set" flag to be used to set scope hyperparameters.
var settings = commandline.CreateSettingsFlag(createModelStore(), "set")

func createModelStore() *model.Store {
	store := model.NewStore()
	store.SetParams(map[string]any{
		"train_steps":     5000,
		"batch_size":      128,
		"eval_batch_size":      1000,
		"plots":           true,
        "num_checkpoints": 3,

		optimizer.ParamOptimizer:           "adam",
		optimizer.ParamLearningRate:        0.001,
		optimizer.ParamAdamEpsilon:         1e-7,
		optimizer.ParamAdamDType:           "",
		cosineschedule.ParamPeriodSteps:     0,
		activation.ParamActivation:         "sigmoid",
		layers.ParamDropoutRate:             0.0,
		regularizer.ParamL2:                1e-5,
		regularizer.ParamL1:                1e-5,

		// FNN network parameters:
		fnn.ParamNumHiddenLayers: 1,
		fnn.ParamNumHiddenNodes:  4,
		fnn.ParamResidual:        true,
		fnn.ParamNormalization:   "layer",

		// KAN network parameters:
		"kan":                       false, // Enable kan
		kan.ParamNumControlPoints:   20,    // Number of control points
		kan.ParamNumHiddenNodes:     4,
		kan.ParamNumHiddenLayers:    1,
		kan.ParamBSplineDegree:      2,
		kan.ParamBSplineMagnitudeL1: 1e-5,
		kan.ParamBSplineMagnitudeL2: 0.0,
		kan.ParamDiscrete:           false,
		kan.ParamDiscreteSoftness:   0.1,
	})
	return store
}

// storeFromSettings is the default scope (createModelStore) changed by -set flag.
func storeFromSettings() (store *model.Store, paramsSet []string) {
    store = createModelStore()
    paramsSet = must.M1(commandline.ParseSettings(store, *settings))
    return
}
In [4]:
// Let's test that we can set hyperparameters by setting it in the "-set" flag:
%% -set="batch_size=17;fnn_num_hidden_layers=12"
store, paramsSet := storeFromSettings()
for _, name := range paramsSet {
    if value, found := store.GetParam(name); found {
        fmt.Printf("\t%s=%v\n", name, value)
    }
}
	batch_size=17
	fnn_num_hidden_layers=12

Creating Datasets¶

First we create the GoMLX's backend (a compute.Backend interface defined in github.com/gomlx/compute): it's the engine instance (XLA in this case) that compiles and executes our computation graph on top of some software/hardware stack (e.g. "xla:cuda", "xla:cpu", "cpu").

With that we create the samplers of data that we will use to train and evaluate. They implement GoMLX's train.Dataset interface, which is what is used by our training loop to draw batches to train, or our eval loop to draw batches to evaluate.

The inputs are 3 tensors: categorical values, continuous values and weights.

In the cell below we define BuildDatasets and printout some samples.

In [5]:
import (
    "flag"
    "fmt"
    "io"

    "github.com/gomlx/compute"
    "github.com/gomlx/gomlx/examples/adult"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/core/tensors"
    . "github.com/gomlx/gomlx/core/graph"

    _ "github.com/gomlx/gomlx/backends/default"
)

// Global backend created an initialization, used everywhere.
var backend = compute.MustNew()

// BuildDatasets returns 3 `train.Dataset`:
// * trainDS is an endless random sampler used for training.
// * trainEvalDS samples through exactly one epoch of the train dataset.
// * testEvalDS samples through exactly one epoch of the test dataset.
func BuildDatasets(store *model.Store) (trainDS, trainEvalDS, testEvalDS train.Dataset) {
	batchSize := model.GetRootParamOr(store, "batch_size", 128)
	evalBatchSize := model.GetRootParamOr(store, "eval_batch_size", 1000)
    baseDS := adult.NewDataset(backend, adult.Data.Train, "batched train")
    trainEvalDS = baseDS.Copy().BatchSize(evalBatchSize, false)
    testEvalDS = adult.NewDataset(backend, adult.Data.Test, "test").
        BatchSize(evalBatchSize, false)
    // For training, we shuffle and loop indefinitely.
    trainDS = baseDS.BatchSize(batchSize, true).Shuffle().Infinite(true)
    return
}

// PositiveRatio finds out the the ratio of positive labels in the
// training and testing data.
//
// We could do this easily with GoMLX computation model (just `ReduceAllSum`), but
// this examples shows it's also ok to mix Go computations.
func PositiveRatio(ds train.Dataset) float32 {
    var sum float32
    var count float32
    reduceSumExec := MustNewExec1(backend, func(labels *Node) *Node {
        return ReduceAllSum(labels)
    })
    for batch, err := range ds.Iter() {
        if err != nil { panic(err) }
        sum += reduceSumExec.MustCall(batch.Labels[0]).Value().(float32)
        count += float32(batch.Labels[0].Size())
        batch.Finalize()
    }
    return sum/count
}

%%
fmt.Printf("Backend: %s - %s\n", backend, backend.Description())

adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    
store, _ := storeFromSettings()
trainDS, trainEvalDS, testEvalDS := BuildDatasets(store)

// Take one batch.
var batch train.Batch
var err error
for batch, err = range trainDS.Iter() {
    break
}
if err != nil { panic(err) }
fmt.Printf("Inputs of batch (size %d):\n", model.GetRootParamOr(store, "batch_size", 0))
fmt.Printf("\tcategorical:\n\t\tFeatures=%v\n", adult.Data.VocabulariesFeatures)
fmt.Printf("\t\tValues: %s\n", batch.Inputs[0])
fmt.Printf("\tcontinuous:\n\t\tFeatures=%v\n", adult.Data.QuantilesFeatures)
fmt.Printf("\t\tValues: %s\n", batch.Inputs[1])
fmt.Printf("\tweights: %s\n", batch.Inputs[2])
fmt.Printf("\nLabels of batch:\n\t%s\n", batch.Labels[0])
fmt.Printf("\nLabels distributions:\n\tTrain:\t%.2f%% positive\n\tTest:\t%.2f%% positive\n",
           PositiveRatio(trainEvalDS)*100.0, PositiveRatio(testEvalDS)*100.0)
Backend: xla - xla:cuda - PJRT "cuda" plugin (/home/janpf/.local/lib/go-xla/nvidia/pjrt_c_api_cuda_plugin.so) v0.100 [StableHLO] [1 device(s)]
Inputs of batch (size 128):
	categorical:
		Features=[workclass education marital-status occupation relationship race sex native-country]
		Values: [128][8]int64{
 {4, 16, 3, ..., 5, 2, 39},
 {6, 8, 3, ..., 5, 2, 0},
 {4, 10, 5, ..., 5, 1, 39},
 ...,
 {4, 6, 3, ..., 5, 2, 39},
 {4, 8, 3, ..., 2, 2, 39},
 {4, 12, 1, ..., 5, 1, 39}}
	continuous:
		Features=[age education-num capital-gain capital-loss hours-per-week]
		Values: [128][5]float32{
 {58, 10, 0, 0, 36},
 {42, 12, 0, 0, 40},
 {46, 13, 0, 0, 40},
 ...,
 {55, 4, 0, 0, 75},
 {41, 12, 0, 0, 40},
 {36, 9, 0, 0, 40}}
	weights: [128][1]float32{
 {9.584e+04},
 {1.838e+05},
 {2.792e+05},
 ...,
 {1.905e+05},
 {3.497e+05},
 {1.166e+05}}

Labels of batch:
	[128][1]float32{
 {0},
 {0},
 {0},
 ...,
 {0},
 {0},
 {0}}

Labels distributions:
	Train:	24.08% positive
	Test:	23.62% positive

Model Definition¶

Lots of hyper-parameter flags, but otherwise a straight forward FNN, using piece-wise linear calibration of the continuous features, and embeddings for the categorical features.

Note: building models is a constant checking that shapes are compatible. It's a bit annoying, in particular because shapes are known in runtime only -- no compile time check. GoMLX tries to help providing a stack trace of where errors happen so one can pin-point issues quickly. But often it involves lots of experimentation (more than ordinary Go code).

Developing with a Noteboook (see GoNB) or simply a unit test on your ModelGraph function are quick and convenient ways to develop models -- before actually training them. You can also use shape asserts in the middle of the Model building, as we do below.

In [6]:
import (
    "fmt"
    "io"

    . "github.com/gomlx/gomlx/core/graph"

    "github.com/gomlx/gomlx/examples/adult"
    "github.com/gomlx/compute/shapes"
    "github.com/gomlx/gomlx/ml/model"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/ml/train/optimizer"
    "github.com/gomlx/gomlx/ml/train/optimizer/cosineschedule"
    "github.com/gomlx/compute/dtypes"
)

var (
    // ModelDType used for the model. Must match RawData Go types.
    ModelDType = dtypes.Float32
    

    // Model hyperparameters.
    flagUseCategorical       = flag.Bool("use_categorical", true, "Use categorical features.")
    flagUseContinuous        = flag.Bool("use_continuous", true, "Use continuous features.")
    flagTrainableCalibration = flag.Bool("trainable_calibration", true, "Allow piece-wise linear calibration to adjust outputs.")
    flagEmbeddingDim    = flag.Int("embedding_dim", 8, "Default embedding dimension for categorical values.")
)

// Model outputs the logits (not the probabilities). 
// It has the `train.ModelFn` signature, so we can use it later with the trainer to automatically train
// and evaluate the model.
//
// The parameter inputs should contain 3 tensors:
//
// - categorical inputs, shaped  `(int64)[batch_size, len(VocabulariesFeatures)]`
// - continuous inputs, shaped `(float32)[batch_size, len(Quantiles)]`
// - weights: not currently used, but shaped `(float32)[batch_size, 1]`.
func Model(scope *model.Scope, categorical, continuous, weights *Node) *Node {
	g := categorical.Graph()
	dtype := continuous.DType() // From continuous features.
    scope = scope.In("model")
    
	// Use Cosine schedule of the learning rate, if hyperparameter is set to a value > 0.
	cosineschedule.New(scope, g, dtype).FromScope().Done()

	batchSize := categorical.Shape().Dimensions[0]
    

	// Feature preprocessing:
	var allEmbeddings []*Node
	if *flagUseCategorical {
		// Embedding of categorical values, each with its own vocabulary.
		numCategorical := categorical.Shape().Dimensions[1]
		for catIdx := 0; catIdx < numCategorical; catIdx++ {
			// Take one column at a time of the categorical values.
			split := Slice(categorical, AxisRange(), AxisRange(catIdx, catIdx+1))
			// Embed it accordingly.
			embedScope := scope.In(fmt.Sprintf("categorical_%d_%s", catIdx, adult.Data.VocabulariesFeatures[catIdx]))
			vocab := adult.Data.Vocabularies[catIdx]
			vocabSize := len(vocab)
			embedding := layers.Embedding(embedScope, split, ModelDType, vocabSize, *flagEmbeddingDim)
			embedding.AssertDims(batchSize, *flagEmbeddingDim) // 2-dim tensor, with batch size as the leading dimension.
			allEmbeddings = append(allEmbeddings, embedding)
		}
	}

	if *flagUseContinuous {
		// Piecewise-linear calibration of the continuous values. Each feature has its own number of quantiles.
		numContinuous := continuous.Shape().Dimensions[1]
		for contIdx := 0; contIdx < numContinuous; contIdx++ {
			// Take one column at a time of the continuous values.
			split := Slice(continuous, AxisRange(), AxisRange(contIdx, contIdx+1))
			featureName := adult.Data.QuantilesFeatures[contIdx]
			calibrationScope := scope.In(fmt.Sprintf("continuous_%d_%s", contIdx, featureName))
			quantiles := adult.Data.Quantiles[contIdx]
			layers.AssertQuantilesForPWLCalibrationValid(quantiles)
			calibrated := layers.PieceWiseLinearCalibration(calibrationScope, split, Const(g, quantiles),
				*flagTrainableCalibration)
			calibrated.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.
			allEmbeddings = append(allEmbeddings, calibrated)
		}
	}
	logits := Concatenate(allEmbeddings, -1)
	logits.AssertDims(batchSize, -1) // 2-dim tensor, with batch size as the leading dimension (-1 means it is not checked).

	// Model itself is an FNN or a KAN.
	if model.GetParamOr(scope, "kan", false) {
		// Use KAN, all configured by scope hyperparameters. See createModelStore for defaults.
		logits = kan.New(scope.In("kan"), logits, 1).Done()
	} else {
		// Normal FNN, all configured by scope hyperparameters. See createModelStore for defaults.
		logits = fnn.New(scope.In("fnn"), logits, 1).Done()
	}
	logits.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.
	return logits
}

Training Loop¶

We can create a training loop with only a backend, a store (container for the model variables) and the Model function.

To make it more interesting we also add the following:

  • Accuracy metrics for training and testing.
  • Checkpoints -- so trained model can be saved, and reloaded.
  • A progress-bar that also shows training metrics.
  • We dynamically plot how the loss and accuracy evolve.

First we define the corresponding flags and the trainModel function, and run it for very few steps to make sure it is working.

In [7]:
import (
    "fmt"
    "time"
    "github.com/gomlx/compute/support/humanize"
    "github.com/gomlx/gomlx/ui/gonb/plotly"
    "github.com/gomlx/gomlx/support/fsutil"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/ml/train/metric"
    "github.com/gomlx/gomlx/ml/model/checkpoint"
    "github.com/gomlx/gomlx/ml/train/loss"
    "github.com/gomlx/gomlx/core/tensors"
)

var (
    flagCheckpoint     = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
    flagPlots          = flag.Bool("plots", true, "Plots during training: perform periodic evaluations, "+
                                   "save results if --checkpoint is set and draw plots, if in a Jupyter notebook.")
    flagPlotType       = flag.String("plot_type", "plotly", "Type of plot to use, values are \"plotly\" or \"margaid\"")
)

func trainModel(store *model.Store, paramsSet []string) {
    scope := store.RootScope()
    *flagDataDir = must.M1(fsutil.ReplaceTildeInDir(*flagDataDir))

    // Load data and create dataset.
    adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    
    trainDS, trainEvalDS, testEvalDS := BuildDatasets(store)

	// Checkpoints loading (and saving)
    var checkpointHandler *checkpoint.Handler
    if *flagCheckpoint != "" {
        numCheckpointsToKeep := model.GetParamOr(scope, "num_checkpoints", 3)
        checkpointHandler = must.M1(checkpoint.Build(store).
            DirFromBase(*flagCheckpoint, *flagDataDir).
            Keep(numCheckpointsToKeep).
            ExcludeParams(append(paramsSet, "train_steps", "plots", "num_checkpoints")...).
            Done())
    }

    // Metrics we are interested.
    meanAccuracyMetric := metric.NewMeanBinaryLogitsAccuracy("Mean Accuracy", "#acc")
    movingAccuracyMetric := metric.NewMovingAverageBinaryLogitsAccuracy("Moving Average Accuracy", "~acc", 0.01)

	// Create a train.Trainer: this object will orchestrate running the model, feeding
	// results to the optimizer, evaluating the metrics, etc. (all happens in trainer.TrainStep)
    trainer := train.NewTrainer(backend, store, Model, loss.BinaryCrossentropyLogits,
        optimizer.FromStore(store),
        []metric.Interface{movingAccuracyMetric}, // trainMetrics
        []metric.Interface{meanAccuracyMetric})   // evalMetrics

    // Use standard training loop.
    loop := train.NewLoop(trainer)
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.

    // Attach a checkpoint saver: checkpoint every 1 minute of training.
    if checkpointHandler != nil {
        period := time.Minute * 1
        train.PeriodicCallback(loop, period, true, "saving checkpoint", 100,
            func(loop *train.Loop, metrics []*tensors.Tensor) error {
                fmt.Printf("\n[saving checkpoint@%d] [median train step: %s]\n", 
                           loop.LoopStep, humanize.Duration(loop.MedianTrainStepDuration()))
                return checkpointHandler.Save()
            })
    }

	// Attach Plotly plots: plot points at exponential steps.
	// The points generated are saved along the checkpoint directory (if one is given).
	if model.GetParamOr(scope, "plots", false) {
		_ = plotly.New().
			WithCheckpoint(checkpointHandler).
			Dynamic().
			WithDatasets(trainEvalDS, testEvalDS).
			ScheduleExponential(loop, 200, 1.2)
	}

	// Train up to "train_steps".
	trainSteps := model.GetParamOr(scope, "train_steps", 0)
	globalStep := int(optimizer.GetGlobalStep(scope))
	time.Sleep(time.Millisecond) // Ensure no issues with fast loops
	if globalStep < trainSteps {
		if globalStep != 0 {
			fmt.Printf("\t- restarting training from global_step=%d\n", globalStep)
		}
		_ = must.M1(loop.RunSteps(trainDS, trainSteps-globalStep))
		fmt.Printf("\t[Step %d] median train step: %d microseconds\n", loop.LoopStep, loop.MedianTrainStepDuration().Microseconds())
	} else {
		fmt.Printf("\t - target train_steps=%d already reached. To train further, set a number larger than "+
			"current global step.\n", trainSteps)
	}

	// Finally, print an evaluation on train and test dataset.
	must.M(commandline.ReportEval(trainer, trainEvalDS, testEvalDS))
}

// Notice command line flags are passed in the %% notebook command. We set plots=false here to disable plotting
// since this is only a quick test that our train() loop is working. See below the final run for a full training.

Final run with 5K steps¶

With everything working, we can do our final run.

Note here is where someone might want to hyperparameter tune, trying out different hyperparameters.

In [8]:
// Remove previously trained model -- skip this cell, if you want to continue training.
!rm -rf ~/work/uci-adult/base_model
In [9]:
%% --checkpoint base_model -set="plots=true;train_steps=5000"
trainModel(storeFromSettings())
        14% [====>...................................] (108 steps/s) [2s:39s] [step=724] [loss=0.351] [~loss=0.319] [~acc=85.74%]        
       100% [========================================] (1414 steps/s) [step=4999] [loss=0.275] [~loss=0.273] [~acc=87.34%]                

[saving checkpoint@5000] [median train step: 257.8µs]

Metric: accuracy

Metric: loss

	[Step 5000] median train step: 257 microseconds
Results on batched train:
	Mean Loss (#loss): 0.27
	Mean Accuracy (#acc): 87.48%
Results on test:
	Mean Loss (#loss): 0.28
	Mean Accuracy (#acc): 87.09%

Extend training another 5K steps¶

Since the model training went well, and it doesn't seem to be yet terribly overfiting, let's train further, another 5k steps, for 10K steps in total.

Notice the plots continue from where it stopped. And this time we use Plotly to plot the training results -- they don't display in Github since they depend on javascript.

Unfortunately, it doesn't help (the accuracy on the test set doesn't improve), 5k steps was already enough.

In [10]:
%% --checkpoint base_model -set="plots=true;train_steps=10_000"
store, paramsSet := storeFromSettings()
fmt.Printf("train_steps=%d\n", model.GetRootParamOr(store, "train_steps", 0))
trainModel(store, paramsSet)
train_steps=10000
	- restarting training from global_step=5000
       100% [========================================] (1486 steps/s) [step=9999] [loss=0.233] [~loss=0.269] [~acc=87.30%]                  

[saving checkpoint@10000] [median train step: 243.2µs]

Metric: accuracy

Metric: loss

	[Step 10000] median train step: 243 microseconds
Results on batched train:
	Mean Loss (#loss): 0.268
	Mean Accuracy (#acc): 87.43%
Results on test:
	Mean Loss (#loss): 0.279
	Mean Accuracy (#acc): 87.10%

Using Kolmogorov-Arnold Networks (KAN)¶

Since it's avaialable as a layer (see package kan), the model supports it by simply changing a hyperparameter.

See description in https://arxiv.org/pdf/2404.19756

In [11]:
// Remove previously trained model -- skip this cell, if you want to continue training.
!rm -rf ~/work/uci-adult/kan_model
In [12]:
%% --checkpoint kan_model -set="kan=true;activation=swish;plots=true;train_steps=10_000"
trainModel(storeFromSettings())
         7% [=>......................................] (106 steps/s) [2s:1m27s] [step=719] [loss=4.29] [~loss=4.35] [~acc=80.09%]        
       100% [========================================] (1669 steps/s) [step=9999] [loss=1.27] [~loss=1.29] [~acc=86.83%]                  

[saving checkpoint@10000] [median train step: 315.8µs]

Metric: accuracy

Metric: loss

	[Step 10000] median train step: 315 microseconds
Results on batched train:
	Mean Loss (#loss): 1.28
	Mean Accuracy (#acc): 86.86%
Results on test:
	Mean Loss (#loss): 1.29
	Mean Accuracy (#acc): 86.40%