![[GoMLX Mascot]](gomlx_gopher2.jpg)
GoMLX Tutorial¶
If you want just to quickly look at an working example, checkout examples/cifar/demo/adult.ipynb, for a model trained on the UCI Adult Income dataset. More examples in examples/ subdirectory.
The tutorial won't detail the whole API, but should present all important concepts. Everything else is well documented in godoc (in the code), also available in pkg.go.dev.
The tutorial was written using a Jupyter notebook with GoNB, a kernel for Go that was co-developed with GoMLX. It has its own short tutorial for those interested. But GoMLX doesn't require GoNB.
If you are seeing this tutorial from github snapshop, you won't be able to interact with it. To be able to play with it, try installing GoMLX, see its README Installation section. The easiest way is to start the pre-generated docker and use the Jupyter notebook there -- this tutorial can be opened from there in an interactive way.
Note: Output Not Displaying in JupyterLab ?¶
If your notebook plots are not displaying correctly, it's likely because Jupyter is assuming the notebook is "not trusted". This can be easily fixed by opening the command palette (Shift+Control+C in PCs) and selecting Trust Notebook.
There is an issue opened in JupyterLab attempting to address that.
Computation Graphs (or Symbolic Computation)¶
To do machine learning based on neural networks and gradient descent one of the most important requirements is the ability to do mathematical computations (mostly matrix multiplications) fast.
GoMLX is built on the concept of building "computation graphs", just-in-time compiling them and only then executing them to get the desired results. That means one has to write code that generates another type of code (computation graph) so to say. We do this because then we are able to execute it really fast using XLA using various accelerators (GPU, TPUs, etc).
For example, let's create a computation graph to sum two values:
Note
- If executing this on a notebook, notice the very first cell execution takes a few seconds for Go to fetch the required dependencies.
import . "github.com/gomlx/gomlx/core/graph"
func Sum(a, b *Node) *Node {
return Add(a, b)
}
Note
- The
import . "github.com/gomlx/gomlx/core/graph"imports all definitions in computation to the current scope. Usually, it's easier to work like this on go files that are going to implement graph building functions.- The type
*Noderepresents a node in the computation graph. All graph operations either take a*Graphobject to start with, or a*Node, and create new nodes with the corresponding operations. So our example will create anAddnode that will take the nodes pointed byaandb, build a node that represents their summation and then return this*Node.- Every node contains a reference to the
*Graphit's part of (seeNode.Graph()).- There is a rich set of operations available in GoMLX, see
Nodedocumentation.
Ok, but this is "symbolic" only, it won't tell us what is 1+1 yet. We need to compile and then execute this graph with some input values.
Backends and executing Graphs with graph.Exec¶
GoMLX can work with different backends to execute its operations. Currently 2 backends are supported:
"xla": It uses OpenXLA/XLA, a "state of the art" execution engine for ML operations, that supports CPU (with SIMD), various accelerators (GPUs, TPUs), JIT (Just In Time) compilation, providing the best performance. For now though, it only works in Linux/amd64 platform."go": A pure Go backend engine. Much slower thanxla, but very portable, likely will work in all platforms supported by Go -- tested in Linux, Windows, WASM (yes, it runs in the browser).
Both backends are included if you import _ "github.com/gomlx/gomlx/backends/default" like in the example below (xla only supported in Linux/Windows/Darwin, with GPU support only on Linux). Or you can include them individually.
Creating a backend (compute.Backend object) is trivial: compute.New() (or compute.MustNew()).
If the environment variable GOMLX_BACKEND is set, it can define the backend and its configuration (e.g.: "go", "xla:cpu", "xla:cuda", etc.). Here we set it to xla:cpu to force the CPU backend for this tutorial.
With the backend created, Exec is the easiest way to compile and execute computation graphs in GoMLX.
To run our Sum function above we can do:
%env GOMLX_BACKEND=xla:cpu
import "github.com/gomlx/compute"
import _ "github.com/gomlx/gomlx/backends/default"
// must exits with an error, logging it.
func must[T any](value T, err error) T {
if err != nil {
log.Fatalf("Fatal error: %+v", err)
}
return value
}
var backend = must(compute.New())
%%
// Short version: `CallOnce` will compile the function and execute only once.
two := must(CallOnce(backend, /* GoMLX function */ Sum, /* Arguments */ 1.0, 1.0))
fmt.Printf("CallOnce:\t1+1=%v\n", two)
// Exec version: `NewExec` will compile the Sum function so it can be executed efficiently
// many times.
exec := must(NewExec(backend, Sum))
results := must(exec.Call(1, 1))
two = results[0]
fmt.Printf("Exec/Call:\t1+1=%v\n", two)
Set: GOMLX_BACKEND="xla:cpu" CallOnce: 1+1=float64(2) Exec/Call: 1+1=int64(2)
Note
%%is a shortcut forfunc main(): everything after it is put inside amainfunction by GoNB, the Go Notebook kernel.- We define the
mustfunction: a convenient way to handle errors for notebooks, by immediately logging the error and exiting.compute.New()creates acompute.Backendobject, which connects to an accelerator if present.compute.MustNewis the version that panics on error. Usually one creates one at the beginning of the program and passes it around. Here GoNB (the Go Notebook kernel) will keep the global variablebackendavailable to all cells, so we don't need to define it again.- The
Execobject created is associated with a graph building function (Sumin this case). It lazily compiles and executes the compiled computation as needed. Naturally the first timeCall()is invoked it is slow: it has to build the graph and just-in-time (JIT) compile it. But the compiled graph afterwards is optimized and very fast to execute, which is what we want for machine learning.- The
Exec.Call(1, 1)method always returns a slice of results, that's why we use[0]to access the first result.- The results of graph execution are always tensors (see section below). They can be printed as is, or be converted to Go types using
.Value()method.
Something important to understand is that Graphs have static (fixed) shapes for its inputs and outputs. What
it means is that, for example, if you are going to sum floats instead of ints, Exec would have to rebuild the graph
to take as input two floats. Or if you want to sum a vector or matrix or ints, or any different shapes.Shape.
For a detailed explanation of
Shapeand the associate concepts of Axis, Dimensions andDType(the underlying data type), see packageshapesdocumentation. Dynamic shapes (only the input dimension is dynamic, the rank is fixed) is also supported in some backends, see reference documentation for details.
To exemplify, let's expand our code a bit:
import (
"fmt"
. "github.com/gomlx/gomlx/core/graph"
"github.com/gomlx/compute/dtypes"
)
// Sum is the symbolic function (creates a graph) of summing two numbers.
func Sum(a, b *Node) *Node {
g := a.Graph() // Graph on which this symbolic computation is being built.
fmt.Printf("\n- Building graph (GraphId=%d) for a.shape=%s and b.shape=%s\n", g.GraphId(), a.Shape(), b.Shape())
return Add(a, b)
}
func main() {
sumExec := must(NewExec(backend, Sum))
two := sumExec.MustCall1(1, 1)
fmt.Printf("\t1+1=%s\n", two)
for ii := 0; ii < 5; ii++ {
sumInts := sumExec.MustCall1(ii, ii)
fmt.Printf("\t%d+%d=%s\n", ii, ii, sumInts)
}
five := sumExec.MustCall1(3.5, 1.5)
fmt.Printf("\t3.5+1.5=%s\n", five)
many := sumExec.MustCall1([]float32{1.1, 2.2, 3.3}, float32(10))
fmt.Printf("\t[1.1, 2.2, 3.3] + 10 = %s\n", many)
}
- Building graph (GraphId=0) for a.shape=(Int64) and b.shape=(Int64)
1+1=int64(2)
0+0=int64(0)
1+1=int64(2)
2+2=int64(4)
3+3=int64(6)
4+4=int64(8)
- Building graph (GraphId=1) for a.shape=(Float64) and b.shape=(Float64)
3.5+1.5=float64(5)
- Building graph (GraphId=2) for a.shape=(Float32)[3] and b.shape=(Float32)
[1.1, 2.2, 3.3] + 10 = [3]float32{11.1, 12.2, 13.3}
Note
- Each time new input shapes are found, a new graph is created. We added a
fmt.Printfto tell us the GraphId and more importantly the shape of the graph operands. Notice thatfmt.Printfis not part of the computation graph, it's only part of the graph building function. We'll see laterNode.SetLoggedon how to print intermediary results in the middle of the execution of the graph.- Every
Nodehas an associated shape (shapes.Shapetype). A shape is defined by its underlying data typedtypes.DTypeand its axes dimensions. For scalars, the shape has zero axes (dimensions). E.g.:(Int64)[]represents a scalarintvalue, and(Float32)[3]represents a vector with 3float32values. More details and the list of data types (aka. dtype) supported in the package github.com/gomlx/compute/dtypes.Exec.MustCall1is an alias to the previousExec.Call1that conveniently: (a) panics on error; (b) returns the first result.
In general the graph operations only work with the same dtypes.DType.
If they are different, they are reported back with a panic (works like an exception, and can be caught) with an error with a full stack-trace in the returned result.
And with that, let's talk how error is handled in GoMLX computational graph building:
Exec aliases:¶
Because the graph.Exec (and later the similar model.Exec, in a few sections below) is so central, we provide some aliases to make it more "ergonomic" to use.
To create an Exec object:
NewExec,MustNewExec: the first returns an error, the second panics.
To create, execute and finalize an Exec object in one line, for when we want to execute a computation only once:
Exec.CallOnceandExec.CallOnceN: the first can be used if the computation returns only one result. The second if the computation returns > 1 result, it returns a slice of tensors (see below).
To execute an Exec object:
Exec.Call,Exec.Call1,Exec.Call2, ...Exec.Call4: The first returns a slice of tensors;Call1returns only one tensor, it assumes the computation graph has only one output;Call2returns 2 tensors, and so on. They all also return an error.Exec.CallWithGraphexecutes the computation and returns the results (as a slice of tensors) as well as thegraph.Graphobject that holds the compiled computation.
There is always a MustCall... version of each method that panics on error, instead of returning them.
Error Handling¶
During the graph building (or what we call symbolic computation) GoMLX diverges from the usual Golang idiomatic error handling: it uses exceptions in the form of a panic() (with an error object with a stack trace).
Usually, the symbolic computation functions (like the Sum function we defined above) are called by the Exec.Call methods. And they capture any panics and return them as errors to the caller.
Note: The author is aware this is controvertial for some. So below are the motivations to use exceptions, as opposed to returning an error everywhere:
- While implementing long mathematical functions, introducing error handling in between each operation would be too distracting, and unwieldly. Notice the Go language makes the same choice in a few situations; e.g.: that's why the division operator
/doesn't return the result of the division and an error, and instead panics with division by zero.- Speed of building the computation graph is not a concern: it has no impact on the execution of the JIT-compiled code later on.
Let's create an example with an error to see how this works:
%%
sumExec := MustNewExec(backend, Sum)
_ = must(sumExec.Call(1.1, 2)) // Error: arguments have different dtypes Float64 and Int64.
- Building graph (GraphId=0) for a.shape=(Float64) and b.shape=(Int64)
2026/06/09 13:45:31 Fatal error: cannot broadcast Float64 and Int64 for "Add": they have different dtypes github.com/gomlx/go-xla/compute/xla.(*Builder).broadcastForBinaryOps /home/janpf/Projects/gomlx/go-xla/compute/xla/builder.go:317 github.com/gomlx/go-xla/compute/xla.(*Function).Add /home/janpf/Projects/gomlx/go-xla/compute/xla/gen_binary_ops.go:13 github.com/gomlx/gomlx/core/graph.Add /home/janpf/Projects/gomlx/gomlx/core/graph/gen_compute_ops.go:212 main.Sum [[ Cell [3] Line 11 ]] /tmp/gonb_bd1ce53c/main.go:20 github.com/gomlx/gomlx/core/graph.convertExecFn[...].func5 /home/janpf/Projects/gomlx/gomlx/core/graph/gen_constraints.go:84 github.com/gomlx/gomlx/core/graph.(*Exec).buildAndCompileGraph.func1 /home/janpf/Projects/gomlx/gomlx/core/graph/exec.go:801 github.com/gomlx/gomlx/support/exceptions.TryCatch[...] /home/janpf/Projects/gomlx/gomlx/support/exceptions/exceptions.go:91 github.com/gomlx/gomlx/core/graph.(*Exec).buildAndCompileGraph /home/janpf/Projects/gomlx/gomlx/core/graph/exec.go:799 github.com/gomlx/gomlx/core/graph.(*Exec).findOrCreateGraph /home/janpf/Projects/gomlx/gomlx/core/graph/exec.go:981 github.com/gomlx/gomlx/core/graph.(*Exec).compileAndExecute /home/janpf/Projects/gomlx/gomlx/core/graph/exec.go:644 github.com/gomlx/gomlx/core/graph.(*Exec).CallWithGraphOnDevice /home/janpf/Projects/gomlx/gomlx/core/graph/exec.go:432 github.com/gomlx/gomlx/core/graph.(*Exec).ExecWithGraphOnDevice /home/janpf/Projects/gomlx/gomlx/core/graph/exec.go:444 github.com/gomlx/gomlx/core/graph.(*Exec).CallWithGraph /home/janpf/Projects/gomlx/gomlx/core/graph/execaliases.go:122 github.com/gomlx/gomlx/core/graph.(*Exec).Call /home/janpf/Projects/gomlx/gomlx/core/graph/execaliases.go:38 main.main [[ Cell [4] Line 3 ]] /tmp/gonb_bd1ce53c/main.go:35 runtime.main /home/janpf/sdk/go1.26.3/src/runtime/proc.go:290 runtime.goexit /home/janpf/sdk/go1.26.3/src/runtime/asm_amd64.s:1771 graphFn for "Exec:main.Sum" panicked compilation failed for Exec "Exec:main.Sum" failed during graph construction/compilation exit status 1
exit status 1
Note
- In the stack-trace above there are 2 lines of interest, that typically help to debug such issues:
- Where in the graph building function
main.Sumfunction the invalid operation was created: Line 9 of the previous cell.- Where in the
mainfunction, the graph was attempted to be executed: Line 3 of this cell.- You can enable displaying line-numbers in the JupyterLab with "ESC+L" (upper-case L).
Tensors¶
Tensors are multidimensional arrays of a given data type (dtypes.DType) defined in the package github.com/gomlx/gomlx/core/tensors.
For GoMLX tensors work as containers of data that are used as concrete inputs and outputs for the
execution of computational graphs. There is only basic support to manipulate tensors directly (it includes direct access to its data): instead usually one operates on the tensor using computational graphs and the Exec object. Tensors have a shape (shapes.Shape) just like a graph.Node.
When talking about a tensor, we are referring to a *tensors.Tensor object.
The tensor raw values can be stored either locally (as a Go slice) or on the device where the graph is being executed. E.g.: a GPU, or even a CPU, in which case it is owned by the compute.Backend.
While the tensors.Tensor object handles the synchronization automatically among the local/on-device versions, it's important
that you are aware that there is this distinction. There is a time cost to transfer data from, for example, a GPU to the CPU and back.
Effectively, what that means is that, during training, you dont' want to transfer anything locally often,
for instance to save a checkpointHandler. Most of this is handled automatically by GoMLX libraries, but not everything.
Example: because some of the resources are allocated in the accelerator (GPU) and not managed by Go, the garbage collector may not be aware of the memory pressure on these devices--there is a Tensor.Finalize call that immediately releases the data associated with a tensor.
Whenever a graph is executed (with Exec.Call), non-tensor input values are automatically converted to tensors, and those are transferred to the device doing the execution automatically.
Example:
%%
onePlusExec := MustNewExec(backend, func (x *Node) *Node {
return OnePlus(x)
})
// exec.MustCall1 will return a *tensor.Tensor.
counter := onePlusExec.MustCall1(0)
// counter.String() will first transfer counter to local (using counter.Local()) to print its values.
fmt.Printf("counter.type=%T, counter=%s\n", counter, counter.String())
for ii := 0; ii < 10; ii++ {
// Since the counter is not being used locally between the calls, the tensor will only use the
// device storage.
counter = onePlusExec.MustCall1(counter)
}
// Again counter.String() (called implicitly by fmt.Printf) will first transfer the counter value locally, and then convert to a Go value.
fmt.Printf("counter=%v\n", counter)
counter.type=*tensors.Tensor, counter=int64(1) counter=int64(11)
Note:
- In the first call to
onePlusExec.MustCall1(0), the Go constant0is automatically converted to a*tensor.Tensorby thegraph.Execand fed to the graph. It returns a[]*tensor.Tensorwith one element, containing0+1=1.- The returned tensor is initially only stored on device. But when we print it, it is automatically transferred to Go.
- While executing the loop the counter will not be transferred locally, everything happens efficiently in the accelerator device (GPU, TPU, or even the CPU executor).
- When we print the final result in
counter.String()again the data is transferred locally automatically.
New tensors¶
There are several ways to create tensors, the most common:
tensors.FromValue[S](value S): Generics conversion, works with any supportedDTypescalar as well as with any arbitrary multidimensional slice. Slices of rank > 1 must be regular, that is all the sub-slices must have the same shape. E.g:FromValue([][]float{{1,2}, {3, 5}, {7, 11}}). There is also a non-generic versiontensors.MustMustFromAnyValue(value any). This is whatExec.Call()uses to convert arbitrary values to a tensor.FromShape(shape shapes.Shape): creates a Local tensor with the given shape, and initialized to zero. See documentation onTensor.MutableFlatData(...)to mutate tensors in place.FromScalarAndDimensions[T](value T, dimensions ...int): creates a tensor with the given dimensions, filled with the scalar value given.Tmust be one of the supported types.
You can print the contents of a tensors.Tensor simply using Tensor.String(), or convert it back to the
corresponding Go type using Tensor.Value() (it returns an any that you can cast to the expected type).
See more documentation in github.com/gomlx/gomlx/core/tensors.
Errors in the manipulation of Tensors (e.g. invalid values) are also reported back with exception thrown with panic, with full stack-traces -- they only happen due to invalid inputs (nil values, wrong data types, etc.). The errors can easily be caught (with recover() or with exceptions.TryCatch helper) when needed.
Gradients¶
An important functionality required to train machine learning models based on gradient descent is calculating the gradients of some value (loss) being optimized with respect to some variable / quantity.
GoMLX does this symbolically, during graph building time. It adds to the graph the computation for the gradient.
Example: let's calculate the gradient of the function $f(x, y) = x^2 + xy$ for a few values of $x$ and $y$. Algebraically we have:
$$ \begin{aligned} f(x, y) &= x^2 + xy \\ df/dx(x,y) &= 2x + y \\ df/dy(x,y) &= x \\ \end{aligned} $$
func f(x, y *Node) *Node {
return Add(Square(x), Mul(x, y))
}
func gradOfF(x, y *Node) (output, gradX, gradY *Node) {
output = f(x, y)
reduced := ReduceAllSum(output) // In case x and y are not scalars.
grads := Gradient(reduced, x, y)
gradX, gradY = grads[0], grads[1] // df/dx, df/dy
return output, gradX, gradY
}
%%
exec := MustNewExec(backend, gradOfF)
x := []float64{0, 1, 2}
y := []float64{10, 20, 30}
output, gradX, gradY := exec.MustCall3(x, y)
fmt.Printf("f(x=%v, y=%v)=%s,\n\tdf/dx=%s,\n\tdf/dy=%s\n", x, y, output, gradX, gradY)
f(x=[0 1 2], y=[10 20 30])=[3]float64{0, 21, 64},
df/dx=[3]float64{10, 22, 34},
df/dy=[3]float64{0, 1, 2}
Note:
- As of yet, GoMLX only calculates gradients of a scalar (typically a model loss) with respect to arbitrary tensors. It does not yet calculate jacobians, that is, if the value we are deriving is not a scalar.
- A question that may arise is whether it calculates the second derivative (hessian). In principle the machinery to do that is in place, but there are 2 limitations: (1) not all operations have their derivative implemented, in particular some of the operations that are only used when calculating the first derivative; (2) it only calculates the gradient with respect to a scalar, in most cases the hessian will be the gradient of a gradient, usually of higher rank.
- It's not impossible (or hard) to add support for those. If you need it, contributions to the project here are welcome ;)
- The
Gradientfunction uses back-propagation (more efficient if calculating the gradient of a single output). Like with jacobians and hessians, it could be extended to calculate forward-propagation if needed.
// Removing the previous definitions of `f` and `gradOfF`
%rm gradOfF f
. removed func gradOfF . removed func f
Machine Learning with GoMLX¶
Variables and scope¶
Computation graphs are pure functions: they have no state, they take inputs, return outputs and everything in between is transient.
For Machine Learning as well as many types of computations, it's convenient to use and update "scoped" information about a model being trained or executed, like its weights and hyperparameters. We want to be able to easily access and update these weights in the middle of a computation graph.
In GoMLX this functionality is provided by the github.com/gomlx/gomlx/ml/model package, in the form of:
model.Store: a container object, holding all the variables and hyperparamters associated with the model, stored as a tree (like a directory).model.Scope: a pointer with a "scope" (like a "current directory") into itsStorecontainer.model.Variable: a reference to a variable, with a link to its current value, and a method to access it in the currentgraph.Graph.model.Exec: an executor similar to thegraph.Execwe already discussed, but which also attaches aStoreto thegraph.Graphbeing built, and handle the "magic" of passing as "side-inputs" the variables used in the graph, and as "side-outputs" the variable updated, without the user needing to think about it.
This may sound more complex than it is in practice. Let's see a simple example of a "model" that maintains a counter:
import "github.com/gomlx/gomlx/ml/model"
%%
store := model.NewStore()
exec := model.MustNewExec(backend, store, func(scope *model.Scope, g *Graph) *Node {
counterVar := scope.VariableWithValue("count", int32(10)) // Creates variable with 0, if it doesn't yet exist.
count := counterVar.NodeValue(g) // Get the current value of the variable as a graph.Node.
count = AddScalar(count, 1) // Increment it.
counterVar.SetNodeValue(count) // Store the value (as a graph.Node) into the variable.
return count
})
fmt.Println("Counting:")
for range(3) {
fmt.Printf("\tcount=%s\n", exec.MustCall1())
}
counterVar := store.GetVariable("/count")
fmt.Printf("- State of counter=%s\n", must(counterVar.Value()))
Counting: count=int32(11) count=int32(12) count=int32(13) - State of counter=int32(13)
Note:
- We are using
model.Exec, while before we were usinggraph.Exec. The main difference is thatmodel.Execcompiles and executes graph functions that take amodel.Storeas its first parameter, and it automatically handles the passing of variables as side inputs and outputs (for those variables updated) of the computation graph.- During graph building, we access and set the variables's symbolic values (
graph.Node) withVariable.NodeValueandVariable.SetNodeValue.- Outside graph building, we can access the last value set to a variable by using
Variable.Value()andVariable.SetValue. They return/take concrete*tensor.Tensortypes.
Finding $argmin_{x}{f(x)}$ example¶
A more elaborate example: let's find $argmin_{x}{f(x)}$ where $f(x) = ax^2 + bx + c$.
If we solve it literally we should get, for $a > 0$, $argmin_{x}{f(x)} = \frac{-b}{2a}$.
Instead, we solve it numerically, using gradient descent:
import (
"flag"
"github.com/gomlx/compute/shapes"
"github.com/gomlx/compute/dtypes"
)
var (
flagA = flag.Float64("a", 1.0, "Value of a in the equation ax^2+bx+c")
flagB = flag.Float64("b", 2.0, "Value of b in the equation ax^2+bx+c")
flagC = flag.Float64("c", 4.0, "Value of c in the equation ax^2+bx+c")
flagNumSteps = flag.Int("steps", 10, "Number of gradient descent steps to perform")
flagLearningRate = flag.Float64("lr", 0.1, "Initial learning rate.")
)
// f(x) = ax^2 + bx + c
func f(x *Node) *Node {
f := MulScalar(Square(x), *flagA)
f = Add(f, MulScalar(x, *flagB))
f = AddScalar(f, *flagC)
return f
}
// minimizeF does one gradient descent step on F by moving a variable "x",
// and returns the value of the function at the current "x".
func minimizeF(scope *model.Scope, graph *Graph) *Node {
// Create or reuse existing variable "x" -- no graph operation is created with this, it's
// only a reference.
xVar := scope.VariableWithShape("x", shapes.Make(dtypes.Float64))
x := xVar.NodeValue(graph) // Read variable for the current graph.
y := f(x) // Value of f(x).
// Gradient always return a slice, we take the first element for grad of X.
gradX := Gradient(y, x)[0]
// stepNum += 1
stepNumVar := scope.VariableWithValue("stepNum", 0.0) // Creates the variable if not existing, or retrieve it if already exists.
stepNum := stepNumVar.NodeValue(graph)
stepNum = OnePlus(stepNum)
stepNumVar.SetNodeValue(stepNum)
// step = -learningRate * gradX / Sqrt(stepNum)
step := Div(gradX, Sqrt(stepNum))
step = MulScalar(step, -*flagLearningRate)
// x += step
x = Add(x, step)
xVar.SetNodeValue(x)
return y // f(x)
}
func Solve() {
store := model.NewStore()
exec := model.MustNewExec(backend, store, minimizeF)
for ii := 0; ii < *flagNumSteps-1; ii++ {
_ = exec.MustCall()
}
y := exec.MustCall1()
x := must(store.GetVariable("/x").Value())
stepNum := must(store.GetVariable("/stepNum").Value())
fmt.Printf("Minimum found at x=%g, f(x)=%g after %f steps.\n", x.Value(), y.Value(), stepNum.Value())
}
The code above created Solve() that will solve for the values set by the flags a, b, and c.
Let's try a few values:
Note:
%%in GoNB automatically creates afunc main()and passes the extra arguments to the Go program.
%% --a=1 --b=2 --c=3 --steps=10 --lr=0.5
Solve()
Minimum found at x=-1, f(x)=2 after 10.000000 steps.
%% --a=2 --b=12 --c=20 --steps=10 --lr=0.5
Solve()
Minimum found at x=-3, f(x)=2 after 10.000000 steps.
Note:
- We created two variables, one for "x" that we were optimizing to minimize $f(x)$, and one variable "stepNum", used to keep track how many steps were already executed.
- Yes, if we set
--lr=1(the learning rate), it will get to the minimum in one step for the quadratic f(x). 😉
There is more to model.Store, some we'll present
on the next section on Machine Learning, others can be found in its documentation. A few things worth advancing:
scopeis always configured at a certain scope, and variables are unique within its scope. Scope is easily changed withscope.In("new_scope"). So thescopeobject is a scope (a string) and a pointer to the actual data (variables, graph and model parameters).scopealso holds model parameters (concrete Go values), which are also scoped. Those can be hyperparameters for the models (learning rate, regularization, etc.) or anything the user or any library may create a convention for. They are more convenient than using hundreds of flags. See example of its usage in the UCI Adult example.- Similarly
scopealso holds "Graph parameters". Those are very similar to model parameters, but they have one value per Graph. So if a model is created with parameters of different shape (or for training/evaluation), each version will have its own Graph parameters. Don't worry about this now -- if you need it later when building complex graphs, the funcitonality will be there.
Training a machine learning model¶
The previous sections presented the fundamentals of what is needed to implement machine learning. This section we present various sub-packages that provide high level ML layers and tools that make building, training and evaluating a model trivial.
First is the package layers (see code in ml/layers. It provides several composable ML layers.
These are graph building functions, most of which take a *model.Scope as first parameter, where they store
variables or access hyperparameters.
There are several such layers, for example: layers.Dense, layers/fnn.New, layers/kan.New,
layers.Dropout, layers.PiecewiseLinearCalibration (very good for normalization of inputs), layers.BatchNorm,
layers.LayerNorm, layers.Convolution, layers.MultiHeadAttention (for Transformers
layers), etc.
The package train offers two main functionalities: train.Trainer will build a train step and an eval step
graph, given a model graph building function and an optimizer. This graph can be executed in sequence to train a model.
The package also provides train.Loop that simply loop over a train.Dataset interface, reading data and feeding
it to Trainer.TrainStep or Trainer.EvalStep, along with executing configurable hooks on the training loop. One
such hooks is provided by gomlx/train/commandline.AttachProgressBar(loop), it pretty prints the progress during
training on the command line.
There are also a collection of optimizers, initializers, loss functions, metrics, etc. For any functionality there is always an example under the examples/ subdirectory.
Let's look at the simplest ML example: linear,
which trains a linear model on nosiy generated data.
Here are the constants of our problem:
const (
CoefficientMu = 0.0
CoefficientSigma = 5.0
BiasMu = 1.0
BiasSigma = 10.0
)
To generate synthetic data it first randomly chose some random coefficients and bias based on which data is
generated. These selected coefficients is the ones we want to try to learn using ML. The coefficients could
have been selected in Go directly using math/random, but just for fun, we do it using a computation graph.
import (
"github.com/gomlx/compute/shapes"
"github.com/gomlx/compute/dtypes"
"github.com/gomlx/compute"
"github.com/gomlx/gomlx/core/tensors"
)
// initCoefficients chooses random coefficients and bias. These are the true values the model will
// attempt to learn.
func initCoefficients(backend compute.Backend, numVariables int) (coefficients, bias *tensors.Tensor) {
e := MustNewExec(backend, func(g *Graph) (coefficients, bias *Node) {
rngState := Const(g, must(RNGState()))
rngState, coefficients = RandomNormal(rngState, shapes.Make(dtypes.Float64, numVariables))
coefficients = AddScalar(MulScalar(coefficients, CoefficientSigma), CoefficientMu)
rngState, bias = RandomNormal(rngState, shapes.Make(dtypes.Float64))
bias = AddScalar(MulScalar(bias, BiasSigma), BiasMu)
return
})
coefficients, bias = e.MustCall2()
return
}
%%
coef, bias := initCoefficients(backend, 3)
fmt.Printf("Example of target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value())
Example of target: coefficients=[6.3 -7.22 6.58], bias=12
Note
- This code should look familiar, using things we presented earlier in the tutorial. It creates a computation graph to generate randomly the
coefficientsandbias. Then it executes it and returns the result.- Notice that since the computation graph is functional: we need to pass around the random number generator state, which gets updated at each call to
RandomUniformorRandomNormal.
- Alternatively the
model.Storeintroduced earlier can keep the state as a variable, and provides a simpler interface: seescope.RandomUniformandscope.RandomNormal.
Next, we want to generate the data (examples): we generate random inputs, and then the label using the selected coefficients plus some normal noise.
func buildExamples(backend compute.Backend, coef, bias *tensors.Tensor, numExamples int, noise float64) train.Batch {
e := MustNewExec(backend, func(coef, bias *Node) (inputs, labels *Node) {
g := coef.Graph()
numFeatures := coef.Shape().Dimensions[0]
// Random inputs (observations).
rngState := Const(g, must(RNGState()))
rngState, inputs = RandomNormal(rngState, shapes.Make(coef.DType(), numExamples, numFeatures))
coef = ExpandDims(coef, 0)
// Calculate perfect labels.
labels = ReduceAndKeep(Mul(inputs, coef), ReduceSum, -1)
labels = Add(labels, bias)
if noise > 0 {
// Add some noise to the labels.
var noiseVector *Node
rngState, noiseVector = RandomNormal(rngState, labels.Shape())
noiseVector = MulScalar(noiseVector, noise)
labels = Add(labels, noiseVector)
}
return
})
inputs, labels := e.MustCall2(coef, bias)
return train.Batch{Inputs: []*tensors.Tensor{inputs}, Labels: []*tensors.Tensor{labels}}
}
%%
coef, bias := initCoefficients(backend, 3)
fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value())
numExamples := 5
theBatch := buildExamples(backend, coef, bias, numExamples, 0.2)
fmt.Printf("%d dataset examples:\n", numExamples)
inputs := theBatch.Inputs[0].Value().([][]float64)
labels := theBatch.Labels[0].Value().([][]float64)
for ii := 0; ii < numExamples; ii ++ {
fmt.Printf("\tx=%0.3v; label=%0.3v\n", inputs[ii], labels[ii])
}
Target: coefficients=[-1.29 -5.26 -2.09], bias=-22.9 5 dataset examples: x=[0.161 -2.22 0.845]; label=[-12.8] x=[1.47 -0.576 -0.217]; label=[-21.3] x=[0.606 0.173 -0.161]; label=[-24.1] x=[-0.644 -0.0497 1.41]; label=[-24.7] x=[1.53 0.893 -1.01]; label=[-27.2]
Dataset¶
Now the first new concept of this section: train.Dataset is the interface that is used to feed data to
during a training loop or evaluation. It must implement a Dataset.Name() method, and a Dataset.Iter() method to iterate over the data, yielding a train.Batch with inputs and labels.
For our linear model, we are training always with the full training data, so there is only one batch that should be yielded every time.
GoMLX provides several Dataset tools in the package ml/dataset, one of which has exactly what we want dataset.Const(batch), it creates a constant dataset that yields always the same batch.
Note:
- More often it is more work pre-processing data than actually building an ML model ... that's life 🙁
- In the examples/ subdirectory we implement
train.Datasetfor some well known data sets: UCI Adult, Cifar-10, Cifar-100, IMDB Reviews, Kaggle's Dogs vs Cats, Oxford Flowers 102. These can be used as libraries to easily try different models. If you are workig on public datasets, please consider contributing similar libraries.
The package github.com/gomlx/gomlx/ml/dataset provides several other tools to facilitate the work here:
InMemory: reads a dataset into (accelerator) memory, and then serves it from there -- greatly accelerates training.Take(inputDS, n),Batch(inputDS, batchSize),Buffer,Map(inputDS, mapFn), MapOnHost(inputDS, mapFn): usual transformation functions that transform one dataset into another.ParallelizeMap,ParallelizeMapOnHost: parallelized verion ofMapandMapOnHost, for performant transformation of datasets.Const(batch),Zero(): Constant datasets that always yields the same batch.
See also github.com/gomlx/go-huggingface for a project that among other things makes it trivial to iterate over HuggingFace datasets (many standard public datasets are stored there).
Below a small demo creating our training batch, and iterating over it: it yields always the same value.
%%
const numFeatures = 4
const batchSize = 1
trueCoefficients, trueBias := initCoefficients(backend, numFeatures)
fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", trueCoefficients.Value(), trueBias.Value())
theBatch := buildExamples(backend, trueCoefficients, trueBias, batchSize, 0.1)
ds := dataset.Const(theBatch)
fmt.Printf("Dataset: %s\n", ds.Name())
var count int
for batch, err := range ds.Iter() {
if err != nil { panic(err) }
fmt.Printf("- Batch: input=%s, label=%s\n", batch.Inputs[0], batch.Labels[0])
batch.Finalize()
count++
if count==2 {
break
}
}
Target: coefficients=[3.38 -1.14 -0.244 -3.09], bias=-7.57
Dataset: dataset.Const
- Batch: input=[1][4]float64{{0.6384, 0.7118, -0.1299, 1.708}}, label=[1][1]float64{{-11.43}}
- Batch: input=[1][4]float64{{0.6384, 0.7118, -0.1299, 1.708}}, label=[1][1]float64{{-11.43}}
ModelFn¶
Next we write a modelFn to be used by our trainer.
The modelFn should always take a model.Scope (from where to access the variables),
optionally a spec any from the dataset (it is used for multi-task training), zero to three input nodes or
a slice of nodes.
It should output zero to three prediction nodes, or a slice of nodes. It uses generics
to match any of the accepted signatures, so it won't compile if you pass an invalid modelFn.
See model.ModelFnCompatible for details.
During training the predictions returned are fed to the loss function, and during inference they can be used directly.
Our linear example has the simplest model possible:
import "github.com/gomlx/gomlx/ml/model"
func linearModel(scope *model.Scope, input *Node) (prediction *Node) {
return layers.DenseWithBias(scope, input, /* outputDim= */ 1)
}
Note
- It uses the
layers.DenseWithBiaslayer, which simply multiplies the input by a learnable matrix (weights) and add a learnable bias. It's the most basic building block of neural networks (NNs). The implementation oflayers.DenseWithBiasis pretty simple, and worth checking out to refresh how variables from thescopeare used.- Since it's a linear model, we don't use an activation function. The usual are available for NNs (
Relu,Sigmoid,Tanh,Gelu,Silu, etc.).
Trainer and Loop¶
The last part is to put together a train.Trainer and train.Loop objects in our main() function.
The train.Trainer stitches together the model, the optimizer and the loss function, and is able to run training
steps and evaluations (you can always write your own if you want, GoMLX is designed for simple plug&play).
The second, train.Loop, loops over the dataset executing a training step at
a time and supports a subscription (hooking) system, where one attaches things like a progress bar, periodic check-pointing, run evaluations, and/or plotting of a graph of metrics.
import (
"os"
"github.com/gomlx/gomlx/ui/commandline"
"github.com/gomlx/gomlx/ml/layers/regularizer"
"github.com/gomlx/gomlx/ml/train/loss"
"github.com/gomlx/gomlx/ml/train/optimizer"
)
var (
flagNumExamples = flag.Int("num_examples", 10000, "Number of examples to generate")
flagNumFeatures = flag.Int("num_features", 3, "Number of features")
flagNoise = flag.Float64("noise", 0.2, "Noise in synthetic data generation")
flagNumSteps = flag.Int("steps", 100, "Number of gradient descent steps to perform")
flagLearningRate = flag.Float64("lr", 0.1, "Initial learning rate.")
)
// AttachToLoop decorators. It will be redefined later.
func AttachToLoop(loop *train.Loop) {
commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
}
// TrainMain() does everything to train the linear model.
func TrainMain() error {
flag.Parse()
// Select coefficients that we will try to predic.
trueCoefficients, trueBias := initCoefficients(backend, *flagNumFeatures)
fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", trueCoefficients.Value(), trueBias.Value())
// Generate training data with noise.
theBatch := buildExamples(backend, trueCoefficients, trueBias, *flagNumExamples, *flagNoise)
fmt.Printf("Training data (inputs, labels): (%s, %s)\n\n", theBatch.Inputs[0].Shape(), theBatch.Labels[0].Shape())
ds := dataset.Const(theBatch)
// Creates scope with learned weights and bias.
store := model.NewStore()
store.SetParam(optimizer.ParamLearningRate, *flagLearningRate) // = "learning_rate"
store.SetParam(regularizer.ParamL2, 1e-3) // 1e-3 of L2 regularization.
// train.Trainer executes a training step.
trainer := train.NewTrainer(backend, store,
linearModel,
loss.MeanSquaredError,
optimizer.StochasticGradientDescent(),
nil, nil) // trainMetrics, evalMetrics
loop := train.NewLoop(trainer)
AttachToLoop(loop)
// Loop for given number of steps.
metrics, err := loop.RunSteps(ds, *flagNumSteps)
if err != nil { return err }
_ = metrics // We are not interested in them in this example.
// Print learned coefficients and bias -- from the weights in the dense layer.
fmt.Println()
coefVar, biasVar := store.GetVariable("/dense/weights"), store.GetVariable("/dense/biases")
learnedCoef, learnedBias := must(coefVar.Value()), must(biasVar.Value())
fmt.Printf("Learned: coefficients=%0.3v, bias=%0.3v\n", learnedCoef.Value(), learnedBias.Value())
return nil
}
%%
err := TrainMain()
if err != nil { panic(err) }
Target: coefficients=[0.388 8.48 1.81], bias=10.4
Training data (inputs, labels): ((Float64)[10000, 3], (Float64)[10000, 1])
100% [========================================] (8000 steps/s) [step=99] [loss=0.219] [~loss=9.92]
Learned: coefficients=[[0.441] [8.27] [1.79]], bias=[10.2]
Note:
- Hyperparameters are also set in the
model.Store. Layers and optimizers can define their own hyperparemeters independently. Hyperparameters can be set and read using amodel.Scope(like a current directory) within themodel.Store, and they can take specialized values under specific scopes -- e.g.: doingscope.In("dense_5").SetParam(regularizer.ParamL2, 0.1)would set L2 regularization only for the layerdense_5of the model to0.1.- The
trainerconstructor also takes as input arbitrary metrics (for training and evaluation). The metric of the loss of the last batch, and a moving average of the loss are always included automatically. There are many others, that can be means or moving averages, etc.- The 'Loop' object is very flexible. One can attach any functionality (hooks) to be executed during training with
OnStart,OnStep,OnEnd,EveryNSteps,NTimesDuringLoop,PeriodicCallback(e.g.: every N minutes) andExponentialCallback.- The most common such functionality is the
commandline.AttachProgressBar. There is also a plotting of any arbitrary metric or any arbitraryNodein the computation graph.Loop.RunSteps()returns also the final metrics from the training, usually printed out.
Training and Plotting¶
As our last example, let's train it "for real", that is, with more steps.
And to make things prettier, let's attach also a plot of the metrics registered. In our example the only metrics are the default ones, the batch and mean of the loss -- the mean squared error.
import "github.com/gomlx/gomlx/ui/gonb/plotly"
import "os"
func AttachToLoop(loop *train.Loop) {
commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
_ = plotly.New().Dynamic().ScheduleExponential(loop, 50, 1.1)
}
%% --steps=5000
err := TrainMain()
if err != nil { panic(err) }
Target: coefficients=[4.88 -5.79 0.911], bias=3.83 Training data (inputs, labels): ((Float64)[10000, 3], (Float64)[10000, 1])
3% [>.......................................] (12697 steps/s) [0s:0s] [step=164] [loss=0.101] [~loss=2.14]
100% [========================================] (12719 steps/s) [step=4999] [loss=0.0975] [~loss=0.0975]
Metric: loss
Learned: coefficients=[[4.87] [-5.78] [0.91]], bias=[3.83]
Note:
- The
gomlx/examples/notebook/gonb/plotlypackage will automatically plot all the metrics registered in the trainer. When it attaches itself toloopit collects the metric values (and run evaluation on any requested datasets), and at the end of it, plot it.- Optionally, with
Plots.Dynamic()it also plots intermediary results, displaying the progress of the training as it happens.- Previously, we added a little L2 regularization as a hyperparameter. The
layers.DenseWithBiaslayer picks up on that hyperparameter, and adds that bit of L2 regularization. GoMLX tracks the loss with/without regularization separately, and one can see the difference in the two lines.
Other examples¶
There is much more in the libraries and examples.
GoMLX libraries are well documented, and the implementation is generally simple to use and understand -- don't hesitate into diving into the code.
For machine learning, we highly recommend diving into the model.Store package/objects, and looking at the layers library, to learn how many of the common ML layers work.
Debugging¶
Unfortunately, the computers just "don't get it": they do exactly what we tell them to do, as opposed to what we want them to do, and thus programs fail or crash. And AIs sometimes "get it", but still write wrong programs.
GoMLX provides different ways to track down various types of errors. The most commonly used below:
Good old "printf"¶
It's convenient because of Go fast compilation (change something and run to see what one gets is
almost instant). Logging results to stdout is a valid way of developing. During graph building development,
often one prints the shape of the Node being operated to confirm (or not) one's expectations.
Errors in GoMLX have stacktrace¶
Errors during the building of the graph are reported back with panic.
The graph.Exec or model.Exec recover and convert them to errors for you.
These errors include a stack-trace the tells you exactly where things happened. The stack-grace can be printed with "%+v" format string.
Node Shape Asserts¶
During the writing of complex models, it's very common to add comments on the expected shapes of the graph nodes, to facilitate the reader (and developer) of the code to have the right mental model of what is going on.
GoMLX provides a series of assert methods that can be used instead. They serve both as documentation, and an early exit in case of some unexpected results.
For example, a modelGraph function could contain:
batch_size := inputs[0].Shape().Dimensions[0]
...
layer := Concatenate(allEmbeddings, -1)
layer.AssertDims(batchSize, -1) // 2D tensor, with batch size as the leading dimension.
Although using these when building graphs is the most common case, there are similar assert functions for tensors and shapes themselves in the package github.com/gomlx/compute/shapes.
Graph Execution Logging¶
Every Node of the graph can be flagged with SetLogged(msg).
The executor (Exec) will at the end of the execution log all these values. The default logger
(set with Exec.SetLogger) will simply print the message msg along with the value of the Node of
interest. Creating a specialized loggers that handle arbitrary nodes is trivial.
Catching NaN and Inf in your training¶
This is a common source of headaches when training complex models.
The package nanlogger implements a specialized logger and monitor results on arbitrary nodes in the graph,
and will panic with custom messages (and a stack-trace) a the first sight of a NaN (or ±Inf).
More Debugging¶
Tests and these methods have been enough to develop most of GoMLX so far. But there are other debugging tools that could be made, see discussion in the Debugging document. Let us know if you need something specialized.
Happy coding and good luck on modeling!!
![[Zürich See]](zurich_see.jpg)