GoMLX Tutorial¶
If you want just to quickly look at an working example, checkout examples/cifar/demo/adult.ipynb, for model trained on the UCI Adult Income dataset. More examples in examples/ subdirectory.
The tutorial won't detail the whole API, but should present all important concepts. Everything else is well documented in godoc (in the code), also available in pkg.go.dev.
The tutorial was written using a Jupyter notebook with GoNB, a kernel for Go that was co-developed with GoMLX. It has its own short tutorial for those interested. But GoMLX doesn't require GoNB.
If you are seeing this tutorial from github snapshop, you won't be able to interact with it. To be able to play with it, try installing GoMLX, see its README Installation section. The easiest way is to start the pre-generated docker and use the Jupyter notebook there -- this tutorial can be opened from there in an interactive way.
Note: Output Not Displaying in JupyterLab ?¶
If your notebook plots are not displaying correctly, it's likely because Jupyter is assuming the notebook is "not trusted". This can be easily fixed by opening the command palette (Shift+Control+C
in PCs) and selecting Trust Notebook
.
There is an issue opened in JupyterLab attempting to address that.
!*rm -f go.work && go work init && go work use . "${HOME}/Projects/gomlx"
%goworkfix
- Added replace rule for module "github.com/gomlx/gomlx" to local directory "/home/janpf/Projects/gomlx".
Computation Graphs¶
To do machine learning based on neural networks and gradient descent one of the most important requirements is the ability to do mathematical computations (mostly matrix multiplications) fast.
GoMLX is built on the concept of building "computation graphs", just-in-time compiling them and only then executing them to get the desired results. That means one has to write code that generates other type of code (computation graph) so to say. We do this because then we are able to execute it really fast using XLA and PJRT.
For example, let's create a computation graph to sum two values:
Note
- If executing this on a notebook, notice the very first cell execution takes a few seconds for Go to fetch the required dependencies.
import . "github.com/gomlx/gomlx/graph"
func SumGraph(a, b *Node) *Node {
return Add(a, b)
}
Note
- The
import . "github.com/gomlx/gomlx/graph"
import all definitions in computation to the current scope. Usually, it's easier to work like this on go files that are going to implement graph building functions.- Our function is named
SumGraph
: the suffixGraph
is just a convention, but it helps identifying functions that do graph building.- The type
*Node
represents a node in the computation graph. All graph operations either take a*Graph
object to start with, or a*Node
, and create new nodes with the corresponding operations. So our example will create anAdd
node that will take the nodes pointed bya
andb
, build a node that represent their summation and then return this*Node
.- Every node contains a reference to the
*Graph
it's part of (seeNode.Graph()
).- There is a rich set of operations available in GoMLX, see
Node
documentation.
Ok, but this won't tell us what is 1+1 yet. We need to compile and then execute this graph with some input values.
Backends and executing Graphs with Exec
¶
GoMLX can in principle work with different backends to execute its operations -- this opens up possibilities like executing models in WASM (not implemented yet).
But the default backend, which must always be imported to execute something is xla
-- it will provide a useful error message telling you to import it, if you forget.
Creating a backends.Backend
object is trivial: backends.New()
. It can uses the environment variable GOMLX_BACKEND
, if one wants to specify an specific backend / plugin (e.g.: "xla:cpu"
, "xla:cuda"
, etc.).
With the backend created, Exec
is the easiest way to compile and execute computation graphs in GoMLX.
To run our SumGraph
function above we can do:
import _ "github.com/gomlx/gomlx/backends/xla"
var backend = backends.New()
%%
// Short version: `ExecOnce` will compile the function and execute only once.
two := ExecOnce(backend, /* GoMLX function */ SumGraph, /* Arguments */ 1.0, 1.0)
fmt.Printf("Short version:\t1+1=%v\n", two)
// Exec version: `NewExec` will compile the SumGraph function so it can be executed efficiently
// many times.
exec := NewExec(backend, SumGraph)
two = exec.Call(1, 1)[0]
fmt.Printf("Exec version:\t1+1=%v\n", two.Value())
Short version: 1+1=float64(2.0000) Exec version: 1+1=2
Note
%%
is a shortcut forfunc main()
: everything after it is put inside amain
function by GoNB, the Go Notebook kernel.backends.New()
creates abackends.Backend
object, which connects to an accelerator if present. Usually one creates one at the beginning of the program and passes it around. Here GoNB will keep the global variablemanager
available to all cells, so we don't need to define it again.- The
Exec
object created is associated with a graph building function (SumGraph
in this case). It lazily compiles and executes the compiled computation as needed. Naturally the first timeCall()
is invoked it is slow: it has to build the graph and just-in-time (JIT) compile it. But the compiled graph afterwards is optimized and very fast to execute, which is what we want for machine learning.- The
Exec.Call(1, 1)
method always returns a slice of results, that's why we use[0]
to access the first result.- The results of graph execution are always tensors (see section below). They can be converted to Go types using
.Value()
method.
Something important to understand is that Graphs have static (fixed) shapes for its inputs and outputs. What
it means is that, for example, if you are going to sum floats instead of ints, Exec
would have to rebuild the graph
to take as input two floats. Or if you want to sum a vector or matrix or ints, or any different shapes.Shape
.
For a detailed explanation of
Shape
and the associate concepts of Axis, Dimensions andDType
(the underlying data type), see packageshapes
documentation.
To exemplify, let's expand our code a bit:
import (
"fmt"
. "github.com/gomlx/gomlx/graph"
)
func SumGraph(a, b *Node) *Node {
fmt.Printf("\t. building graph for a.shape=%s and b.shape=%s\n", a.Shape(), b.Shape())
return Add(a, b)
}
func main() {
sumExec := NewExec(backend, SumGraph)
two := sumExec.Call(1, 1)[0]
fmt.Printf("1+1=%v\n", two.Value())
for ii := 0; ii < 5; ii++ {
sumInts := sumExec.Call(ii, ii)[0]
fmt.Printf("%d+%d=%v\n", ii, ii, sumInts.Value())
}
five := sumExec.Call(3.5, 1.5)[0]
fmt.Printf("3.5+1.5=%v\n", five.Value())
many := sumExec.Call([]float32{1.1, 2.2, 3.3}, []float32{10, 10, 10})[0]
fmt.Printf("[1.1, 2.2, 3.3] + [10, 10, 10] = %v\n", many.Value())
}
. building graph for a.shape=(Int64) and b.shape=(Int64) 1+1=2 0+0=0 1+1=2 2+2=4 3+3=6 4+4=8 . building graph for a.shape=(Float64) and b.shape=(Float64) 3.5+1.5=5 . building graph for a.shape=(Float32)[3] and b.shape=(Float32)[3] [1.1, 2.2, 3.3] + [10, 10, 10] = [11.1 12.2 13.3]
Note
- Each time a new graph is created, we added a
fmt.Printf
to tell us the shape of the graph operands. Notice thatfmt.Printf
is not included in the graph, it's only part of the graph building function. We'll see later how to print intermediary results in the middle of the execution of the graph.- Every
Node
has an associated shape (shapes.Shape
type). A shape is defined by its underlying data typeshapes.DType
and its axes dimensions. For scalars, the shape has zero axes (dimensions). E.g.:(Int64)[]
represents a scalarint
value, and(Float32)[3]
represents a vector with 3float32
values. More details and the list of data types (aka. dtype) supported in the package github.com/gomlx/gopjrt/dtypes.Exec
automatically callsSumGraph
whenever theCall()
method sees parameters of shapes different from it has seen before (there is a cache of pre-compiled graphs kept in memory with limited size).
In general the graph operations only work with the same dtypes.DType
.
If they are different, they are reported back with a panic
(works like an exception, and can be caught) with an error with a full stack-trace in the returned result.
And with that, let's talk how error is handled in GoMLX computational graph building:
Error Handling¶
During the graph building and execution GoMLX diverges from the usual Golang idiomatic error handling: it uses exceptions in the form of panic
with error
object with a stack trace. See the package github.com/gomlx/exceptions for convenient wrapper scripts.
Note: The author is aware this is controvertial for some. So below are the motivations to use exceptions, as opposed to returning an error everywhere:
- While implementing long mathematical functions, introducing error handling in between each operation would be too distracting, and unwieldly. Notice the Go language makes the same choice, that's why the division operator
/
doesn't return the result of the division and an error, and instead panics with division by zero.- Speed of building the computation graph is not a concern: it has no impact on the execution of the JIT compiled code later on.
Let's create an example with an error to see how this works:
%%
sumExec := NewExec(backend, SumGraph)
_ = sumExec.Call(1.1, 2) // Error: arguments have different dtypes Float64 and Int64.
. building graph for a.shape=(Float64) and b.shape=(Int64)
panic: Backend "xla": failed Add: dtype of first (Float64) and second (Int64) operands don't match goroutine 1 [running]: github.com/gomlx/gomlx/backends/xla.(*Builder).Add(0xc00028e000, {0xbd3ba0?, 0xc000290000?}, {0xbd3ba0, 0xc000290120}) /home/janpf/Projects/gomlx/backends/xla/gen_standard_ops.go:32 +0x137 github.com/gomlx/gomlx/graph.Add(0xc0001ca580, 0xc0001ca630) /home/janpf/Projects/gomlx/graph/gen_backend_ops.go:176 +0x123 main.SumGraph(0xc0001ca580, 0xc0001ca630) [[ Cell [4] Line 8 ]] /tmp/gonb_55ad4f33/main.go:17 +0x189 reflect.Value.call({0xbce4e0?, 0xc62b90?, 0xc0000d7850?}, {0xc4ba8a, 0x4}, {0xc000284150, 0x2, 0x0?}) /snap/go/current/src/reflect/value.go:581 +0xca6 reflect.Value.Call({0xbce4e0?, 0xc62b90?, 0x0?}, {0xc000284150?, 0x0?, 0x0?}) /snap/go/current/src/reflect/value.go:365 +0xb9 github.com/gomlx/gomlx/graph.(*Exec).createAndCacheGraph(0xc0001de120, {0xc0000da000, 0x2, 0xcb59a8?}) /home/janpf/Projects/gomlx/graph/exec.go:536 +0x75c github.com/gomlx/gomlx/graph.(*Exec).findOrCreateGraph(0xc0001de120, {0xc0000da000, 0x2, 0x2}) /home/janpf/Projects/gomlx/graph/exec.go:595 +0x167 github.com/gomlx/gomlx/graph.(*Exec).compileAndExecute(0xc0001de120, 0x1, {0xc0000d7f20?, 0xc0000061c0?, 0xc0000d7f40?}) /home/janpf/Projects/gomlx/graph/exec.go:436 +0x226 github.com/gomlx/gomlx/graph.(*Exec).CallWithGraph(...) /home/janpf/Projects/gomlx/graph/exec.go:386 github.com/gomlx/gomlx/graph.(*Exec).Call(0xcb59a8?, {0xc0000d7f20?, 0xbce4e0?, 0xc62b90?}) /home/janpf/Projects/gomlx/graph/exec.go:372 +0x27 main.main() [[ Cell [5] Line 3 ]] /tmp/gonb_55ad4f33/main.go:24 +0xbe exit status 2
Note
- In the stack-trace above there are 2 lines of interest, that typically help to debug such issues:
- Where in the graph building function
main.SumGraph
function the invalid operation was created: Line 8 of the previous cell.- Where in the
main
function, the graph was attempted to be executed: Line 3 of this cell.- You can enable displaying line-numbers in the JupyterLab with "ESC+L" (upper-case L).
If you want to catch errors, GoMLX provides a small exceptions
library, that defines TryCatch[E]
, that will catch arbitrary panic
(thrown) exceptions. GoMLX only throws error
type of exceptions. So you could do:
import "github.com/gomlx/exceptions"
err := exceptions.TryCatch[error](func() {_ = exec.Call(1.1, 2)})
if err != nil { … }
Tensors¶
Tensors are multidimensional arrays of a given data type (dtypes.DType
) defined in the package github.com/gomlx/gomlx/types/tensors
.
For GoMLX tensors work as containers of data that are used as concrete inputs and outputs for the
execution of computational graphs. There is only basic support to manipulate tensors directly (it includes access directly to its data): instead usually one does that using computational graphs. Tensors have a shape (shapes.Shape
) just like Node
.
When talking about a tensor, we are referring to the package's tensors.Tensor
object (usually a pointer to it).
The tensor raw values can be stored either locally (as a Go slice) or on the device where the graph is being executed. E.g.: a GPU, or even a CPU, in which case it is owned by the backends.Backend
.
While the tensors.Tensor
object handles the synchronization automatically among the two, it's important
that you are aware that there is this distinction, because there is a time cost to transfer data from,
for example, a GPU to the CPU and back.
Effectively, what that means is that, during training, you dont' want to transfer anything locally often, for instance to save a checkpoint. But again, most of this is handled automatically by GoMLX libraries. Also, because some of the resources are allocated in the accelerator (GPU) or not managed by Go, the garbage collector may not be aware of the memory pressure on these devices.
Whenever a graph is executed (with Exec.Call
) and non-tensor values are converted to tensors, and those are transferred to the device doing the execution automatically.
Example:
%%
onePlusExec := NewExec(backend, func (x *Node) *Node {
return OnePlus(x)
})
// exec.Call will return a *tensor.Tensor.
counter := onePlusExec.Call(0)[0]
// counter.Value() will first transfer counter to local with counter.Local().
fmt.Printf("counter.type=%T, counter.shape=%s, counter=%v\n", counter, counter.Shape(), counter.Value())
for ii := 0; ii < 10; ii++ {
// Since the counter is not being used locally between the calls, the tensor will only use the
// device storage.
counter = onePlusExec.Call(counter)[0]
}
// counter.Value() will first transfer the counter value locally, and then convert to a Go value.
fmt.Printf("counter=%v\n", counter.Value())
counter.type=*tensors.Tensor, counter.shape=(Int64), counter=1 counter=11
Note:
- In the first call to
onePlusExec.Call(0)
, the Go constant0
is automatically converted to a*tensor.Tensor
by thegraph.Exec
and fed to the graph. It returns a[]*tensor.Tensor
with one element, containing0+1=1
.- The returned tensor is initially only stored on device. But when we print it, it is automatically transferred to Go.
- While executing the loop the counter will not be transferred locally, everything happens efficiently in the accelerator device (GPU, TPU, or even the CPU executor).
- When we print the final result in
counter.Value()
again the data is transferred locally.
There are several ways to create tensors, the most common:
tensors.FromValue[S](value S)
: Generics conversion, works with any supportedDType
scalar as well as with any arbitrary multidimensional slice. Slices of rank > 1 must be regular, that is all the sub-slices must have the same shape. E.g:FromValue([][]float{{1,2}, {3, 5}, {7, 11}})
. There is also a non-generic versiontensors.FromAnyValue(value any)
. This is whatExec.Call()
uses to convert arbitrary values to a tensor.FromShape(shape shapes.Shape)
: creates a Local tensor with the given shape, and initialized to zero. See documentation onTensor.MutableFlatData(...)
to mutate tensors in place.FromScalarAndDimensions[T](value T, dimensions ...int)
: creates a tensor with the given dimensions, filled with the scalar value given.T
must be one of the supported types.
See more documentation in github.com/gomlx/gomlx/types/tensors.
Errors in the manipulation of Tensors (e.g. invalid values) are also reported back with exception thrown with panic
, with full stack-traces, just as
with the graph
package described in the previous package. The errors can easily be caught (with recover()
or with exceptions.TryCatch
helper) when needed.
Gradients¶
Another important functionality required to train machine learning models based on gradient descent is calculating the gradients of some value being optimized with respect to some variable / quantity.
GoMLX does this statically, during graph building time. It adds to the graph the computation for the gradient.
Example: let's calculate the gradient of the function $f(x, y) = x^2 + xy$ for a few values of $x$ and $y$. Algebraically we have:
$$ \begin{aligned} (x, y) &= x^2 + xy \\ df/dx(x,y) &= 2x + y \\ df/dy(x,y) &= x \\ \end{aligned} $$
func f(x, y *Node) *Node {
return Add(Square(x), Mul(x, y))
}
func gradOfF(x, y *Node) (output, gradX, gradY *Node) {
output = f(x, y)
reduced := ReduceAllSum(output) // In case x and y are not scalars.
grads := Gradient(reduced, x, y)
gradX, gradY = grads[0], grads[1] // df/dx, df/dy
return output, gradX, gradY
}
%%
exec := NewExec(backend, gradOfF)
x := []float64{0, 1, 2}
y := []float64{10, 20, 30}
results := exec.Call(x, y)
fmt.Printf("f(x=%v, y=%v)=%v,\n\tdf/dx=%v,\n\tdf/dy=%v\n", x, y, results[0].Value(), results[1].Value(), results[2].Value())
f(x=[0 1 2], y=[10 20 30])=[0 21 64], df/dx=[10 22 34], df/dy=[0 1 2]
Note:
- For now GoMLX only calculates gradients of a scalar (typically a model loss) with respect to arbitrary tensors. It does not yet calculate jacobians, that is, if the value we are deriving is not a scalar. That's the reason of the
ReduceAllSum
in the example, the result is the derivative of the sum of all the 3 inputs.- A question that may arise is whether it calculates the second derivative (hessian). In principle the machinery to do that is in place, but there are 2 limitations: (1) not all operations have their derivative implemented, in particular some of the operations that are only used when calculating the first derivative; (2) it only calculates the gradient with respect to a scalar, in most cases the hessian will be the gradient of a gradient, usually of higher rank -- Btw, contributions to the project here are welcome ;)
// Removing the previous definitions of `f` and `gradOfF`
%rm gradOfF f
. removed func gradOfF . removed func f
Variables and Context¶
Computation graphs are pure functions: they have no state, they take inputs, return outputs and everything in between is transient.
For Machine Learning as well as many types of computations, it's convenient to store intermediary results (the model parameters for ML) in between the execution of the computation graphs.
For that GoMLX offers the context.Context
object
(completely unrelated to the usual Go's context
package), and a corresponding context.Exec
.
It is a container of variables (whose values are tensors usually stored on device), and it manages automatically its updates, passing it as extra
inputs and taking them out (if changed) as extra outputs of the computation graph.
This may sound more complex than it is in practice. Let's see an example, where we try to find $argmin_{x}{f(x)}$ where $f(x) = ax^2 + bx + c$. If we solve it literally we should get, for $a > 0$, $argmin_{x}{f(x)} = \frac{-b}{2a}$. Instead we solve it numerically, using gradient descent:
import "flag"
var (
flagA = flag.Float64("a", 1.0, "Value of a in the equation ax^2+bx+c")
flagB = flag.Float64("b", 2.0, "Value of b in the equation ax^2+bx+c")
flagC = flag.Float64("c", 4.0, "Value of c in the equation ax^2+bx+c")
flagNumSteps = flag.Int("steps", 10, "Number of gradient descent steps to perform")
flagLearningRate = flag.Float64("lr", 0.1, "Initial learning rate.")
)
// f(x) = ax^2 + bx + c
func f(x *Node) *Node {
f := MulScalar(Square(x), *flagA)
f = Add(f, MulScalar(x, *flagB))
f = AddScalar(f, *flagC)
return f
}
// minimizeF does one gradient descent step on F by moving a variable "x",
// and returns the value of the function at the current "x".
func minimizeF(ctx *context.Context, graph *Graph) *Node {
// Create or reuse existing variable "x" -- no graph operation is created with this, it's
// only a reference.
xVar := ctx.VariableWithShape("x", shapes.Make(dtypes.Float64))
x := xVar.ValueGraph(graph) // Read variable for the current graph.
y := f(x) // Value of f(x).
// Gradient always return a slice, we take the first element for grad of X.
gradX := Gradient(y, x)[0]
// stepNum += 1
stepNumVar := ctx.VariableWithValue("stepNum", 0.0) // Creates the variable if not existing, or retrieve it if already exists.
stepNum := stepNumVar.ValueGraph(graph)
stepNum = OnePlus(stepNum)
stepNumVar.SetValueGraph(stepNum)
// step = -learningRate * gradX / Sqrt(stepNum)
step := Div(gradX, Sqrt(stepNum))
step = MulScalar(step, -*flagLearningRate)
// x += step
x = Add(x, step)
xVar.SetValueGraph(x)
return y // f(x)
}
func Solve() {
ctx := context.New()
exec := context.NewExec(backend, ctx, minimizeF)
for ii := 0; ii < *flagNumSteps-1; ii++ {
_ = exec.Call()
}
y := exec.Call()[0]
x := ctx.InspectVariable(ctx.Scope(), "x").Value()
stepNum := ctx.InspectVariable(ctx.Scope(), "stepNum").Value()
fmt.Printf("Minimum found at x=%g, f(x)=%g after %f steps.\n", x.Value(), y.Value(), stepNum.Value())
}
The code above created Solve()
that will solve for the values set by the flags a
, b
, and c
.
Let's try a few values:
Note:
%%
in GoNB automatically creates afunc main()
and passes the extra arguments to the Go program.
%% --a=1 --b=2 --c=3 --steps=10 --lr=0.5
Solve()
Minimum found at x=-0.9999999999999999, f(x)=2 after 10.000000 steps.
%% --a=2 --b=12 --c=20 --steps=10 --lr=0.5
Solve()
Minimum found at x=-3, f(x)=2 after 10.000000 steps.
Note:
- We are using
context.Exec
, while before we were usingcomputation.Exec
. The main difference is thatcontext.Exec
compiles and executes graph functions that take a context as its first parameter, and it automatically handles the passing of variables as side inputs and outputs (for those variables updated) of the computation graph.- During graph building, we access and set the variables with
Variable.ValueGraph
andVariableSetValueGraph
: They return/take*Node
types, that can be used in the graph.- Outside graph building, we can access the last value set to a variable by using
Variable.Value()
andVariable.SetValue
. They return/take concrete*tensor.Tensor
types.- We created two variables, one for "x" that we were optimizing to minimize $f(x)$, and one variable "stepNum", used to keep track how many steps were already executed.
- Yes, if we set
--lr=1
(the learning rate), it will get to the minimum in one step for the quadratic f(x). 😉
There is more to context.Context
, some we'll present
on the next section on Machine Learning, others can be found in its documentation. A few things worth advancing:
Context
is always configured at a certain scope, and variables are unique within its scope. Scope is easily changed withctx.In("new_scope")
. So theContext
object is a scope (a string) and a pointer to the actual data (variables, graph and model parameters).Context
also holds model parameters (concrete Go values), which are also scoped. Those can be hyperparameters for the models (learning rate, regularization, etc.) or anything the user or any library may create a convention for. They are more convenient than using hundreds of flags. See example of its usage in the UCI Adult example.- Similarly
Context
also holds "Graph parameters". Those are very similar to model parameters, but they have one value per Graph. So if a model is created with parameters of different shape (or for training/evaluation), each version will have its own Graph parameters. Don't worry about this now -- if you need it later when building complex graphs, the funcitonality will be there.
Machine Learning (ML)¶
The previous sections presented the fundamentals of what is needed to implement machine learning. This section we present various sub-packages that provide high level ML layers and tools that make building, training and evaluating a model trivial.
First is the package layers
(see code in ml/layers. It provides several composable ML layers.
These are graph building functions, most of which take a *context.Context
as first parameter, where they store
variables or access hyperparameters.
There are several such layers, for example: layers.Dense
, layers/fnn.New
, layers/kan.New
,
layers.Dropout
, layers.PiecewiseLinearCalibration
(very good for normalization of inputs), layers.BatchNorm
,
layers.LayerNorm
, layers.Convolution
, layers.MultiHeadAttention
(for Transformers
layers), etc.
The package train
offers two main functionalities: train.Trainer
will build a train step and an eval step
graph, given a model graph building function and an optimizer. This graph can be executed in sequence to train a model.
The package also provides train.Loop
that simply loop over a train.Dataset
interface, reading data and feeding
it to Trainer.TrainStep
or Trainer.EvalStep
, along with executing configurable hooks on the training loop. One
such hooks is provided by gomlx/train/commandline.AttachProgressBar(loop)
, it pretty prints the progress during
training on the command line.
There are also a collection of optimizers, initializers, loss functions, metrics, etc. For any functionality there is always an example under the examples/ subdirectory.
Let's look at the simplest ML example: linear
,
which trains a linear model on nosiy generated data.
Here are the constants of our problem:
const (
CoefficientMu = 0.0
CoefficientSigma = 5.0
BiasMu = 1.0
BiasSigma = 10.0
)
To generate synthetic data it first randomly chose some random coefficients and bias based on which data is
generated. These selected coefficients is the ones we want to try to learn using ML. The coefficients could
have been selected in Go directly using math/random
, but just for fun, we do it using a computation graph.
import (
"github.com/gomlx/gomlx/types/shapes"
"github.com/gomlx/gomlx/types/tensors"
)
// initCoefficients chooses random coefficients and bias. These are the true values the model will
// attempt to learn.
func initCoefficients(backend backends.Backend, numVariables int) (coefficients, bias *tensors.Tensor) {
e := NewExec(backend, func(g *Graph) (coefficients, bias *Node) {
rngState := Const(g, RngState())
rngState, coefficients = RandomNormal(rngState, shapes.Make(dtypes.Float64, numVariables))
coefficients = AddScalar(MulScalar(coefficients, CoefficientSigma), CoefficientMu)
rngState, bias = RandomNormal(rngState, shapes.Make(dtypes.Float64))
bias = AddScalar(MulScalar(bias, BiasSigma), BiasMu)
return
})
results := e.Call()
coefficients, bias = results[0], results[1]
return
}
%%
coef, bias := initCoefficients(backend, 3)
fmt.Printf("Example of target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value())
Example of target: coefficients=[-3.4 -7.17 3.85], bias=-15.3
Note
- This code should look familiar, using things we presented earlier in the tutorial. It creates a computation graph to generate randomly the
coefficients
andbias
. Then it executes it and returns the result.- Notice that since the computation graph is functional: we need to pass around the random number generator state, which gets updated at each call to
RandomUniform
orRandomNormal
.
- Alternatively the
context.Context
introduced earlier can keep the state as a variable, and provides a simpler interface: seeContext.RandomUniform
andContext.RandomNormal
.
Next, we want to generate the data (examples): we generate random inputs, and then the label using the selected coefficients plus some normal noise.
func buildExamples(backend backends.Backend, coef, bias *tensors.Tensor, numExamples int, noise float64) (inputs, labels *tensors.Tensor) {
e := NewExec(backend, func(coef, bias *Node) (inputs, labels *Node) {
g := coef.Graph()
numFeatures := coef.Shape().Dimensions[0]
// Random inputs (observations).
rngState := Const(g, RngState())
rngState, inputs = RandomNormal(rngState, shapes.Make(coef.DType(), numExamples, numFeatures))
coef = ExpandDims(coef, 0)
// Calculate perfect labels.
labels = ReduceAndKeep(Mul(inputs, coef), ReduceSum, -1)
labels = Add(labels, bias)
if noise > 0 {
// Add some noise to the labels.
var noiseVector *Node
rngState, noiseVector = RandomNormal(rngState, labels.Shape())
noiseVector = MulScalar(noiseVector, noise)
labels = Add(labels, noiseVector)
}
return
})
examples := e.Call(coef, bias)
inputs, labels = examples[0], examples[1]
return
}
%%
coef, bias := initCoefficients(backend, 3)
numExamples := 5
inputsTensor, labelsTensor := buildExamples(backend, coef, bias, numExamples, 0.2)
fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", coef.Value(), bias.Value())
fmt.Printf("%d dataset examples:\n", numExamples)
inputs := inputsTensor.Value().([][]float64)
labels := labelsTensor.Value().([][]float64)
for ii := 0; ii < numExamples; ii ++ {
fmt.Printf("\tx=%0.3v; label=%0.3v\n", inputs[ii], labels[ii])
}
Target: coefficients=[7.07 1.15 -3.4], bias=4.38 5 dataset examples: x=[0.0184 1.35 -0.627]; label=[8.5] x=[0.876 -1.57 0.0963]; label=[8.32] x=[0.493 -0.332 0.89]; label=[3.87] x=[-0.535 0.923 1.01]; label=[-2.15] x=[0.701 0.357 0.696]; label=[7.72]
Dataset¶
Now the first new concept of this section: train.Dataset
is the interface that is used to feed data to
during a training loop or evaluation.
There are three methods: Dataset.Yield
that returns the next batch of examples;
Dataset.Reset
restarts the dataset, for datasets that don't loop indefinitely;
Finally Dataset.Name
returns the dataset name, usually used for metric names, logging and printing.
Datasets also yield a spec
, an opaque type for GoMLX (defined as any
), that allows the dataset to communicate
to the model which type of data it is generating. In our case, since it's always the same data, we don't need it,
so we keep it set to nil
.
If one would implement a dataset like a generic CSV file, one may want to communicate
to the model the field names to the Model throught the spec
, for instance. See the documentation for more details.
For our linear synthetic data we implement the simplest train.Dataset
: the whole data is pre-generated, and we return a giant batch with the full data every time -- we could also have more easily used data.InMemoryDataset
for that, but for didactic purposes we create it from scratch:
import "github.com/gomlx/gomlx/ml/train"
// TrivialDataset always returns the whole data.
type TrivialDataset struct {
name string
inputs, labels []*tensors.Tensor
}
var (
// Assert Dataset implements train.Dataset.
_ train.Dataset = &TrivialDataset{}
)
// Name implements train.Dataset.
func (ds *TrivialDataset) Name() string { return ds.name }
// Yield implements train.Dataset.
func (ds *TrivialDataset) Yield() (spec any, inputs, labels []*tensors.Tensor, err error) {
return ds, ds.inputs, ds.labels, nil
}
// IsOwnershipTransferred tells the training loop that the dataset keeps ownership of the yielded tensors.
func (ds *TrivialDataset) IsOwnershipTransferred() bool {
return false
}
// Reset implements train.Dataset.
func (ds *TrivialDataset) Reset() {}
Note:
- More often it is more work pre-processing data than actually building an ML model ... that's life 🙁
- In the examples/ subdirectory we implement
train.Dataset
for some well known data sets: UCI Adult, Cifar-10, Cifar-100, IMDB Reviews, Kaggle's Dogs vs Cats, Oxford Flowers 102. These can be used as libraries to easily try different models. If you are workig on public datasets, please consider contributing similar libraries.- GoMLX also include the
data.InMemoryDataset
, which can be created from tensors in one line
The package github.com/gomlx/gomlx/ml/data provides several tools to facilitate the work here:
Parallel
: parallelizes any dataset, includes some buffer.InMemory
: reads a dataset into (accelerator) memory, and then serves it from there -- greatly accelerates training. We could have used that instead of definingTrivialDataset
, but we left it because it's common to specialize datasets.ConstantDataset
: trivial Dataset that returns always empty inputs and labels. Used when one is going to generate the data dynamically -- it could also have been used for this example. See example in KAN Shapes demo.- Downloading of datasets (with progress-bar) and checksum functions.
ModelFn¶
Next we build a model, that for our train
package means implementing a function with the following signature:
type ModelFn func(ctx *context.Context, spec any, inputs []*graph.Node) (predictions []*graph.Node)
It takes a context.Context
for the variables and hyperparameters, the spec
and a slice of inputs
—
the last two are fed by Dataset.Yield
above. It returns a slice of predictions
— is most cases there
is just one value in the slice (only one prediction). During training predictions
fed to the loss function,
and during inference they can be returned directly.
Our linear example has the simplest model possible:
import "github.com/gomlx/gomlx/ml/context"
func modelGraph(ctx *context.Context, spec any, inputs []*Node) ([]*Node) {
_ = spec // Not needed here, we know the dataset.
logits := layers.DenseWithBias(ctx, inputs[0], /* outputDim= */ 1)
return []*Node{logits}
}
Note
- It uses the
layers.DenseWithBias
layer, which simply multiplies the input by a learnable matrix (weights) and add a learnable bias. It's the most basic building block of neural networks (NNs). The implementation oflayers.DenseWithBias
is pretty simple, and worth checking out to refresh how variables from theContext
are used.- Since it's a linear model, we don't use an activation function. The usual are available for NNs (
Relu
,Sigmoid
,Tanh
and more to come).- The
spec
parameter allows the creation of aModelFn
that can be used for different types of data. The dataset can Yield also aspec
about the type of data it is reading. Each different value ofspec
will trigger the the creation of a different computation graph, so ideally there would be at most a few types of different data sourcespec
. Most commonly there is only one, like in this example, and the parameter can be ignored.
Trainer and Loop¶
The last part is put together a train.Trainer
and train.Loop
objects in our main()
function. The
first stitches together the model, the optimizer and the loss function, and is able to run training
steps and evaluations. The second, train.Loop
, loops over the dataset executing a training step at
a time and supporst a subscription (hooking) system, where one attaches things like a progress bar,
or plotting of a graph.
import (
"os"
"github.com/gomlx/gomlx/ui/commandline"
"github.com/gomlx/gomlx/ml/layers/regularizers"
)
var (
flagNumExamples = flag.Int("num_examples", 10000, "Number of examples to generate")
flagNumFeatures = flag.Int("num_features", 3, "Number of features")
flagNoise = flag.Float64("noise", 0.2, "Noise in synthetic data generation")
flagNumSteps = flag.Int("steps", 100, "Number of gradient descent steps to perform")
flagLearningRate = flag.Float64("lr", 0.1, "Initial learning rate.")
)
// AttachToLoop decorators. It will be redefined later.
func AttachToLoop(loop *train.Loop) {
commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
}
// TrainMain() does everything to train the linear model.
func TrainMain() {
flag.Parse()
// Select coefficients that we will try to predic.
trueCoefficients, trueBias := initCoefficients(backend, *flagNumFeatures)
fmt.Printf("Target: coefficients=%0.3v, bias=%0.3v\n", trueCoefficients.Value(), trueBias.Value())
// Generate training data with noise.
inputs, labels := buildExamples(backend, trueCoefficients, trueBias, *flagNumExamples, *flagNoise)
fmt.Printf("Training data (inputs, labels): (%s, %s)\n\n", inputs.Shape(), labels.Shape())
dataset := &TrivialDataset{"linear", []*tensors.Tensor{inputs}, []*tensors.Tensor{labels}}
// Creates Context with learned weights and bias.
ctx := context.New()
ctx.SetParam(optimizers.ParamLearningRate, *flagLearningRate) // = "learning_rate"
ctx.SetParam(regularizers.ParamL2, 1e-3) // 1e-3 of L2 regularization.
// train.Trainer executes a training step.
trainer := train.NewTrainer(backend, ctx, modelGraph,
losses.MeanSquaredError,
optimizers.StochasticGradientDescent(),
nil, nil) // trainMetrics, evalMetrics
loop := train.NewLoop(trainer)
AttachToLoop(loop)
// Loop for given number of steps. must.M1() panics, if loop.RunSteps returns an error.
metrics := must.M1(loop.RunSteps(dataset, *flagNumSteps))
_ = metrics // We are not interested in them in this example.
// Print learned coefficients and bias -- from the weights in the dense layer.
fmt.Println()
coefVar, biasVar := ctx.GetVariableByScopeAndName("/dense", "weights"), ctx.GetVariableByScopeAndName("/dense", "biases")
learnedCoef, learnedBias := coefVar.Value(), biasVar.Value()
fmt.Printf("Learned: coefficients=%0.3v, bias=%0.3v\n", learnedCoef.Value(), learnedBias.Value())
}
%%
TrainMain()
Target: coefficients=[-5.43 -1.65 -4.84], bias=6.74 Training data (inputs, labels): ((Float64)[10000 3], (Float64)[10000 1]) Training (100 steps): 100% [========================================] (4946 steps/s) [step=99] [loss+=0.144] [~loss+=5.162] [~loss=5.120] Learned: coefficients=[[-5.31] [-1.62] [-4.73]], bias=[6.59]
Note:
- Hyperparameters are set on the context. Layers and optimizers can define their own hyperparemeters independently.
Context
uses a scoping mechanism (like directories), and hyperparameters can take specialized values under specific scopes -- e.g.: doingctx.In("dense_5").SetParam(regularizers.ParamL2, 0.1)
would set L2 regularization only for the layerdense_5
of the model to0.1
.- The
trainer
constructor also takes as input arbitrary metrics (for training and evaluation). The metric of the loss of the last batch, and a moving average of the loss are always included automatically. There are many others, that can be means or moving averages, etc.- The 'Loop' object is very flexible. One can attach any functionality (hooks) to be executed during training with
OnStart
,OnStep
,OnEnd
,EveryNSteps
,NTimesDuringLoop
,PeriodicCallback
(e.g.: every N minutes) andExponentialCallback
.- The most common such functionality is the
commandline.AttachProgressBar
. There is also a plotting of any arbitrary metric or any arbitraryNode
in the computation graph.Loop.RunSteps()
returns also the final metrics from the training, usually printed out.
Training and Plotting¶
As our last example, let's train it "for real", that is, with more steps.
And to make things prettier, let's attach also a plot of the metrics registered. In our example the only metrics are the default ones, the batch and mean of the loss -- the mean squared error.
import "github.com/gomlx/gomlx/ui/gonb/plotly"
import "os"
func AttachToLoop(loop *train.Loop) {
commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.
_ = plotly.New().Dynamic().ScheduleExponential(loop, 50, 1.1)
}
%% --steps=5000
TrainMain()
Target: coefficients=[-1.74 3.39 1.23], bias=3.83 Training data (inputs, labels): ((Float64)[10000 3], (Float64)[10000 1])
Training (5000 steps): 3% [>.......................................] (5847 steps/s) [0s:0s] [step=164] [loss+=0.058] [~loss+=0.864] [~loss=0.851]
Training (5000 steps): 100% [========================================] (7731 steps/s) [step=4999] [loss+=0.056] [~loss+=0.056] [~loss=0.040]
Metric: loss
Learned: coefficients=[[-1.73] [3.38] [1.23]], bias=[3.83]
Note:
- The
gomlx/examples/notebook/gonb/plotly
package will automatically plot all the metrics registered in the trainer. When it attaches itself toloop
it collects the metric values (and run evaluation on any requested datasets), and at the end of it, plot it.- Optionally, with
Plots.Dynamic()
it also plots intermediary results, displaying the progress of the training as it happens.- Previously, we added a little L2 regularization as a hyperparameter. The
layers.DenseWithBias
layer picks up on that hyperparameter, and adds that bit of L2 regularization. GoMLX tracks the loss with/without regularization separately, and one can see the difference in the two lines.
Other examples¶
There is much more in the libraries and examples.
GoMLX libraries are well documented, and the implementation is generally simple to use and understand -- don't hesitate into diving into the code.
For machine learning, we highly recommend diving into the context.Context
package/objects, and looking at the layers
library, to learn how many of the common ML layers work.
Debugging¶
Unfortunately, the computers just "don't get it": they do exactly what we tell them to do, as opposed to what we want them to do, and thus programs fail or crash. GoMLX provides different ways to track down various types of errors. The most commonly used below:
Good old "printf"¶
It's convenient because of Go fast compilation (change something and run to see what one gets is
almost instant). Logging results to stdout is a valid way of developing. During graph building development,
often one prints the shape of the Node
being operated to confirm (or not) one's expectations.
Errors reported with panic
-- Go "exceptions".¶
Errors during the building of the graph are reported back with panic
. These can be caught in Go with recover
,
but GoMLX offers the exceptions
library to make it easy.
All libraries always panic
helpful error messages -- they can be printed with "%+v"
to get
full stack-trace output.
Node Shape Asserts¶
During the writing of complex models, it's very common to add comments on the expected shapes of the graph nodes, to facilitate the reader (and developer) of the code to have the right mental model of what is going on.
GoMLX provides a series of assert methods that can be used instead. They serve both as documentation, and an early exit in case of some unexpected results.
For example, a modelGraph
function could contain:
batch_size := inputs[0].Shape().Dimensions[0]
...
layer := Concatenate(allEmbeddings, -1)
layer.AssertDims(batchSize, -1) // 2D tensor, with batch size as the leading dimension.
Although using these when building graphs is the most common case, there are similar assert functions for tensors and shapes themselves in the package gomlx/types/shapes
.
Graph Execution Logging¶
Every Node
of the graph can be flagged with SetLogged(msg)
.
The executor (Exec
) will at the end of the execution log all these values. The default logger
(set with Exec.SetLogger
) will simply print the message msg
along with the value of the Node
of
interest. Creating a specialized loggers that handle arbitrary nodes is trivial.
Catching NaN
and Inf
in your training¶
This is a common source of headaches when training complex models.
The package nanlogger
implements a specialized logger and monitor results on arbitrary nodes in the graph,
and will panic
with custom messages (and a stack-trace) a the first sight of a NaN
(or ±Inf
).
More Debugging¶
Tests and these methods have been enough to develop most of GoMLX so far. But there are other debugging tools that could be made, see discussion in the Debugging document. Let us know if you need something specialized.
Happy coding and good luck on modeling!!