Build ML models
at Go speed.
GoMLX is an easy-to-use ML and math framework โ like PyTorch, Jax, or TensorFlow for Go. No Python dependency, runs anywhere Go runs, with XLA-accelerated JIT compilation.
Familiar Go patterns,
ML power underneath.
// Two-layer MLP on MNIST โ full working example
package main
import (
"github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/ml/layers"
"github.com/gomlx/gomlx/ml/train"
)
// modelGraph defines your model as a pure function.
// GoMLX traces it once, compiles to XLA, runs fast.
func modelGraph(ctx *context.Context, spec any, inputs []*graph.Node) []*graph.Node {
x := inputs[0] // [batch, 784]
x = layers.Dense(ctx.In("hidden"), x, true, 128) // โ [batch, 128]
x = graph.Relu(x)
x = layers.Dense(ctx.In("output"), x, true, 10) // โ [batch, 10]
return []*graph.Node{x}
}// Build a trainer with loss, optimizer, and eval callbacks
trainer := train.NewTrainer(manager, ctx, modelGraph,
losses.SparseCategoricalCrossEntropyLogits,
optimizers.Adam(),
train.EveryNSteps(100, plotMetricsFn),
train.NanLogger(),
)
// Loop runs entirely on-device โ no Python round-trips
loop := train.NewLoop(trainer)
loop.RunSteps(trainDS, 10_000)
// Evaluate on test set
metrics := trainer.Eval(testDS)
fmt.Printf("Test accuracy: %.2f%%\n", metrics["accuracy"]*100)// Graphs are compiled functions โ define once, run anywhere
manager := backends.New() // auto-selects GPU > CPU
computeFn := graph.Compile(manager, func(g *graph.Graph) *graph.Node {
x := graph.Parameter(g, "x", shapes.Make(dtypes.Float32, 3))
w := graph.Const(g, []float32{1, 2, 3})
return graph.ReduceSum(graph.Mul(x, w), 0)
})
result := computeFn.Call(xTensor)
fmt.Println(result) // Tensor[float32: scalar] = 14Everything you need.
Nothing you don't.
XLA acceleration
Compile your model graph once to XLA. The same Go code runs on CPU, NVIDIA GPU, and Google TPU without modification. Full JIT compilation, kernel fusion, and memory optimization built in.
Learn about backends โStatic typing, caught at compile time
Go's type system catches shape mismatches before your first training step. No more cryptic runtime crashes three epochs into a run. Tensor shapes are first-class citizens.
Explore the type system โPure Go
One go get, no virtual environments, no Python runtime, no version conflicts. Just idiomatic Go that your whole team already knows.
Rich layers API
Dense, Conv, Attention, BatchNorm, Dropout, Embeddings โ all in one package, composable, easy to extend with custom layers.
Live plotting
Built-in Gonb/Jupyter integration plots training metrics in real time. Loss curves, accuracy, custom metrics โ all without leaving Go.
Modular design
Swap backends, optimizers, and data pipelines independently. Use just the graph layer, or the full training stack โ your choice.
From MNIST to transformers.
Working notebooks, command-line tools, and models you can run right now.
Gemma 3 270M
Demonstrates ONNX-converted text generation (LLM) using the Gemma 3 model.
CIFAR-10 Demo
Convolutional network image classifier available as a Jupyter Notebook.
KaLM-Gemma3 12B
Tencent's top-ranked sentence encoder loaded with go-huggingface.
AlphaZero for Hive
AI board evaluation using a GNN. Runs directly in your browser!
Questions? Join the conversation.
Ask on the #gomlx Slack channel, open a GitHub issue, or contribute an example. The library is actively maintained and the community is growing.