โš ๏ธ๐Ÿšง This site is currently under construction. Documentation is actively being written and expected to be released along with the next GoMLX release v0.28.0. ๐Ÿšงโš ๏ธ

Core Concepts

Understand the three building blocks of every GoMLX program: the backend manager, computation graphs, and the context.

Overview

GoMLX is built on three layered abstractions. Understanding them makes every other part of the library click:

  1. Manager โ€” the connection to a hardware backend (CPU, GPU, TPU)
  2. Graph โ€” a computation graph that you define as a pure Go function
  3. Context โ€” a store for named, typed model parameters (weights)

You can use just the manager and graph for mathematical computing, or add the context to build full trainable models.


The Manager

The manager connects your Go process to a hardware backend. Create one at program startup and reuse it everywhere:

import "github.com/gomlx/gomlx/backends"

manager := backends.New() // auto-selects best available backend

The manager owns the device memory, compiles graphs to native code, and manages data transfer between host and device. One manager per process is the typical pattern.

backends.New() selects the best available backend in order: CUDA GPU โ†’ Metal (Apple) โ†’ CPU. To pin a specific backend, use backends.NewWithName("cpu").

Computation Graphs

A graph is a pure function that describes a computation in terms of *graph.Node values. GoMLX traces this function once, compiles it to XLA HLO, and produces an executable that runs entirely on the device.

// Define a graph function โ€” just a Go function returning nodes
addFn := graph.Compile(manager, func(g *graph.Graph) *graph.Node {
    a := graph.Parameter(g, "a", shapes.Make(dtypes.Float32, 4))
    b := graph.Parameter(g, "b", shapes.Make(dtypes.Float32, 4))
    return graph.Add(a, b)
})

// Execute it โ€” inputs move to device, result moves back
result := addFn.Call(tensorA, tensorB)

Why graphs?

This design gives XLA visibility over the entire computation so it can apply aggressive optimizations: operator fusion, memory layout selection, and loop unrolling โ€” automatically.

Your Go code never runs on the GPU. Only the compiled graph runs there. This is the same design used by JAX and TensorFlow’s tf.function.

Nodes are values, not tensors

Inside a graph function, *graph.Node represents a future value. You cannot inspect its contents during graph construction โ€” only after calling .Call(). Operations on nodes describe the graph structure.

// This is graph construction โ€” no computation happens here
x := graph.Parameter(g, "x", shapes.Make(dtypes.Float32, 3))
y := graph.Mul(x, graph.Const(g, float32(2.0)))
z := graph.ReduceSum(y, 0) // sum all elements

// This executes the compiled graph on-device
result := compiledFn.Call(xTensor) // returns a *tensors.Tensor
fmt.Println(result.Value()) // []float32{...}

Shapes and Dtypes

Every node has a shape: a list of dimension sizes, plus a dtype. GoMLX checks shape compatibility at graph construction time โ€” mismatches are caught before any computation runs.

// Shape: [batch, height, width, channels]
imgShape := shapes.Make(dtypes.Float32, 32, 224, 224, 3)

// Scalar
scalarShape := shapes.Make(dtypes.Float64) // no dimensions

// Check shape at construction
x := graph.Parameter(g, "x", imgShape)
fmt.Println(x.Shape()) // Float32[32, 224, 224, 3]

Common dtypes: dtypes.Float32, dtypes.Float64, dtypes.Int32, dtypes.Int64, dtypes.Bool.


The Context

The context is a hierarchical store for model parameters. Think of it as the model’s named weight dictionary, with Go’s type safety built in.

import "github.com/gomlx/gomlx/ml/context"

ctx := context.New()

// Inside a graph function, variables are created or retrieved by name
func denseLayer(ctx *context.Context, x *graph.Node, units int) *graph.Node {
    w := ctx.WithInitializer(initializers.GlorotUniform).
        VariableWithShape("weights", shapes.Make(dtypes.Float32, x.Shape().Dim(-1), units))
    b := ctx.VariableWithShape("bias", shapes.Make(dtypes.Float32, units))
    return graph.Add(graph.MatMul(x, w.ValueGraph(x.Graph())), b.ValueGraph(x.Graph()))
}

Scoping

Use ctx.In("name") to create named sub-scopes, which keeps weight names unique across layers:

x = denseLayer(ctx.In("layer1"), x, 128) // weights stored at "layer1/weights"
x = denseLayer(ctx.In("layer2"), x, 64)  // weights stored at "layer2/weights"

Checkpointing

The context can serialize all its variables to disk and restore them:

checkpoint := checkpoints.Build(ctx).Dir("/tmp/my-model").Done()
checkpoint.Save()   // saves all variables to disk
checkpoint.Restore() // restores from the latest checkpoint

Putting it together

Here is the minimal skeleton of a trainable GoMLX program:

func main() {
    // 1. Connect to hardware
    manager := backends.New()

    // 2. Create a context to hold weights
    ctx := context.New()

    // 3. Define your model as a graph function
    trainer := train.NewTrainer(manager, ctx, myModelFn,
        losses.SparseCategoricalCrossEntropyLogits,
        optimizers.Adam(),
    )

    // 4. Run the training loop
    loop := train.NewLoop(trainer)
    loop.RunSteps(trainDataset, 10_000)
}

Each of these pieces โ€” manager, graph, context, trainer โ€” is independently replaceable. You can swap the optimizer, the backend, or the loss function without touching the rest.

Last updated April 23, 2026