โš ๏ธ๐Ÿšง This site is currently under construction. Documentation is actively being written and expected to be released along with the next GoMLX release v0.28.0. ๐Ÿšงโš ๏ธ
Machine learning for Go

Build ML models
at Go speed.

GoMLX is an easy-to-use ML and math framework โ€” like PyTorch, Jax, or TensorFlow for Go. No Python dependency, runs anywhere Go runs, with XLA-accelerated JIT compilation.

Runs on CPU ยท GPU ยท TPU via XLA  ยท  Apache 2.0  ยท  Go 1.25+

Familiar Go patterns,
ML power underneath.

// Two-layer MLP on MNIST โ€” full working example
package main

import (
    "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/ml/layers"
    "github.com/gomlx/gomlx/ml/train"
)

// modelGraph defines your model as a pure function.
// GoMLX traces it once, compiles to XLA, runs fast.
func modelGraph(ctx *context.Context, spec any, inputs []*graph.Node) []*graph.Node {
    x := inputs[0]                                  // [batch, 784]
    x  = layers.Dense(ctx.In("hidden"), x, true, 128) // โ†’ [batch, 128]
    x  = graph.Relu(x)
    x  = layers.Dense(ctx.In("output"), x, true, 10)  // โ†’ [batch, 10]
    return []*graph.Node{x}
}
// Build a trainer with loss, optimizer, and eval callbacks
trainer := train.NewTrainer(manager, ctx, modelGraph,
    losses.SparseCategoricalCrossEntropyLogits,
    optimizers.Adam(),
    train.EveryNSteps(100, plotMetricsFn),
    train.NanLogger(),
)

// Loop runs entirely on-device โ€” no Python round-trips
loop := train.NewLoop(trainer)
loop.RunSteps(trainDS, 10_000)

// Evaluate on test set
metrics := trainer.Eval(testDS)
fmt.Printf("Test accuracy: %.2f%%\n", metrics["accuracy"]*100)
// Graphs are compiled functions โ€” define once, run anywhere
manager := backends.New() // auto-selects GPU > CPU

computeFn := graph.Compile(manager, func(g *graph.Graph) *graph.Node {
    x := graph.Parameter(g, "x", shapes.Make(dtypes.Float32, 3))
    w := graph.Const(g, []float32{1, 2, 3})
    return graph.ReduceSum(graph.Mul(x, w), 0)
})

result := computeFn.Call(xTensor)
fmt.Println(result) // Tensor[float32: scalar] = 14

Everything you need.
Nothing you don't.

XLA acceleration

Compile your model graph once to XLA. The same Go code runs on CPU, NVIDIA GPU, and Google TPU without modification. Full JIT compilation, kernel fusion, and memory optimization built in.

Learn about backends โ†’

Static typing, caught at compile time

Go's type system catches shape mismatches before your first training step. No more cryptic runtime crashes three epochs into a run. Tensor shapes are first-class citizens.

Explore the type system โ†’

Pure Go

One go get, no virtual environments, no Python runtime, no version conflicts. Just idiomatic Go that your whole team already knows.

Rich layers API

Dense, Conv, Attention, BatchNorm, Dropout, Embeddings โ€” all in one package, composable, easy to extend with custom layers.

Live plotting

Built-in Gonb/Jupyter integration plots training metrics in real time. Loss curves, accuracy, custom metrics โ€” all without leaving Go.

Modular design

Swap backends, optimizers, and data pipelines independently. Use just the graph layer, or the full training stack โ€” your choice.

Questions? Join the conversation.

Ask on the #gomlx Slack channel, open a GitHub issue, or contribute an example. The library is actively maintained and the community is growing.

-- GitHub stars
-- Forks
v0.27.3 Latest release
Apache 2 License