[GoMLX Mascot]

GoMLX Tutorial¶

If you want just to quickly look at an working example, checkout examples/cifar/demo/adult.ipynb, for model trained on the UCI Adult Income dataset. More examples in examples/ subdirectory.

The tutorial won't detail the whole API, but should present all important concepts. Everything else is well documented in godoc (in the code), also available in pkg.go.dev.

The tutorial was written using a Jupyter notebook with GoNB, a kernel for Go that was co-developed with GoMLX. It has its own short tutorial for those interested. But GoMLX doesn't require GoNB.

If you are seeing this tutorial from github snapshop, you won't be able to interact with it. To be able to play with it, try installing GoMLX, see its README Installation section. The easiest way is to start the pre-generated docker and use the Jupyter notebook there -- this tutorial can be opened from there in an interactive way.

Note: Output Not Displaying in JupyterLab ?¶

If your notebook plots are not displaying correctly, it's likely because Jupyter is assuming the notebook is "not trusted". This can be easily fixed by opening the command palette (Shift+Control+C in PCs) and selecting Trust Notebook. There is an issue opened in JupyterLab attempting to address that.

In [1]:
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx"
%goworkfix
	- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".

Computation Graphs¶

Package graph reference documentation

To do machine learning based on neural networks and gradient descent one of the most important requirements is the ability to do mathematical computations (mostly matrix multiplications) fast.

GoMLX is built on the concept of building "computation graphs", just-in-time compiling them and only then executing them to get the desired results. That means one has to write code that generates other type of code (computation graph) so to say. We do this because then we are able to execute it really fast using XLA and PJRT.

For example, let's create a computation graph to sum two values:

Note

  • If executing this on a notebook, notice the very first cell execution takes a few seconds for Go to fetch the required dependencies.
In [2]:
import . "github.com/gomlx/gomlx/graph"

func SumGraph(a, b *Node) *Node {
    return Add(a, b)
}

Note

  • The import . "github.com/gomlx/gomlx/graph" import all definitions in computation to the current scope. Usually, it's easier to work like this on go files that are going to implement graph building functions.
  • Our function is named SumGraph: the suffix Graph is just a convention, but it helps identifying functions that do graph building.
  • The type *Node represents a node in the computation graph. All graph operations either take a *Graph object to start with, or a *Node, and create new nodes with the corresponding operations. So our example will create an Add node that will take the nodes pointed by a and b, build a node that represent their summation and then return this *Node.
  • Every node contains a reference to the *Graph it's part of (see Node.Graph()).
  • There is a rich set of operations available in GoMLX, see Node documentation.

Ok, but this won't tell us what is 1+1 yet. We need to compile and then execute this graph with some input values.

Backends and executing Graphs with Exec¶

GoMLX can in principle work with different backends to execute its operations -- this opens up possibilities like executing models in WASM (not implemented yet). But the default backend, which must always be imported to execute something is xla -- it will provide a useful error message telling you to import it, if you forget.

Creating a backends.Backend object is trivial: backends.New(). It can uses the environment variable GOMLX_BACKEND, if one wants to specify an specific backend / plugin (e.g.: "xla:cpu", "xla:cuda", etc.).

With the backend created, Exec is the easiest way to compile and execute computation graphs in GoMLX. To run our SumGraph function above we can do:

In [3]:
import _ "github.com/gomlx/gomlx/backends/xla"

var backend = backends.New()

%%

// Short version: `ExecOnce` will compile the function and execute only once.
two := ExecOnce(backend, /* GoMLX function */ SumGraph, /* Arguments */ 1.0, 1.0)
fmt.Printf("Short version:\t1+1=%v\n", two)

// Exec version: `NewExec` will compile the SumGraph function so it can be executed efficiently
// many times.
exec := NewExec(backend, SumGraph)  
two = exec.Call(1, 1)[0]
fmt.Printf("Exec version:\t1+1=%v\n", two.Value())
Short version:	1+1=float64(2.0000)
Exec version:	1+1=2

Note

  • %% is a shortcut for func main(): everything after it is put inside a main function by GoNB, the Go Notebook kernel.
  • backends.New() creates a backends.Backend object, which connects to an accelerator if present. Usually one creates one at the beginning of the program and passes it around. Here GoNB will keep the global variable manager available to all cells, so we don't need to define it again.
  • The Exec object created is associated with a graph building function (SumGraph in this case). It lazily compiles and executes the compiled computation as needed. Naturally the first time Call() is invoked it is slow: it has to build the graph and just-in-time (JIT) compile it. But the compiled graph afterwards is optimized and very fast to execute, which is what we want for machine learning.
  • The Exec.Call(1, 1) method always returns a slice of results, that's why we use [0] to access the first result.
  • The results of graph execution are always tensors (see section below). They can be converted to Go types using .Value() method.

Something important to understand is that Graphs have static (fixed) shapes for its inputs and outputs. What it means is that, for example, if you are going to sum floats instead of ints, Exec would have to rebuild the graph to take as input two floats. Or if you want to sum a vector or matrix or ints, or any different shapes.Shape.

For a detailed explanation of Shape and the associate concepts of Axis, Dimensions and DType (the underlying data type), see package shapes documentation.

To exemplify, let's expand our code a bit:

In [4]:
import (
    "fmt"
    . "github.com/gomlx/gomlx/graph"
)

func SumGraph(a, b *Node) *Node {
    fmt.Printf("\t. building graph for a.shape=%s and b.shape=%s\n", a.Shape(), b.Shape())
    return Add(a, b)
}


func main() {
    sumExec := NewExec(backend, SumGraph)
    two := sumExec.Call(1, 1)[0]
    fmt.Printf("1+1=%v\n", two.Value())

    for ii := 0; ii < 5; ii++ {
        sumInts := sumExec.Call(ii, ii)[0]
        fmt.Printf("%d+%d=%v\n", ii, ii, sumInts.Value())
    }

    five := sumExec.Call(3.5, 1.5)[0]
    fmt.Printf("3.5+1.5=%v\n", five.Value())

    many := sumExec.Call([]float32{1.1, 2.2, 3.3}, []float32{10, 10, 10})[0]
    fmt.Printf("[1.1, 2.2, 3.3] + [10, 10, 10] = %v\n", many.Value())
}
	. building graph for a.shape=(Int64) and b.shape=(Int64)
1+1=2
0+0=0
1+1=2
2+2=4
3+3=6
4+4=8
	. building graph for a.shape=(Float64) and b.shape=(Float64)
3.5+1.5=5
	. building graph for a.shape=(Float32)[3] and b.shape=(Float32)[3]
[1.1, 2.2, 3.3] + [10, 10, 10] = [11.1 12.2 13.3]

Note

  • Each time a new graph is created, we added a fmt.Printf to tell us the shape of the graph operands. Notice that fmt.Printf is not included in the graph, it's only part of the graph building function. We'll see later how to print intermediary results in the middle of the execution of the graph.
  • Every Node has an associated shape (shapes.Shape type). A shape is defined by its underlying data type shapes.DType and its axes dimensions. For scalars, the shape has zero axes (dimensions). E.g.: (Int64)[] represents a scalar int value, and (Float32)[3] represents a vector with 3 float32 values. More details and the list of data types (aka. dtype) supported in the package github.com/gomlx/gopjrt/dtypes.
  • Exec automatically calls SumGraph whenever the Call() method sees parameters of shapes different from it has seen before (there is a cache of pre-compiled graphs kept in memory with limited size).

In general the graph operations only work with the same dtypes.DType. If they are different, they are reported back with a panic (works like an exception, and can be caught) with an error with a full stack-trace in the returned result. And with that, let's talk how error is handled in GoMLX computational graph building:

Error Handling¶

During the graph building and execution GoMLX diverges from the usual Golang idiomatic error handling: it uses exceptions in the form of panic with error object with a stack trace. See the package github.com/gomlx/exceptions for convenient wrapper scripts.

Note: The author is aware this is controvertial for some. So below are the motivations to use exceptions, as opposed to returning an error everywhere:

  • While implementing long mathematical functions, introducing error handling in between each operation would be too distracting, and unwieldly. Notice the Go language makes the same choice, that's why the division operator / doesn't return the result of the division and an error, and instead panics with division by zero.
  • Speed of building the computation graph is not a concern: it has no impact on the execution of the JIT compiled code later on.

Let's create an example with an error to see how this works:

In [5]:
%%
sumExec := NewExec(backend, SumGraph)
_ = sumExec.Call(1.1, 2) // Error: arguments have different dtypes Float64 and Int64.
	. building graph for a.shape=(Float64) and b.shape=(Int64)
panic: Backend "xla": failed Add: dtype of first (Float64) and second (Int64) operands don't match

goroutine 1 [running]:
github.com/gomlx/gomlx/backends/xla.(*Builder).Add(0xc00028e000, {0xbd3ba0?, 0xc000290000?}, {0xbd3ba0, 0xc000290120})
	/home/janpf/Projects/gomlx/backends/xla/gen_standard_ops.go:32 +0x137
github.com/gomlx/gomlx/graph.Add(0xc0001ca580, 0xc0001ca630)
	/home/janpf/Projects/gomlx/graph/gen_backend_ops.go:176 +0x123
main.SumGraph(0xc0001ca580, 0xc0001ca630)
	 [[ Cell [4] Line 8 ]] /tmp/gonb_55ad4f33/main.go:17 +0x189
reflect.Value.call({0xbce4e0?, 0xc62b90?, 0xc0000d7850?}, {0xc4ba8a, 0x4}, {0xc000284150, 0x2, 0x0?})
	/snap/go/current/src/reflect/value.go:581 +0xca6
reflect.Value.Call({0xbce4e0?, 0xc62b90?, 0x0?}, {0xc000284150?, 0x0?, 0x0?})
	/snap/go/current/src/reflect/value.go:365 +0xb9
github.com/gomlx/gomlx/graph.(*Exec).createAndCacheGraph(0xc0001de120, {0xc0000da000, 0x2, 0xcb59a8?})
	/home/janpf/Projects/gomlx/graph/exec.go:536 +0x75c
github.com/gomlx/gomlx/graph.(*Exec).findOrCreateGraph(0xc0001de120, {0xc0000da000, 0x2, 0x2})
	/home/janpf/Projects/gomlx/graph/exec.go:595 +0x167
github.com/gomlx/gomlx/graph.(*Exec).compileAndExecute(0xc0001de120, 0x1, {0xc0000d7f20?, 0xc0000061c0?, 0xc0000d7f40?})
	/home/janpf/Projects/gomlx/graph/exec.go:436 +0x226
github.com/gomlx/gomlx/graph.(*Exec).CallWithGraph(...)
	/home/janpf/Projects/gomlx/graph/exec.go:386
github.com/gomlx/gomlx/graph.(*Exec).Call(0xcb59a8?, {0xc0000d7f20?, 0xbce4e0?, 0xc62b90?})
	/home/janpf/Projects/gomlx/graph/exec.go:372 +0x27
main.main()
	 [[ Cell [5] Line 3 ]] /tmp/gonb_55ad4f33/main.go:24 +0xbe
exit status 2

Note

  • In the stack-trace above there are 2 lines of interest, that typically help to debug such issues:
    1. Where in the graph building function main.SumGraph function the invalid operation was created: Line 8 of the previous cell.
    2. Where in the main function, the graph was attempted to be executed: Line 3 of this cell.
  • You can enable displaying line-numbers in the JupyterLab with "ESC+L" (upper-case L).

If you want to catch errors, GoMLX provides a small exceptions library, that defines TryCatch[E], that will catch arbitrary panic (thrown) exceptions. GoMLX only throws error type of exceptions. So you could do:

import "github.com/gomlx/exceptions"

err := exceptions.TryCatch[error](func() {_ = exec.Call(1.1, 2)})
if err != nil { … }

Tensors¶

Tensors are multidimensional arrays of a given data type (dtypes.DType) defined in the package github.com/gomlx/gomlx/types/tensors.

For GoMLX tensors work as containers of data that are used as concrete inputs and outputs for the execution of computational graphs. There is only basic support to manipulate tensors directly (it includes access directly to its data): instead usually one does that using computational graphs. Tensors have a shape (shapes.Shape) just like Node.

When talking about a tensor, we are referring to the package's tensors.Tensor object (usually a pointer to it).

The tensor raw values can be stored either locally (as a Go slice) or on the device where the graph is being executed. E.g.: a GPU, or even a CPU, in which case it is owned by the backends.Backend.

While the tensors.Tensor object handles the synchronization automatically among the two, it's important that you are aware that there is this distinction, because there is a time cost to transfer data from, for example, a GPU to the CPU and back.

Effectively, what that means is that, during training, you dont' want to transfer anything locally often, for instance to save a checkpoint. But again, most of this is handled automatically by GoMLX libraries. Also, because some of the resources are allocated in the accelerator (GPU) or not managed by Go, the garbage collector may not be aware of the memory pressure on these devices.

Whenever a graph is executed (with Exec.Call) and non-tensor values are converted to tensors, and those are transferred to the device doing the execution automatically.

Example:

In [6]:
%%
onePlusExec := NewExec(backend, func (x *Node) *Node {
    return OnePlus(x) 
})
// exec.Call will return a *tensor.Tensor.
counter := onePlusExec.Call(0)[0]
// counter.Value() will first transfer counter to local with counter.Local().
fmt.Printf("counter.type=%T, counter.shape=%s, counter=%v\n", counter, counter.Shape(), counter.Value())
for ii := 0; ii < 10; ii++ {
    // Since the counter is not being used locally between the calls, the tensor will only use the 
    // device storage.
    counter = onePlusExec.Call(counter)[0]
}
// counter.Value() will first transfer the counter value locally, and then convert to a Go value.
fmt.Printf("counter=%v\n", counter.Value())
counter.type=*tensors.Tensor, counter.shape=(Int64), counter=1
counter=11

Note:

  • In the first call to onePlusExec.Call(0), the Go constant 0 is automatically converted to a *tensor.Tensor by the graph.Exec and fed to the graph. It returns a []*tensor.Tensor with one element, containing 0+1=1.
  • The returned tensor is initially only stored on device. But when we print it, it is automatically transferred to Go.
  • While executing the loop the counter will not be transferred locally, everything happens efficiently in the accelerator device (GPU, TPU, or even the CPU executor).
  • When we print the final result in counter.Value() again the data is transferred locally.

There are several ways to create tensors, the most common:

  • tensors.FromValue[S](value S): Generics conversion, works with any supported DType scalar as well as with any arbitrary multidimensional slice. Slices of rank > 1 must be regular, that is all the sub-slices must have the same shape. E.g: FromValue([][]float{{1,2}, {3, 5}, {7, 11}}). There is also a non-generic version tensors.FromAnyValue(value any). This is what Exec.Call() uses to convert arbitrary values to a tensor.
  • FromShape(shape shapes.Shape): creates a Local tensor with the given shape, and initialized to zero. See documentation on Tensor.MutableFlatData(...) to mutate tensors in place.
  • FromScalarAndDimensions[T](value T, dimensions ...int): creates a tensor with the given dimensions, filled with the scalar value given. T must be one of the supported types.

See more documentation in github.com/gomlx/gomlx/types/tensors.

Errors in the manipulation of Tensors (e.g. invalid values) are also reported back with exception thrown with panic, with full stack-traces, just as with the graph package described in the previous package. The errors can easily be caught (with recover() or with exceptions.TryCatch helper) when needed.

Gradients¶

Another important functionality required to train machine learning models based on gradient descent is calculating the gradients of some value being optimized with respect to some variable / quantity.

GoMLX does this statically, during graph building time. It adds to the graph the computation for the gradient.

Example: let's calculate the gradient of the function $f(x, y) = x^2 + xy$ for a few values of $x$ and $y$. Algebraically we have:

$$ \begin{aligned} (x, y) &= x^2 + xy \\ df/dx(x,y) &= 2x + y \\ df/dy(x,y) &= x \\ \end{aligned} $$

In [7]:
func f(x, y *Node) *Node {
    return Add(Square(x), Mul(x, y))
}

func gradOfF(x, y *Node) (output, gradX, gradY *Node) {
    output = f(x, y)
    reduced := ReduceAllSum(output) // In case x and y are not scalars.
    grads := Gradient(reduced, x, y)
    gradX, gradY = grads[0], grads[1] // df/dx, df/dy
    return output, gradX, gradY
}

%%
exec := NewExec(backend, gradOfF)
x := []float64{0, 1, 2}
y := []float64{10, 20, 30}
results := exec.Call(x, y)
fmt.Printf("f(x=%v, y=%v)=%v,\n\tdf/dx=%v,\n\tdf/dy=%v\n", x, y, results[0].Value(), results[1].Value(), results[2].Value())
f(x=[0 1 2], y=[10 20 30])=[0 21 64],
	df/dx=[10 22 34],
	df/dy=[0 1 2]

Note:

  • For now GoMLX only calculates gradients of a scalar (typically a model loss) with respect to arbitrary tensors. It does not yet calculate jacobians, that is, if the value we are deriving is not a scalar. That's the reason of the ReduceAllSum in the example, the result is the derivative of the sum of all the 3 inputs.
  • A question that may arise is whether it calculates the second derivative (hessian). In principle the machinery to do that is in place, but there are 2 limitations: (1) not all operations have their derivative implemented, in particular some of the operations that are only used when calculating the first derivative; (2) it only calculates the gradient with respect to a scalar, in most cases the hessian will be the gradient of a gradient, usually of higher rank -- Btw, contributions to the project here are welcome ;)
In [8]:
// Removing the previous definitions of `f` and `gradOfF`
%rm gradOfF f
. removed func gradOfF
. removed func f

Variables and Context¶

Computation graphs are pure functions: they have no state, they take inputs, return outputs and everything in between is transient.

For Machine Learning as well as many types of computations, it's convenient to store intermediary results (the model parameters for ML) in between the execution of the computation graphs.

For that GoMLX offers the context.Context object (completely unrelated to the usual Go's context package), and a corresponding context.Exec. It is a container of variables (whose values are tensors usually stored on device), and it manages automatically its updates, passing it as extra inputs and taking them out (if changed) as extra outputs of the computation graph.

This may sound more complex than it is in practice. Let's see an example, where we try to find $argmin_{x}{f(x)}$ where $f(x) = ax^2 + bx + c$. If we solve it literally we should get, for $a > 0$, $argmin_{x}{f(x)} = \frac{-b}{2a}$. Instead we solve it numerically, using gradient descent:

In [9]:
import "flag"

var (
    flagA = flag.Float64("a", 1.0, "Value of a in the equation ax^2+bx+c")
    flagB = flag.Float64("b", 2.0, "Value of b in the equation ax^2+bx+c")
    flagC = flag.Float64("c", 4.0, "Value of c in the equation ax^2+bx+c")
    flagNumSteps = flag.Int("steps", 10, "Number of gradient descent steps to perform")
    flagLearningRate    = flag.Float64("lr", 0.1, "Initial learning rate.")
)

// f(x) = ax^2 + bx + c
func f(x *Node) *Node {
    f := MulScalar(Square(x), *flagA)
    f = Add(f, MulScalar(x, *flagB))
    f = AddScalar(f, *flagC)
    return f
}

// minimizeF does one gradient descent step on F by moving a variable "x",
// and returns the value of the function at the current "x".
func minimizeF(ctx *context.Context, graph *Graph) *Node {
    // Create or reuse existing variable "x" -- no graph operation is created with this, it's
    // only a reference.
    xVar := ctx.VariableWithShape("x", shapes.Make(dtypes.Float64))
    
    x := xVar.ValueGraph(graph)                        // Read variable for the current graph.
    y := f(x)                                          // Value of f(x).
    
    // Gradient always return a slice, we take the first element for grad of X.
    gradX := Gradient(y, x)[0] 
    
    // stepNum += 1
    stepNumVar := ctx.VariableWithValue("stepNum", 0.0)  // Creates the variable if not existing, or retrieve it if already exists.
    stepNum := stepNumVar.ValueGraph(graph)
    stepNum = OnePlus(stepNum)
    stepNumVar.SetValueGraph(stepNum)
    
    // step = -learningRate * gradX / Sqrt(stepNum)
    step := Div(gradX, Sqrt(stepNum))
    step = MulScalar(step, -*flagLearningRate)
    
    // x += step
    x = Add(x, step)
    xVar.SetValueGraph(x)
    return y  // f(x)
}

func Solve() {
    ctx := context.New()
    exec := context.NewExec(backend, ctx, minimizeF)
    for ii := 0; ii < *flagNumSteps-1; ii++ {
        _ = exec.Call()
    }
    y := exec.Call()[0]
    x := ctx.InspectVariable(ctx.Scope(), "x").Value()
    stepNum := ctx.InspectVariable(ctx.Scope(), "stepNum").Value()
    fmt.Printf("Minimum found at x=%g, f(x)=%g after %f steps.\n", x.Value(), y.Value(), stepNum.Value())
}

The code above created Solve() that will solve for the values set by the flags a, b, and c.

Let's try a few values:

Note: %% in GoNB automatically creates a func main() and passes the extra arguments to the Go program.

In [10]:
%% --a=1 --b=2 --c=3 --steps=10 --lr=0.5
Solve()
Minimum found at x=-0.9999999999999999, f(x)=2 after 10.000000 steps.
In [11]:
%% --a=2 --b=12 --c=20 --steps=10 --lr=0.5
Solve()
Minimum found at x=-3, f(x)=2 after 10.000000 steps.

Note:

  • We are using context.Exec, while before we were using computation.Exec. The main difference is that context.Exec compiles and executes graph functions that take a context as its first parameter, and it automatically handles the passing of variables as side inputs and outputs (for those variables updated) of the computation graph.
  • During graph building, we access and set the variables with Variable.ValueGraph and VariableSetValueGraph: They return/take *Node types, that can be used in the graph.
  • Outside graph building, we can access the last value set to a variable by using Variable.Value() and Variable.SetValue. They return/take concrete *tensor.Tensor types.
  • We created two variables, one for "x" that we were optimizing to minimize $f(x)$, and one variable "stepNum", used to keep track how many steps were already executed.
  • Yes, if we set --lr=1 (the learning rate), it will get to the minimum in one step for the quadratic f(x). 😉

There is more to context.Context, some we'll present on the next section on Machine Learning, others can be found in its documentation. A few things worth advancing:

  • Context is always configured at a certain scope, and variables are unique within its scope. Scope is easily changed with ctx.In("new_scope"). So the Context object is a scope (a string) and a pointer to the actual data (variables, graph and model parameters).
  • Context also holds model parameters (concrete Go values), which are also scoped. Those can be hyperparameters for the models (learning rate, regularization, etc.) or anything the user or any library may create a convention for. They are more convenient than using hundreds of flags. See example of its usage in the UCI Adult example.
  • Similarly Context also holds "Graph parameters". Those are very similar to model parameters, but they have one value per Graph. So if a model is created with parameters of different shape (or for training/evaluation), each version will have its own Graph parameters. Don't worry about this now -- if you need it later when building complex graphs, the funcitonality will be there.

Machine Learning (ML)¶

The previous sections presented the fundamentals of what is needed to implement machine learning. This section we present various sub-packages that provide high level ML layers and tools that make building, training and evaluating a model trivial.

First is the package layers (see code in ml/layers. It provides several composable ML layers. These are graph building functions, most of which take a *context.Context as first parameter, where they store variables or access hyperparameters. There are several such layers, for example: layers.Dense, layers/fnn.New, layers/kan.New,
layers.Dropout, layers.PiecewiseLinearCalibration (very good for normalization of inputs), layers.BatchNorm, layers.LayerNorm, layers.Convolution, layers.MultiHeadAttention (for Transformers layers), etc.

The package train offers two main functionalities: train.Trainer will build a train step and an eval step graph, given a model graph building function and an optimizer. This graph can be executed in sequence to train a model. The package also provides train.Loop that simply loop over a train.Dataset interface, reading data and feeding it to Trainer.TrainStep or Trainer.EvalStep, along with executing configurable hooks on the training loop. One such hooks is provided by gomlx/train/commandline.AttachProgressBar(loop), it pretty prints the progress during training on the command line.

There are also a collection of optimizers, initializers, loss functions, metrics, etc. For any functionality there is always an example under the examples/ subdirectory.

Let's look at the simplest ML example: linear, which trains a linear model on nosiy generated data.

Here are the constants of our problem:

In [12]:
const (
    CoefficientMu    = 0.0
    CoefficientSigma = 5.0
    BiasMu           = 1.0
    BiasSigma        = 10.0
)

To generate synthetic data it first randomly chose some random coefficients and bias based on which data is generated. These selected coefficients is the ones we want to try to learn using ML. The coefficients could have been selected in Go directly using math/random, but just for fun, we do it using a computation graph.

In [13]:
import (
    "github.com/gomlx/gomlx/types/shapes"
    "github.com/gomlx/gomlx/types/tensors"
)

// initCoefficients chooses random coefficients and bias. These are the true values the model will
// attempt to learn.
func initCoefficients(backend backends.Backend, numVariables int) (coefficients, bias *tensors.Tensor) {
    e := NewExec(backend, func(g *Graph) (coefficients, bias *Node) {
        rngState := Const(g, RngState())
        rngState, coefficients = RandomNormal(rngState, shapes.Make(dtypes.Float64, numVariables))
        coefficients = AddScalar(MulScalar(coefficients, CoefficientSigma), CoefficientMu)
        rngState, bias = RandomNormal(rngState, shapes.Make(dtypes.Float64))
        bias = AddScalar(MulScalar(bias, BiasSigma), BiasMu)
        return
    })
    results := e.Call()
    coefficients, bias = results[0], results[1]
    return
}

%%
coef, bias := initCoefficients(backend, 3)
fmt.Printf("Example of target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value()) 
Example of target: coefficients=[-3.4 -7.17 3.85], bias=-15.3

Note

  • This code should look familiar, using things we presented earlier in the tutorial. It creates a computation graph to generate randomly the coefficients and bias. Then it executes it and returns the result.
  • Notice that since the computation graph is functional: we need to pass around the random number generator state, which gets updated at each call to RandomUniform or RandomNormal.
    • Alternatively the context.Context introduced earlier can keep the state as a variable, and provides a simpler interface: see Context.RandomUniform and Context.RandomNormal.

Next, we want to generate the data (examples): we generate random inputs, and then the label using the selected coefficients plus some normal noise.

In [14]:
func buildExamples(backend backends.Backend, coef, bias *tensors.Tensor, numExamples int, noise float64) (inputs, labels *tensors.Tensor) {
    e := NewExec(backend, func(coef, bias *Node) (inputs, labels *Node) {
        g := coef.Graph()
        numFeatures := coef.Shape().Dimensions[0]

        // Random inputs (observations).
        rngState := Const(g, RngState())
        rngState, inputs = RandomNormal(rngState, shapes.Make(coef.DType(), numExamples, numFeatures))
        coef = ExpandDims(coef, 0)

        // Calculate perfect labels.
        labels = ReduceAndKeep(Mul(inputs, coef), ReduceSum, -1)
        labels = Add(labels, bias)
        if noise > 0 {
            // Add some noise to the labels.
            var noiseVector *Node
            rngState, noiseVector = RandomNormal(rngState, labels.Shape())
            noiseVector = MulScalar(noiseVector, noise)
            labels = Add(labels, noiseVector)
        }
        return
    })
    examples := e.Call(coef, bias)
    inputs, labels = examples[0], examples[1]
    return
}

%%
coef, bias := initCoefficients(backend, 3)
numExamples := 5
inputsTensor, labelsTensor := buildExamples(backend, coef, bias, numExamples, 0.2)
fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value()) 

fmt.Printf("%d dataset examples:\n", numExamples)
inputs := inputsTensor.Value().([][]float64)
labels := labelsTensor.Value().([][]float64)
for ii := 0; ii < numExamples; ii ++ {
    fmt.Printf("\tx=%0.3v; label=%0.3v\n", inputs[ii], labels[ii])
}
Target: coefficients=[7.07 1.15 -3.4], bias=4.38
5 dataset examples:
	x=[0.0184 1.35 -0.627]; label=[8.5]
	x=[0.876 -1.57 0.0963]; label=[8.32]
	x=[0.493 -0.332 0.89]; label=[3.87]
	x=[-0.535 0.923 1.01]; label=[-2.15]
	x=[0.701 0.357 0.696]; label=[7.72]

Dataset¶

Now the first new concept of this section: train.Dataset is the interface that is used to feed data to during a training loop or evaluation.

There are three methods: Dataset.Yield that returns the next batch of examples; Dataset.Reset restarts the dataset, for datasets that don't loop indefinitely; Finally Dataset.Name returns the dataset name, usually used for metric names, logging and printing.

Datasets also yield a spec, an opaque type for GoMLX (defined as any), that allows the dataset to communicate to the model which type of data it is generating. In our case, since it's always the same data, we don't need it, so we keep it set to nil. If one would implement a dataset like a generic CSV file, one may want to communicate to the model the field names to the Model throught the spec, for instance. See the documentation for more details.

For our linear synthetic data we implement the simplest train.Dataset: the whole data is pre-generated, and we return a giant batch with the full data every time -- we could also have more easily used data.InMemoryDataset for that, but for didactic purposes we create it from scratch:

In [15]:
import "github.com/gomlx/gomlx/ml/train"

// TrivialDataset always returns the whole data.
type TrivialDataset struct {
    name string
    inputs, labels []*tensors.Tensor
}

var (
    // Assert Dataset implements train.Dataset.
    _ train.Dataset = &TrivialDataset{}
)
// Name implements train.Dataset.
func (ds *TrivialDataset) Name() string { return ds.name }

// Yield implements train.Dataset.
func (ds *TrivialDataset) Yield() (spec any, inputs, labels []*tensors.Tensor, err error) {
    return ds, ds.inputs, ds.labels, nil
}

// IsOwnershipTransferred tells the training loop that the dataset keeps ownership of the yielded tensors.
func (ds *TrivialDataset) IsOwnershipTransferred() bool {
    return false
}

// Reset implements train.Dataset.
func (ds *TrivialDataset) Reset() {}

Note:

  • More often it is more work pre-processing data than actually building an ML model ... that's life 🙁
  • In the examples/ subdirectory we implement train.Dataset for some well known data sets: UCI Adult, Cifar-10, Cifar-100, IMDB Reviews, Kaggle's Dogs vs Cats, Oxford Flowers 102. These can be used as libraries to easily try different models. If you are workig on public datasets, please consider contributing similar libraries.
  • GoMLX also include the data.InMemoryDataset, which can be created from tensors in one line

The package github.com/gomlx/gomlx/ml/data provides several tools to facilitate the work here:

  • Parallel: parallelizes any dataset, includes some buffer.
  • InMemory: reads a dataset into (accelerator) memory, and then serves it from there -- greatly accelerates training. We could have used that instead of defining TrivialDataset, but we left it because it's common to specialize datasets.
  • ConstantDataset: trivial Dataset that returns always empty inputs and labels. Used when one is going to generate the data dynamically -- it could also have been used for this example. See example in KAN Shapes demo.
  • Downloading of datasets (with progress-bar) and checksum functions.

ModelFn¶

Next we build a model, that for our train package means implementing a function with the following signature:

type ModelFn func(ctx *context.Context, spec any, inputs []*graph.Node) (predictions []*graph.Node)

It takes a context.Context for the variables and hyperparameters, the spec and a slice of inputs — the last two are fed by Dataset.Yield above. It returns a slice of predictions — is most cases there is just one value in the slice (only one prediction). During training predictions fed to the loss function, and during inference they can be returned directly.

Our linear example has the simplest model possible:

In [16]:
import "github.com/gomlx/gomlx/ml/context"

func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) {
    _ = spec  // Not needed here, we know the dataset.
    logits := layers.DenseWithBias(ctx, inputs[0], /* outputDim= */ 1)
    return []*Node{logits}
}

Note

  • It uses the layers.DenseWithBias layer, which simply multiplies the input by a learnable matrix (weights) and add a learnable bias. It's the most basic building block of neural networks (NNs). The implementation of layers.DenseWithBias is pretty simple, and worth checking out to refresh how variables from the Context are used.
  • Since it's a linear model, we don't use an activation function. The usual are available for NNs (Relu, Sigmoid, Tanh and more to come).
  • The spec parameter allows the creation of a ModelFn that can be used for different types of data. The dataset can Yield also a spec about the type of data it is reading. Each different value of spec will trigger the the creation of a different computation graph, so ideally there would be at most a few types of different data source spec. Most commonly there is only one, like in this example, and the parameter can be ignored.

Trainer and Loop¶

The last part is put together a train.Trainer and train.Loop objects in our main() function. The first stitches together the model, the optimizer and the loss function, and is able to run training steps and evaluations. The second, train.Loop, loops over the dataset executing a training step at a time and supporst a subscription (hooking) system, where one attaches things like a progress bar, or plotting of a graph.

In [17]:
import (
    "os"
    "github.com/gomlx/gomlx/ui/commandline"
    "github.com/gomlx/gomlx/ml/layers/regularizers"
)

var (
    flagNumExamples  = flag.Int("num_examples", 10000, "Number of examples to generate")
    flagNumFeatures  = flag.Int("num_features", 3, "Number of features")
    flagNoise        = flag.Float64("noise", 0.2, "Noise in synthetic data generation")
    flagNumSteps     = flag.Int("steps", 100, "Number of gradient descent steps to perform")
    flagLearningRate    = flag.Float64("lr", 0.1, "Initial learning rate.")
)

// AttachToLoop decorators. It will be redefined later.
func AttachToLoop(loop *train.Loop) {
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop. 
}

// TrainMain() does everything to train the linear model.
func TrainMain() {
    flag.Parse()

    // Select coefficients that we will try to predic.
    trueCoefficients, trueBias := initCoefficients(backend, *flagNumFeatures)
    fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", trueCoefficients.Value(), trueBias.Value()) 

    // Generate training data with noise.
    inputs, labels := buildExamples(backend, trueCoefficients, trueBias, *flagNumExamples, *flagNoise)
    fmt.Printf("Training data (inputs, labels): (%s, %s)\n\n", inputs.Shape(), labels.Shape())
    dataset := &TrivialDataset{"linear", []*tensors.Tensor{inputs}, []*tensors.Tensor{labels}}

    // Creates Context with learned weights and bias.
    ctx := context.New()
    ctx.SetParam(optimizers.ParamLearningRate, *flagLearningRate)  // = "learning_rate"
    ctx.SetParam(regularizers.ParamL2, 1e-3)  // 1e-3 of L2 regularization.

    // train.Trainer executes a training step.
    trainer := train.NewTrainer(backend, ctx, modelGraph,
        losses.MeanSquaredError,
        optimizers.StochasticGradientDescent(),
        nil, nil) // trainMetrics, evalMetrics
    loop := train.NewLoop(trainer)
    AttachToLoop(loop)

    // Loop for given number of steps. must.M1() panics, if loop.RunSteps returns an error.
    metrics := must.M1(loop.RunSteps(dataset, *flagNumSteps))
    _ = metrics  // We are not interested in them in this example.

    // Print learned coefficients and bias -- from the weights in the dense layer.
    fmt.Println()
    coefVar, biasVar := ctx.GetVariableByScopeAndName("/dense", "weights"), ctx.GetVariableByScopeAndName("/dense", "biases")
    learnedCoef, learnedBias := coefVar.Value(), biasVar.Value()
    fmt.Printf("Learned: coefficients=%0.3v, bias=%0.3v\n", learnedCoef.Value(), learnedBias.Value())
}

%%
TrainMain()
Target: coefficients=[-5.43 -1.65 -4.84], bias=6.74
Training data (inputs, labels): ((Float64)[10000 3], (Float64)[10000 1])

Training (100 steps):  100% [========================================] (4946 steps/s) [step=99] [loss+=0.144] [~loss+=5.162] [~loss=5.120]                       

Learned: coefficients=[[-5.31] [-1.62] [-4.73]], bias=[6.59]

Note:

  • Hyperparameters are set on the context. Layers and optimizers can define their own hyperparemeters independently. Context uses a scoping mechanism (like directories), and hyperparameters can take specialized values under specific scopes -- e.g.: doing ctx.In("dense_5").SetParam(regularizers.ParamL2, 0.1) would set L2 regularization only for the layer dense_5 of the model to 0.1.
  • The trainer constructor also takes as input arbitrary metrics (for training and evaluation). The metric of the loss of the last batch, and a moving average of the loss are always included automatically. There are many others, that can be means or moving averages, etc.
  • The 'Loop' object is very flexible. One can attach any functionality (hooks) to be executed during training with OnStart, OnStep, OnEnd, EveryNSteps, NTimesDuringLoop, PeriodicCallback (e.g.: every N minutes) and ExponentialCallback.
  • The most common such functionality is the commandline.AttachProgressBar. There is also a plotting of any arbitrary metric or any arbitrary Node in the computation graph.
  • Loop.RunSteps() returns also the final metrics from the training, usually printed out.

Training and Plotting¶

As our last example, let's train it "for real", that is, with more steps.

And to make things prettier, let's attach also a plot of the metrics registered. In our example the only metrics are the default ones, the batch and mean of the loss -- the mean squared error.

In [18]:
import "github.com/gomlx/gomlx/ui/gonb/plotly"
import "os"

func AttachToLoop(loop *train.Loop) {
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop. 
    _ = plotly.New().Dynamic().ScheduleExponential(loop, 50, 1.1)
}

%% --steps=5000
TrainMain()
Target: coefficients=[-1.74 3.39 1.23], bias=3.83
Training data (inputs, labels): ((Float64)[10000 3], (Float64)[10000 1])

Training (5000 steps):    3% [>.......................................] (5847 steps/s) [0s:0s] [step=164] [loss+=0.058] [~loss+=0.864] [~loss=0.851]           
Training (5000 steps):  100% [========================================] (7731 steps/s) [step=4999] [loss+=0.056] [~loss+=0.056] [~loss=0.040]                

Metric: loss

Learned: coefficients=[[-1.73] [3.38] [1.23]], bias=[3.83]

Note:

  • The gomlx/examples/notebook/gonb/plotly package will automatically plot all the metrics registered in the trainer. When it attaches itself to loop it collects the metric values (and run evaluation on any requested datasets), and at the end of it, plot it.
  • Optionally, with Plots.Dynamic() it also plots intermediary results, displaying the progress of the training as it happens.
  • Previously, we added a little L2 regularization as a hyperparameter. The layers.DenseWithBias layer picks up on that hyperparameter, and adds that bit of L2 regularization. GoMLX tracks the loss with/without regularization separately, and one can see the difference in the two lines.

Other examples¶

There is much more in the libraries and examples. GoMLX libraries are well documented, and the implementation is generally simple to use and understand -- don't hesitate into diving into the code. For machine learning, we highly recommend diving into the context.Context package/objects, and looking at the layers library, to learn how many of the common ML layers work.

Debugging¶

Unfortunately, the computers just "don't get it": they do exactly what we tell them to do, as opposed to what we want them to do, and thus programs fail or crash. GoMLX provides different ways to track down various types of errors. The most commonly used below:

Good old "printf"¶

It's convenient because of Go fast compilation (change something and run to see what one gets is almost instant). Logging results to stdout is a valid way of developing. During graph building development, often one prints the shape of the Node being operated to confirm (or not) one's expectations.

Errors reported with panic -- Go "exceptions".¶

Errors during the building of the graph are reported back with panic. These can be caught in Go with recover, but GoMLX offers the exceptions library to make it easy. All libraries always panic helpful error messages -- they can be printed with "%+v" to get full stack-trace output.

Node Shape Asserts¶

During the writing of complex models, it's very common to add comments on the expected shapes of the graph nodes, to facilitate the reader (and developer) of the code to have the right mental model of what is going on.

GoMLX provides a series of assert methods that can be used instead. They serve both as documentation, and an early exit in case of some unexpected results.

For example, a modelGraph function could contain:

batch_size := inputs[0].Shape().Dimensions[0]
    ...
    layer := Concatenate(allEmbeddings, -1)
    layer.AssertDims(batchSize, -1)  // 2D tensor, with batch size as the leading dimension.

Although using these when building graphs is the most common case, there are similar assert functions for tensors and shapes themselves in the package gomlx/types/shapes.

Graph Execution Logging¶

Every Node of the graph can be flagged with SetLogged(msg). The executor (Exec) will at the end of the execution log all these values. The default logger (set with Exec.SetLogger) will simply print the message msg along with the value of the Node of interest. Creating a specialized loggers that handle arbitrary nodes is trivial.

Catching NaN and Inf in your training¶

This is a common source of headaches when training complex models. The package nanlogger implements a specialized logger and monitor results on arbitrary nodes in the graph, and will panic with custom messages (and a stack-trace) a the first sight of a NaN (or ±Inf).

More Debugging¶

Tests and these methods have been enough to develop most of GoMLX so far. But there are other debugging tools that could be made, see discussion in the Debugging document. Let us know if you need something specialized.


Happy coding and good luck on modeling!!

[Zürich See]