Aloha Visitor, welcome to

Jan Ebert

's
! You are a neutral human Tourist.


1. Setup

1.1. About this Text

I created this as a slide-based presentation for the Helmholtz AI Food for Thought seminar on December 02, 2021, in order to introduce researchers from various backgrounds to JAX. This text was produced from the same Org source, with some extra commentary text interleaved. That’s why the text may feel choppy at times and code snippets are more compressed instead of following good style. I tried my best to explain everything I mentioned in the presentation as well.

Also, this is the “extended edition”, including many more code snippets as well as some behind-the-scenes stuff. This way, everything should be reproducible and you should have a good base reference in case you decide to pick up JAX.

You can download the original slides here.

1.2. Environment Setup

This is the environment I used. I gave specific version numbers as comments in case you need the better reproducibility.

Sadly, I was unable to compile JAX with GPU support; the version of CUDA the official NVIDIA drivers support for Ubuntu 20.04 is no longer supported by JAX.

That’s why all numbers I will show you need to be taken with a grain of salt – we aren’t able to use JAX’ killer feature!

conda create -p env python=3.8
conda activate ./env
# `pytorch` version: 1.10.0
conda install pytorch cudatoolkit=10.2 -c pytorch
python -m pip install --upgrade pip
# `jax` version: 0.2.25
# `jaxlib` version: 0.1.74
python -m pip install --upgrade jax jaxlib
# I only have the CPU version.
clang --version | sed 's/ *$//'
clang version 10.0.0-4ubuntu1
Target: x86_64-pc-linux-gnu
Thread model: posix
InstalledDir: /usr/bin

2. Overview

Topic for Today
import functools
import os
import time
import timeit

import jax
import jax.numpy as jnp
import numpy as np
import torch
Behind-the-Scenes Setup

The following setup code can safely be ignored (skip ahead) but achieves the following:

  • XLA IR output in directory generated.
  • 2 simulated devices for XLA.
  • Activate/disable GPUs in both JAX and PyTorch.
  • Less print output.
  • Initialize JAX and PyTorch so we don’t see warnings later.
# Need to do this before even the `jax.config` update.
os.environ['XLA_FLAGS'] = ' '.join([
    os.getenv('XLA_FLAGS', ''),
    '--xla_dump_to=./generated',
    '--xla_force_host_platform_device_count=2',
])
jax.lib.xla_bridge.get_backend.cache_clear()

try:
    gpus = jax.local_devices(None, 'gpu')
except RuntimeError:
    gpus = None
has_gpu = bool(gpus) and torch.cuda.is_available()

# By default, JAX runs on the GPU (if available) while PyTorch by
# default runs on the CPU. So we choose the common denominator depending
# on GPU availability.
if not has_gpu:
    # Force JAX on CPU.
    jax.config.update('jax_platform_name', 'cpu')
    # PyTorch will complain if we call this without a GPU active.
    torch.cuda.synchronize = lambda: None
else:
    # Force PyTorch on GPU.
    torch.cuda.set_device(0)

print_precision = 3
np.set_printoptions(precision=print_precision)
jnp.set_printoptions(precision=print_precision)
torch.set_printoptions(precision=print_precision)

def round_complex(val, precision):
    if isinstance(val, complex):
        return round(val.real, precision) + round(val.imag, precision) * 1j
    return round(val, precision)

def round_tree(tree, precision=print_precision):
    return jax.tree_util.tree_map(lambda x: round_complex(x, precision), tree)

# Initialize JAX.
jnp.array(0)
# Initialize PyTorch.
torch.tensor(0)
Why is it Cool?
  • NumPy/SciPy on the GPU.
  • No mutable state means much joy: easier maintainability, easier scalability.
  • No need for manual batching.
  • Complex number differentiation previously1 not possible on PyTorch. Certain unsupported cases may still be out there.
  • User-friendly parallelism.
JAX’ Libraries
jax
Mostly important for function transformations/compositions.
jax.numpy
NumPy replacement.
jax.scipy
SciPy replacement.
jax.random
Deterministic randomness.
jax.lax
Lower-level operations.
jax.profiler
Analyze performance.

Just a selection, there are more!

3. Mutating Immutable Arrays

Mutating Arrays in NumPy

In NumPy:

x = np.eye(2)
x[0, 0] = 2
x[:, -1] *= -3
x
2 0
0 -3

Some things to note for those not familiar with Python and/or NumPy:

  • In Python, indices start at 0, so the first element of an array is at index 0.
  • NumPy and JAX call any \( n \)-dimensional tensor an array, so do not expect a 1-dimensional vector when you see the word “array”.
  • The colon “:” we used for indexing selects all elements over that dimension. In our case, this resulted in a whole row of the matrix being indexed.
  • Negative values index from the end (but, confusingly, starting at 1 this time). So indexing with -1 yields the very last value of an array. Combining the colon with -1 in that order means we index the last row of the array.
  • The shape of an array is a collection of its dimensionalities (x’s shape is (2, 2)).
  • The dtype of an array is the type of its elements.
Mutating Arrays in JAX

In JAX:

x = jnp.eye(2)

try:
    x[0, 0] = 2
except TypeError:
    pass
else:
    raise TypeError('array is mutable')

x = x.at[0, 0].set(2)
x = x.at[:, -1].multiply(-3)
x
2 0
0 -3

You will receive helpful errors in case you forget this!

The try-except-else block of the code has the following effect: if the statement in the try block x[0, 0] = 2 raises a TypeError, we simply ignore it. That’s what the except TypeError block does. If we did not get any error, i.e. no exception was caught, the else block executes and raises a TypeError telling us that the array was mutable after all. Since the code executed and evaluated just fine, we know that JAX raised an error upon trying to mutate the array, meaning JAX arrays are not mutable as expected. (Whew!)

You may think performance tanks due to the extra allocations; however, these are optimized away to in-place mutations during JIT (just-in-time) compilation.

Closures (Functional Programming Interlude)

Function capturing something from an outer environment. (You’ve been using this whenever you refer to a non-local, e.g. global, variable inside a function.) Very useful for programming with JAX, but be careful about mutable state!

def create_counter():
    count = 0

    def inc():
        nonlocal count
        count += 1
        return count

    return inc

counter = create_counter()
', '.join(map(str, [counter(), counter(), counter()]))
1, 2, 3

If you don’t know Python: the def indicates a function definition.

In Python, closure capturing rules are the same as for function argument passing, also called “pass-by-sharing” – non-primitives are referenced (imagine a C pointer on that object). So outside modifications to non-primitives will be visible inside the closure as well!

This behavior varies between programming languages, so keep in mind that this is a Python implementation detail.

4. Randomness in JAX

Randomness in JAX

Stateless RNG makes reproducibility fun and easy as pie!2

seed = 0
rng_key = jax.random.PRNGKey(seed)

def randn(rng_key, shape, dtype=float):
    rng_key, subkey = jax.random.split(rng_key)
    rands = jax.random.normal(subkey, shape, dtype=dtype)
    return rng_key, rands

rng_key, rands = randn(rng_key, (2, 3))
rands
-1.458 -2.047 -1.424
1.168 -0.976 -1.272

RNG keys are the way JAX represents RNG state. We called the key we use right away subkey even though it is the exact same kind of object as rng_key. This naming is simply a nice pattern to use because we know from the names which keys we are going to “consume” and which we will pass along.

You can also split off more than one key by passing an additional integer to jax.random.split.

Another thing I wanted to show off here is the pattern of using standard Python types with JAX. Since we only specified float instead of the more specific jnp.float32 or jnp.float64, JAX will automatically pick the correct one based on its configured default.

RNG Keys
_, rands = randn(rng_key, (2, 3))
rands
0.932 -0.004 -0.74
-1.427 1.06 0.929

Notice we did not update our rng_key, instead discarding that part of the result. Can you guess what happens when we generate numbers again?

rng_key, rands = randn(rng_key, (2, 3))
rands
0.932 -0.004 -0.74
-1.427 1.06 0.929

This time, we updated our rng_key: the world of randomness is whole again! Notice this is a really easy way to produce random numbers that you (hopefully) want to stay the same.

rng_key, rands = randn(rng_key, (2, 3))
rands
1.169 0.312 -0.571
0.137 0.742 0.038

5. grad: Advanced AutoDiff

5.1. grad How-To

How to Take Gradients

JAX’ grad function transformation takes a function we want to differentiate as input and returns another function calculating the original one’s gradient given arbitrary inputs. By default, it takes the derivative with regard to the first argument. Multiple jax.grad applications can be nested to take higher-order derivatives.

Another Python syntax thing here, the ** is the exponentiation operator. If you’re curious, ^ is used for a bitwise exclusive or.

def expo_fn(x, y):
    return x**4 + 2**y + 3

x = 1.0
y = 2.0

grad_x_fn = jax.grad(expo_fn)  # 4x^3
grad2_x_fn = jax.grad(grad_x_fn)  # 12x^2
grad3_x_fn = jax.grad(grad2_x_fn)  # 24x
[
    ['function', r'\partial{}expo_fn/\partial{}\(x\)'],
    ['grad_x', grad_x_fn(x, y).item()],
    ['grad2_x', grad2_x_fn(x, y).item()],
    ['grad3_x', grad3_x_fn(x, y).item()],
]
function ∂expo_fn/∂\(x\)
grad_x 4.0
grad2_x 12.0
grad3_x 24.0
Differentiating Different Arguments

To differentiate with regard to other arguments of our functions, we pass jax.grad the indices of those arguments in the argnums argument. We can also specify multiple argnums.

x = 1.0
y = 2.0

grad_y_fn = jax.grad(expo_fn, argnums=1)  # ln(y) · 2^y
grad_xy_fn = jax.grad(expo_fn, argnums=(0, 1))
[
    ['function', 'result'],
    ['grad_y', grad_y_fn(x, y).item()],
    ['grad_xy', [g.item() for g in grad_xy_fn(x, y)]],
]
function result
grad_y 2.7725887298583984
grad_xy (4.0 2.7725887298583984)
Differentiation and Outputs

In machine learning, we really like to monitor our loss values, which are the non-differentiated results of the function we differentiate (i.e. the value we want to minimize). In order to not have to evaluate the function twice and lose precious performance, JAX offers the jax.value_and_grad function transformation which returns a new function that calculates the result of the function as well as its gradient. Now we can log our losses and sleep well again.

result_grad_fn = jax.value_and_grad(expo_fn)

result, grad = result_grad_fn(x, y)

Fun fact: jax.value_and_grad is actually what jax.grad calls as well, it just tosses result.


Let’s assume the function we want to differentiate has multiple outputs, for example maybe we need to return some new state!

Let’s assume we went through the trouble of collecting all our extra return values in a tuple. We then also changed our function to return a pair (i.e. another tuple) containing (in this order)

  1. the value we want to differentiate through, and
  2. the tuple of extra return values.

We can then supply the has_aux=True argument to jax.grad and happily differentiate again while keeping our state intact:

def poly_fn_and_aux(x, y):
    aux_output = ({'y': y}, 1337)
    return x**4 + (y - 1)**2 + 3, aux_output

grad_aux_fn = jax.grad(poly_fn_and_aux, has_aux=True)
grad, aux = grad_aux_fn(x, y)

Of course, the same works for jax.value_and_grad as well; however, its tree of return values needs some special care to deconstruct:

result_aux_grad_fn = jax.value_and_grad(poly_fn_and_aux, has_aux=True)
(result, aux), grad = result_aux_grad_fn(x, y)

By the way, if you ever want to disable gradient computation inside a jax.grad context, you can use jax.lax.stop_gradient. Its use is a bit unintuitive, so I’d recommend checking out the link.

5.2. Differentiating Spectral Radius

Default Precision Interlude

Simple Spectral Radius
def jax_spectral_radius(mat):
    eigvals = jnp.linalg.eigvals(mat)
    spectral_radius = jnp.max(jnp.abs(eigvals))
    return spectral_radius

def torch_spectral_radius(mat):
    eigvals = torch.linalg.eigvals(mat)
    spectral_radius = torch.max(torch.abs(eigvals))
    return spectral_radius

# Eigenvalues: 1 ± i
ceig_mat = np.array([[1.0, -1.0], [1.0, 1.0]])
jax_mat = jnp.array(ceig_mat)
torch_mat = torch.from_numpy(ceig_mat)
[
    ['function', 'result'],
    ['jax', jax_spectral_radius(jax_mat).item()],
    ['torch', torch_spectral_radius(torch_mat).item()],
    ['sqrt', np.sqrt(2)],
]
JAX’ Default Precision
function result
jax 1.4142135381698608
torch 1.4142135623730951
sqrt 1.4142135623730951

Wait! JAX’ precision is seriously behind PyTorch here! Is PyTorch just more precise or what’s going on?

While both JAX and PyTorch have single-precision (32-bit) floating point numbers as their default dtype, NumPy uses double-precision (64-bit) floats by default.

Now, when we converted the matrix from NumPy to the respective frameworks, JAX’ jnp.array created a new array from NumPy’s, thus converting the dtype to JAX’ default. This leaves us with jax_mat.dtype being jnp.float32. However, PyTorch’s torch.from_numpy adapted the dtype exactly, which is why PyTorch had double the precision to work with.

With that knowledge, let’s make the test more fair by converting torch_mat to single-precision as well:

torch_mat = torch_mat.float()
[
    ['jax', jax_spectral_radius(jax_mat).item()],
    ['torch', torch_spectral_radius(torch_mat).item()],
]
jax 1.4142135381698608
torch 1.4142135381698608

Ahh, all is well; JAX is not lacking in terms of precision after all (at least in this small example).

Let’s check out what the square root of 2 is in single-precision to see just how precise we are:

two_f32 = np.array(2, dtype=np.float32)
[['sqrt_f32', '{:.16f}'.format(np.sqrt(two_f32))]]
sqrt_f32 1.4142135381698608

Gradients and Complex Numbers

Complex Differentiation

Let’s finally differentiate some complex numbers. You may not ever have seen this way to differentiate in PyTorch – it’s the functional way!

Just to clarify again, being able to take this gradient is a rather recent change in PyTorch: stable since PyTorch 1.9.0, June 2021. Originally, I wanted to show here that JAX is capable of differentiating something that PyTorch cannot. The competition has caught up, though!

jax_grad = jax.grad(jax_spectral_radius)(jax_mat)

torch_mat.requires_grad_(True)
torch_rho = torch_spectral_radius(torch_mat)
torch_grad = torch.autograd.grad(torch_rho, torch_mat)

decimals = 3
[
    ['function', 'grad'],
    ['jax', round_tree(jax_grad.tolist(), decimals)],
    ['torch', round_tree(torch_grad[0].tolist(), decimals)],
]
function grad
jax ((0.354 -0.354) (0.354 0.354))
torch ((0.354 -0.354) (0.354 0.354))
Setting Up Complex Input
complex_mat = np.array([[1 + 1j, -1], [1, 1]])
jax_complex_mat = jnp.array(complex_mat)
torch_complex_mat = torch.from_numpy(complex_mat).to(torch.complex64)
[
    ['function', 'result'],
    None,
    ['jax', jax_spectral_radius(jax_complex_mat).item()],
    ['torch', torch_spectral_radius(torch_complex_mat).item()],
]
function result
jax 1.9021130800247192
torch 1.9021130800247192
Differentiating Through Complex Gradients

Due to JAX doing some heavy abstract syntax tree (AST) work, it includes a nice module with tree-related functions called jax.tree_util. We will use it to conjugate the tree of gradients we obtain from jax.grad (for no special reason at all).

jax_grad = jax.grad(jax_spectral_radius)(jax_complex_mat)
jax_conj_grad = jax.tree_util.tree_map(jnp.conj, jax_grad)

torch_complex_mat.requires_grad_(True)
torch_rho = torch_spectral_radius(torch_complex_mat)
torch_grad = torch.autograd.grad(torch_rho, torch_complex_mat)

decimals = 3
[
    ['type', 'gradient'],
    None,
    ['jax', round_tree(jax_grad.tolist(), decimals)],
    ['jax conj', round_tree(jax_conj_grad.tolist(), decimals)],
    ['torch', round_tree(torch_grad[0].tolist(), decimals)],
]
type gradient
jax (((0.38-0.616j) (-0.38-0.235j)) ((0.38+0.235j) (0.145-0.235j)))
jax conj (((0.38+0.616j) (-0.38+0.235j)) ((0.38-0.235j) (0.145+0.235j)))
torch (((0.38+0.616j) (-0.38+0.235j)) ((0.38-0.235j) (0.145+0.235j)))

Oops! It seems there is a discrepancy here…

Without going too much into the math, when optimizing a function with complex inputs and real outputs, steepest-descent algorithms need the conjugate of the complex gradient in order to walk in the correct direction. As PyTorch is a deep learning framework first-and-foremost, it conjugates its gradients by default so users can go plug-and-play when fooling around with complex numbers and optimization.

Why Different Gradients?
  • JAX chose mathematical/API consistency (also same behavior as in Autograd3).
  • However, not the gradients to use for steepest-descent optimization! (Conjugate before.)
  • PyTorch has the more practical default here.

6. jit Compilation via XLA

6.1. Introducing jit

Side Effects in a jit Context

Just-in-Time-Compiling via XLA
def print_hello():
    print('Hello!')  # Side effect!

jax.jit(print_hello)()

Hello!

jit_print_hello = jax.jit(print_hello)
jit_print_hello()
jit_print_hello()
print('... Hello? :(')

… Hello? :(

Before we dive in here, a quick terminology heads-up: “to JIT something”, means “to just-in-time-compile something”.

Multiple interesting things happened in these short snippets:

  1. We did not get any other “Hello!” after the first one in general.
  2. We did not get a second “Hello!” even though we applied jax.jit to print_hello a second time.
  3. We did get a “Hello!” for the first call of the JITted function.

Let’s go through these in order:

  1. The print call is called a side effect in computer science. JAX does not care for these during JIT compilation, it only cares about math – or, to be more exact, whatever comes in and what comes out of the function.
  2. The reason we do not get a second “Hello!” even though we apply jax.jit again (and we would expect the side effect to happen again before the function is compiled) is because JAX caches its compilation output in the background. So if we JIT the same function twice with the same arguments, the previous compilation output will be reused.
  3. JAX traces what happens inside the function on its first call, building a computational graph. This means, the first call of the function executes just like a standard Python function (though even slower due to the computational graph building).

    This also means that JITting functions that result in a large computational graph (for example a Python loop that is executed very many times) can take forever to JIT only because the first tracing of it takes so long. When you encounter this issue, you can replace your loop with control flow substitutes from the jax.lax module.

JITting State

Here, we’ll see that jax.jit can also be used as a decorator. However, because we need to supply another argument to jax.jit, we cannot use it as a decorator that simply. Instead, we need to combine it with the functools.partial operator. An explanation follows after the code block.

class Counter:
    count = 0

    @functools.partial(jax.jit, static_argnums=(0,))
    def inc(self):
        self.count += 1

a = Counter()
print(a.count)
a.inc()
print(a.count)
a.inc()
print(a.count)
0
1
1

Applying functools.partial like this results in the following (actually anonymous) function:

# Result of `functools.partial(jax.jit, static_argnums=(0,))`.
def partial_jit(*args, **kwargs):
    return jax.jit(*args, static_argnums=(0,), **kwargs)
 

This new partial_jit function wraps the inc method of Counter, resulting in the equivalent of the following:

class Counter:
    count = 0

    def inc(self): self.count += 1

Counter.inc = jax.jit(Counter.inc, static_argnums=(0,))
 

I hope that helped make sense of the code. static_argnums basically tells jax.jit to recompile the JITted function for a different argument at that place. In return, we get some freedoms (for example, we would not be able to JIT the function otherwise). We call the arguments at the positions designated by static_argnums static from now on. More on static and non-static arguments later.

JITting State Again
a = Counter()
print(a.count)
a.inc()
print(a.count)
a.inc()
print(a.count)
0
1
1

Due to self being a static argument as specified via static_argnums, the function is recompiled for a new, different self4.

Benchmarking jit on randn

JITting our randn

We’ll now use Python’s timeit module to benchmark a JIT-compiled version of our old friend randn (remember we implemented this in the section on RNG in JAX). We implement a simple wrapper around it in order to initialize the JIT-compiled function before we benchmark it.

You will notice some block_until_ready calls. These are due to JAX’ asynchronous execution engine. Whenever JAX executes code on a non-host device (such as a GPU), it happens asynchronously. This means that the main thread of the program continues to run ahead with a “mock” result, also called a “future”, while the actual result is computed in the background. Only when we actually query the result will we wait until it’s available.

During benchmarking, we would get these futures immediately – that’s not much use to us. So we call the block_until_ready function in order to wait until the result of the computation is actually available. You achieve the same in PyTorch using a call to torch.cuda.synchronize.

jit_randn = jax.jit(randn, static_argnums=(1, 2))

def time_str(code, number=5000):
    # Initialize cache
    exec(code)
    return timeit.timeit(code, globals=globals(), number=number)

randn_time = time_str(
    'randn(rng_key, (100, 100))[1].block_until_ready()')
jit_randn_time = time_str(
    'jit_randn(rng_key, (100, 100))[1].block_until_ready()')
[
    ['function', 'time [s]'],
    ['randn', randn_time],
    ['jit_randn', jit_randn_time],
]
JIT Results
function time [s]
randn 1.0273832870007027
jit_randn 0.7757850260022678

A 25 % reduction! Not bad at all, especially since we shouldn’t really have much room to optimize here.

Let’s see how PyTorch does:

np_rng = np.random.default_rng()
np_randn_time = time_str('np_rng.normal(size=(100, 100))')
torch_randn_time = time_str(
    'torch.randn((100, 100)); '
    'torch.cuda.synchronize()'
)
[['np_randn', np_randn_time], ['torch_randn', torch_randn_time]]
np_randn 0.6026334530001805
torch_randn 0.25071304299990516

Apparently, PyTorch is super good at generating random numbers (3 times as fast as JIT-compiled JAX!). I did not analyze this and can’t say much more about this as there can be a myriad of reasons.

6.2. More about XLA

About XLA
  • Works via completely known shapes.
  • Can’t work dynamically with non-static values! That means (for jax.jit):
    • No if x > 0.
    • No jnp.unique.
    • No y[y % 2 == 0] or y[:x].
    • However, we can mark x (or what it depends on) as static.
    • Alternatively, “disable” JIT in section via experimental host_callback module.
  • Most important optimization: operator/kernel fusion.
  • Best to apply at outer-most location only5.
When XLA Recompiles

Function is recompiled for…

  • different static argument values,
  • different argument shapes, and
  • different argument dtypes.

When you hit performance issues, constant recompilation may be the problem!

When XLA Recompiles (in Text)

Imagine a dictionary (hash map) compcache and a function with arguments args:

  • For each argument x in args, collect the following in cache_key:
    • if it’s static, x (identity).
    • if it’s not static, (x.shape, x.dtype),

compcache maps from key cache_key to JIT-output (value). Recompile and insert if compcache does not contain cache_key.

Any static x not hashable? Bzzzt, error!

When XLA Recompiles (in Code)
compcache = {}

def maybe_compile(compcache, func, args):
    cache_key = []
    for x in args:
        if is_static_arg(x):
            assert isinstance(x, collections.abc.Hashable)
            cache_key.append(x)  # Identity
        else:
            # Imagine this works on Python primitives.
            cache_key.append((x.shape, x.dtype))

    try:
        return compcache[cache_key]
    except KeyError:
        jit_output = xla_compile(func)
        compcache[cache_key] = jit_output
        return jit_output

Just an intuitive example! For example, fails with arbitrarily ordered keyword arguments; cache_key should be a dict.

JIT and C++
  • Possible to run JIT-compiled functions from C++ via XLA runtime6 and jax_to_hlo utility.
  • Intermediate representation (IR) output from jax.jit will be JIT-compiled in C++ program.
  • However, a bit involved.
  • Example in JAX repository.

6.3. LLVM is Smart… And XLA?

In this section, I’d like to show you a clever optimization compilers do for you.

We’ll take a look at a simple sum implementation in C and the code generated from it. We will compare that with several implementations in Python, compiling the JAX version and seeing why (spoiler alert) XLA does not match up in terms of performance, even though it also uses LLVM for compilation.

Scalar Evolution with Clang

Giant’s Shoulders
#include <stdio.h>

int sum(int limit)
{
    int i, total;

    total = 0;
    for (i = 0; i < limit; ++i) {
        total += i;
    }
    return total;
}

int main(int argc, char** argv)
{
    printf("%d\n", sum(50000));
    return 0;
}
1249975000

If you’re following along, save the above to a file called sum.c.

This is the simplest sum function we can implement. Let’s see what a modern C compiler does to this code: by outputting LLVM’s lower-level representation.

Outputting LLVM IR

Intermediate representation (IR) is a lower-level (in this case assembly-like) representation of the program. The compiler backend LLVM uses IR to achieve portability across assembly languages.

Don’t worry too much about the vim call below – we are simply filtering the LLVM IR output so it only shows the definition of the sum function. The sed call strips trailing spaces.

clang -S -emit-llvm sum.c -O1
cat sum.ll \
    | vim - +'/^define.*sum(.*{$/,/^}$/p' \
          -es --not-a-term \
    | sed 's/ *$//'

Clang is the official C frontend7 of the LLVM project.

LLVM Scalar Evolution

Warning: assembly-like language below! Don’t worry about reading this, I’ll give the gist of it below.

define dso_local i32 @sum(i32 %0) local_unnamed_addr #0 {
  %2 = icmp sgt i32 %0, 0
  br i1 %2, label %3, label %13

3:                                                ; preds = %1
  %4 = add i32 %0, -1
  %5 = zext i32 %4 to i33
  %6 = add i32 %0, -2
  %7 = zext i32 %6 to i33
  %8 = mul i33 %5, %7
  %9 = lshr i33 %8, 1
  %10 = trunc i33 %9 to i32
  %11 = add i32 %10, %0
  %12 = add i32 %11, -1
  br label %13

13:                                               ; preds = %3, %1
  %14 = phi i32 [ 0, %1 ], [ %12, %3 ]
  ret i32 %14
}

So, without focusing too much on the details and interpreting this intuitively: please believe me that LLVM converted our sum function that used a for-loop into… the closed-form sum formula \( n \cdot (n - 1) / 2 \) (minus instead of plus due to the limit being excluded)! Isn’t that amazing?

This form of optimization is called scalar evolution and is an induction-based technique for – as you can see – quite substantial performance improvements. If you became interested in this topic, in the next section follows the link to the source code which includes references to the papers it implements.

If you really wanted to make sure the optimization happens on the machine code level, you can compile the code to an object file and disassemble it using for example the radare2 program.

More on Scalar Evolution

LLVM scalar evolution analysis (SCEV) source code with links to papers.

Weirdly enough, sums with a step other than 1 are not optimized even though a closed-form solution exists…

To explain a bit more, an earlier version had an int sum(int limit, int step) function that allowed a varying step size. However, LLVM did not optimize this function to the closed-form solution, even though it really should be able to (from what I could see in the comments of the scalar evolution source code).

Benchmarking C Sum

This section is about obtaining a C timing for the sum function and can safely be skipped.

In order to get some timings for C, which does not include a nice timeit module, what follows is a benchmark program for the above sum function allowing various timing methods. The idea is to mimic timeit with this. I saved this to a file called benchmark_sum.c.

#include <stdio.h>
#include <time.h>

// The CPU cycle-based timings suffer from poor resolution compared to
// wall-time measurements.
#define MY_CLOCK 0
#define MY_CLOCK_GETTIME 1
#define MY_CLOCK_GETTIME_WALL 2
#define MY_TIME 3

#define MY_CLOCK_FUN MY_CLOCK_GETTIME_WALL

int sum(int limit)
{
    int i, total;

    total = 0;
    for (i = 0; i < limit; ++i) {
        total += i;
    }
    return total;
}

int main(int args, char** argv)
{
    double duration;
    int res;
#if MY_CLOCK_FUN == MY_CLOCK
    clock_t start_time, end_time;

    start_time = clock();
#elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME
    struct timespec start_time, end_time;

    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start_time);
#elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME_WALL
    struct timespec start_time, end_time;

    clock_gettime(CLOCK_MONOTONIC, &start_time);
#elseif MY_CLOCK_FUN == MY_TIME
    time_t start_time, end_time;
#endif

    res = sum(50000);

#if MY_CLOCK_FUN == MY_CLOCK
    end_time = clock();
    duration = (double) (end_time - start_time) / CLOCKS_PER_SEC;
#elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end_time);
    duration = (
        end_time.tv_sec + 1e-9 * end_time.tv_nsec
        - start_time.tv_sec - 1e-9 * start_time.tv_nsec
        );
#elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME_WALL
    clock_gettime(CLOCK_MONOTONIC, &end_time);
    duration = (
        end_time.tv_sec + 1e-9 * end_time.tv_nsec
        - start_time.tv_sec - 1e-9 * start_time.tv_nsec
        );
#elseif MY_CLOCK_FUN == MY_TIME
    end_time = time();
    duration = difftime(end_time, start_time);
#endif

    // Use `res` so the `sum` call is not optimized away.
    printf("sum %d\n", res);
    printf("dur %.17g\n", duration);
    return 0;
}

Check that we get the sum formula optimization we want. I also manually checked to make sure that main does not optimize away the sum call.

clang -S -emit-llvm benchmark_sum.c -O1
cat benchmark_sum.ll \
    | vim - +'/^define.*sum(.*{$/,/^}$/p' \
          -es --not-a-term \
    | sed 's/ *$//'
define dso_local i32 @sum(i32 %0) local_unnamed_addr #0 {
  %2 = icmp sgt i32 %0, 0
  br i1 %2, label %3, label %13

3:                                                ; preds = %1
  %4 = add i32 %0, -1
  %5 = zext i32 %4 to i33
  %6 = add i32 %0, -2
  %7 = zext i32 %6 to i33
  %8 = mul i33 %5, %7
  %9 = lshr i33 %8, 1
  %10 = trunc i33 %9 to i32
  %11 = add i32 %10, %0
  %12 = add i32 %11, -1
  br label %13

13:                                               ; preds = %3, %1
  %14 = phi i32 [ 0, %1 ], [ %12, %3 ]
  ret i32 %14
}

Execute the same number of times as we do with time_str and add up the timings.

clang -o benchmark_sum.o benchmark_sum.c -O1
./benchmark_sum.o
seq 5000 \
    | xargs -n 1 ./benchmark_sum.o \
    | awk '/dur/ {total+=$2} END {print total}'
sum 1249975000
dur 9.532824124368238e-130
4.76641e-126

I copy-pasted this number at a later location so that I was able to give a live surprise.

Clang/Math vs. XLA

How does XLA stack sum up?

Back to Python, let’s see whether XLA timings match Clang. Also, let’s acknowledge Python’s miserable performance.

def py_sum_up(limit):
    return sum(range(limit))

def jax_sum_up(limit):
    return jnp.sum(jnp.arange(limit))

limit = 50000

python_time = time_str('py_sum_up(limit)')
jax_time = time_str('jax_sum_up(limit).block_until_ready()')
jit_sum_up = jax.jit(jax_sum_up, static_argnums=(0,))
jax_jit_time = time_str('jit_sum_up(limit).block_until_ready()')
[
    ['function', 'time [s]', 'result'],
    ['py_sum_up', python_time, py_sum_up(limit)],
    ['jax_sum_up', jax_time, jax_sum_up(limit).item()],
    ['jit_sum_up', jax_jit_time, jit_sum_up(limit).item()],
]
Sum Timings
function time [s] result
py_sum_up 3.0464433249981084 1249975000
jax_sum_up 0.44727893300296273 1249975000
jit_sum_up 0.1480330039994442 1249975000
JIT vs. Math

Here, we manually implement the sum formula optimization. To get a better idea how fast jit_sum_up should be in the best case.

def sum_up_const(limit):
    # Since we exclude the limit, subtract one instead of adding.
    return limit * (limit - 1) // 2

[
    ['function', 'time [s]', 'result'],
    [
        'jit_sum_up',
        jax_jit_time,
        jit_sum_up(limit).item(),
    ],
    [
        'sum_up_const',
        time_str('sum_up_const(limit)'),
        sum_up_const(limit),
    ],
]
JIT vs. Math Results
function time [s] result
jit_sum_up 0.1480330039994442 1249975000
sum_up_const 0.0008629900003143121 1249975000

Reason? The optimization is too slow to apply and was disabled (but only on the CPU)!8

C with -O1 takes around 5e-126 seconds.

Even More on JIT and Math

More JIT vs. Math

Just for fun, we try the same in PyTorch, using its JIT. We need to save the following code to a file torch_sum.py so PyTorch can parse the (TorchScript) source code.

import torch

__all__ = ['torch_sum_up', 'torch_jit_sum_up']

def torch_sum_up(limit):
    return torch.sum(torch.arange(limit))

torch_jit_sum_up = torch.jit.script(torch_sum_up)
from torch_sum import *

torch_limit = torch.tensor(limit)
[
    ['function', 'time [s]', 'result'],
    [
        'torch_sum_up',
        time_str(
            'torch_sum_up(torch_limit); '
            'torch.cuda.synchronize()'
        ),
        torch_sum_up(torch_limit).item(),
    ],
    [
        'torch_jit_sum_up',
        time_str(
            'torch_jit_sum_up(torch_limit); '
            'torch.cuda.synchronize()'
        ),
        sum_up_const(torch_limit).item(),
    ],
]
function time [s] result
torch_sum_up 0.15606207399832783 1249975000
torch_jit_sum_up 0.1805207170000358 1249975000
XLA-Optimized LLVM IR

If you’ve been executing jax.jit code snippets, you can find XLA IR output in the generated directory (if you have set the XLA_FLAGS as in the behind-the-scenes setup code block).

7. vmap: No Batching Required

What is Batching

If you don’t know what batching is, here are some simple examples. First three non-batched versions, then a batched version. Whether you’re using R, Octave/MATLAB, NumPy, or PyTorch, you always want to batch your calculations for optimum performance. Especially when you are interested in taking gradients, batching greatly simplifies the computational graph.

For the setting, assume we have a small set of 15 3-dimensional edges and we want to sum up their norms because we need it for some computer graphics algorithm.

Non-batched:

rng_key, edges = randn(rng_key, (15, 3))

norm_sum = 0
for edge in edges:
    norm_sum += jnp.linalg.norm(edge)
norm_sum
27.265522

Now, we can write this in a more pythonic way by using the built-in sum function. However, we are still not batching.

sum(jnp.linalg.norm(edge) for edge in edges)

The following is a more NumPy-like way to write the non-batched version. It may even be faster than the pythonic version due to being able to use SIMD operations. Whether performance is gained or lost depends on the size of our dataset, though.

jnp.sum(jnp.array([jnp.linalg.norm(edge) for edge in edges]))

Finally, we arrive at the batched version. We avoid any and all Python loops, calculating our norm sum in a much more efficient manner.

Batched:

norms = jnp.linalg.norm(edges, axis=-1)  # shape: (15,)
norm_sum = jnp.sum(norms)
norm_sum
27.265522
vmap: No Batching Required

Now, what if I told you you no longer needed to do batching manually? Enter jax.vmap!

The following example calculates the spectral radius on a 128-size batch of 3×3 matrices. With the assert statement, we make sure our old version jax_spectral_radius is not batched already.

Notice also that you can combine jax.jit with jax.vmap – any of the function transformations in JAX are arbitrarily nestable; quite the magic! I’ll let the below timings speak for themselves.

Why prefer the jax.jit-​jax.vmap combo over a jax.jit​ted loop? If anything, it keeps the computational graph simpler!

rng_key, batch = jit_randn(rng_key, (128, 3, 3))

assert jax_spectral_radius(batch).shape != (128,)

def looped_spectral_radius(batch):
    return list(map(jax_spectral_radius, batch))

jit_looped_spectral_radius = jax.jit(looped_spectral_radius)
batched_spectral_radius = jax.vmap(jax_spectral_radius)
jit_batched_spectral_radius = jax.jit(batched_spectral_radius)
function time [s]
looped 10.506742246001522
jit_looped 1.7194310759987275
batched 3.65071794799951
jit_batched 0.9865755659993738
More Batching

Because we can’t get enough of comparisons across frameworks, here are some more batched implementations.

Since we are going to JIT-compile PyTorch’s TorchScript again, you need to save the following code block to torch_batching.py.

By the way, if you were wondering: implementing looped versions with either jnp.stack or a preallocated result matrix did not improve results. In fact, the pre-allocating version took too long to JIT-compile so I never got a number for that one. Below, you can find a version using jax.lax.scan.

import torch

__all__ = ['torch_batched_spectral_radius', 'torch_jit_batched_spectral_radius']

def torch_batched_spectral_radius(mat):
    eigvals = torch.linalg.eigvals(mat)
    spectral_radius = torch.max(torch.abs(eigvals), -1)
    return spectral_radius

torch_jit_batched_spectral_radius = torch.jit.script(
    torch_batched_spectral_radius)
from torch_batching import *

def np_batched_spectral_radius(mat):
    eigvals = np.linalg.eigvals(mat)
    spectral_radius = np.max(np.abs(eigvals), -1)
    return spectral_radius

def jax_manually_batched_spectral_radius(mat):
    eigvals = jnp.linalg.eigvals(mat)
    spectral_radius = jnp.max(jnp.abs(eigvals), -1)
    return spectral_radius

jax_jit_manually_batched_spectral_radius = jax.jit(
    jax_manually_batched_spectral_radius)

np_batch = np.array(batch)
torch_batch = torch.from_numpy(np_batch)

[
    ['function', 'time [s]'],
    [
        'np_batched',
        time_str('np_batched_spectral_radius(np_batch)'),
    ],
    [
        'jax_manually_batched',
        time_str(
            'jax_manually_batched_spectral_radius(batch).block_until_ready()'),
    ],
    [
        'jax_jit_manually_batched',
        time_str(
            'jax_jit_manually_batched_spectral_radius(batch)'
            '.block_until_ready()'),
    ],
    [
        'torch_batched',
        time_str(
            'torch_batched_spectral_radius(torch_batch); '
            'torch.cuda.synchronize()'
        ),
    ],
    [
        'torch_jit_batched',
        time_str(
            'torch_jit_batched_spectral_radius(torch_batch); '
            'torch.cuda.synchronize()'
        ),
    ],
]
More Batching Results
function time [s]
np_batched 0.8191060569988622
jax_manually_batched 1.0442584550000902
jax_jit_manually_batched 0.8894995779992314
torch_batched 1.276841124999919
torch_jit_batched 1.373109233998548

When comparing these timings, keep in mind that we are not running on the GPU.

8. pmap: Simple, Differentiable MPI

pmap

To activate for local tests (adjust num_devices as desired):

import multiprocessing
import os

num_devices = multiprocessing.cpu_count()
os.environ['XLA_FLAGS'] = (
    os.getenv('XLA_FLAGS')
    + ' --xla_force_host_platform_device_count='
    + str(num_devices)
)

jax.pmap is actually not just related to jax.vmap in name – the functions do the exact same thing, just different: jax.vmap batches its function and can be imagined as a for-loop over the mapped-over axis. jax.pmap also batches its function but is instead a parallelly executed for-loop. That’s really all there is to it; when you know jax.vmap, you know jax.pmap. Of course, the parallelism offers some extra functionalities which is exactly what we are going to be learning in this section.

pmap and Axes
  • Splits computation according to an axis; works over multiple devices and/or JAX processes.
  • This broadcasting axis must be the size of our devices/processes, so reshape your data accordingly:
    • assuming 4 devices/processes and broadcasting axis 0, reshape dataset of shape (128, 3, 3) to (4, 32, 3, 3).
    • In code:

      world_size = jax.device_count()  # or `jax.process_count()`
      dataset = dataset.reshape(world_size, -1, dataset.shape[1:])
      

8.1. Write your own Horovod!

The following is a case study of using JAX to parallelize simple deep learning code. We are going to train a simple multilayer perceptron and teach it to calculate spectral radii.

I referenced the parallel deep learning framework Horovod because it seems to be the most-used tool at Jülich Supercomputing Centre for model-parallel training.

You can substitute “Horovod” with whatever you like to use, though; be it MPI, tf.distribute.Strategy, torch.distributed, torch.distributed with torch.nn.parallel.DistributedDataParallel (DDP), or whatever else you know and love. The principles are all the same, although our JAX implementation is more similar to Horovod than PyTorch DDP, for example (due to global instead of node-local splitting of the batch).

Non-Distributed Setup

Training Code

What follows is some boilerplate setup code for deep learning using JAX’ built-in example machine learning libraries jax.example_libraries.stax and jax.example_libraries.optimizers. Notice how all state is explicitly handled with these. If you don’t care for this, you can skip straight to the interesting part.

I also quickly want to mention why the code is a bit larger than it could be: I like my model to be able to work with dynamic batch sizes; however, if you follow JAX’ official example, it would seem that the batch size needs to be fixed. By implementing two little extras, we are able to handle arbitrary batch sizes. First, we implement our own small flatten function/layer that flattens the whole input (unlike stax.Flatten). Second, we apply jax.vmap to our model function. The model function is always passed the model’s parameters as well, that is why we do not want to vmap over the first argument (indicated by in_axes=(None, [...])). And that’s it!

The only slight inconvenience with this setup is that we need to add an extra size-1 batch dimension in order to handle singular inputs. You’ll see this when we test our model later.

from jax.example_libraries import stax
from jax.example_libraries import optimizers

input_shape = (2, 2)

def flatten():
    def init_fun(rng_key, input_shape):
        flattened_size = jnp.prod(jnp.array(list(input_shape)))
        return (flattened_size,), ()

    def apply_fun(params, inputs, **kwargs):
        return inputs.ravel()

    return init_fun, apply_fun

def build_model(rng_key):
    model_init, model = stax.serial(
        flatten(),
        stax.Dense(64),
        stax.Relu,
        stax.Dense(64),
        stax.Relu,
        stax.Dense(64),
        stax.Relu,
        stax.Dense(1),
    )
    # Handle varying batch sizes.
    model = jax.vmap(model, in_axes=(None, 0))

    rng_key, subkey = jax.random.split(rng_key)
    output_shape, params = model_init(subkey, input_shape)

    assert output_shape == (1,)
    return rng_key, params, model

def build_opt(params):
    opt_init, update, get_params = optimizers.adam(3e-4)
    opt_state = opt_init(params)
    return opt_state, update, get_params

rng_key, params, model = build_model(rng_key)

orig_opt_state, opt_update, get_params = build_opt(params)
Interesting Part of the Training Code

Here we implement our update methods. Notice how we use batched_spectral_radius instead of jit_batched_spectral_radius in order to give XLA more optimization freedom. Also, here we see conjugating the possibly complex gradients in action.

def batch_loss(params, batch):
    preds = model(params, batch)
    targets = batched_spectral_radius(batch)
    return jnp.mean(jnp.abs(preds - targets))

@jax.jit
def train_batch(step, opt_state, batch):
    params = get_params(opt_state)
    loss, grad = jax.value_and_grad(batch_loss)(params, batch)
    # Conjugate gradient for steepest-descent optimization.
    grad = jax.tree_util.tree_map(jnp.conj, grad)
    opt_state = opt_update(step, grad, opt_state)
    return opt_state, loss
Training a Spectral Radius MLP

A really simple deep learning training loop. We generate our batches on-demand, taking care to update our rng_key, of course!

opt_state = orig_opt_state
batch_size = 64
batch_shape = (batch_size,) + input_shape

steps = 10000
log_step_interval = 1000
start_time = time.perf_counter()
for step in range(steps):
    rng_key, batch = jit_randn(rng_key, batch_shape, dtype=complex)
    opt_state, loss = train_batch(step, opt_state, batch)
    if step % log_step_interval == 0:
        print('step ', step, '; loss ', loss, sep='')

end_time = time.perf_counter()
print('Training took', end_time - start_time, 'seconds.')
Training Results
step 0; loss 1.6026156
step 1000; loss 0.39599544
step 2000; loss 0.38151193
step 3000; loss 0.4386006
step 4000; loss 0.3645811
step 5000; loss 0.38383436
step 6000; loss 0.4037715
step 7000; loss 0.3104779
step 8000; loss 0.32223767
step 9000; loss 0.40970623
Training took 7.869086120001157 seconds.

Okay, that’s some sound old MLP training. Let’s get into parallelization already.

Multi-Node Distribution

A quick interlude on some extra distributed use cases and GPU memory pre-allocation. These are only interesting if you plan to distribute code yourself and are more code snippets I wanted to leave here as a reference. To skip these sections, click here.

The following is some setup code you would use on a Slurm-managed cluster, for example. But first, a word of caution…

Multi-Node Distributed Setup
  • Caution: Experimental, undocumented and not even used, much less tested, anywhere inside JAX!
  • Future versions will have a function jax.distributed.initialize, working much like PyTorch’s

    torch.distributed.init_process_group(
        [...],
        init_method="tcp://[...]",  # or "env://"
    )
    
Multi-Node Distributed Setup Code

Adapted from JAX source code (clickable version of link below):

# Adapted and compressed from
# https://github.com/google/jax/blob/4d6467727709de1c9ad220ac62783a18bcbf4990/jax/_src/distributed.py

def jax_distributed_initialize(
        coordinator_address, num_processes, process_id):
    if process_id == 0:
        global _service
        _service = \
            jax.lib.xla_extension.get_distributed_runtime_service(
                coordinator_address, num_processes)

    client = jax.lib.xla_extension.get_distributed_runtime_client(
        coordinator_address, process_id)
    client.connect()

    factory = functools.partial(
        jax.lib.xla_client.make_gpu_client, client, process_id)
    jax.lib.xla_bridge.register_backend_factory(
        'gpu', factory, priority=300)

Handling GPU Memory Pre-Allocation

If you ever had trouble with running out of GPU memory when using multi-process TensorFlow, you may have fixed it by enabling “GPU memory growing”. (The default is that TensorFlow pre-allocates a large block of memory in order to reduce memory fragmentation.)

JAX does the same, so in case you need it, what follows is the JAX equivalent to enabling GPU memory growing in TensorFlow.

Disable GPU Memory Pre-Allocation

Equivalent of the following TensorFlow:

import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In JAX:

import os

# Before first JAX computation.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

Distributed Training

Distributing our Training Code

With all of that out of the way, let’s finally write some distributed training code!

… What’s that? We only need to add a single line in our train_batch function?

Well, almost; we also need to apply the titular jax.pmap function transformation. I’ll explain what’s going on here after the code block.

batch_axis = 'batch'

def distributed_train_batch(step, opt_state, batch):
    params = get_params(opt_state)
    loss, grad = jax.value_and_grad(batch_loss)(params, batch)
    # This is the only line we had to add to `train_batch`.
    loss, grad = jax.lax.pmean((loss, grad), batch_axis)
    # Conjugate gradient for steepest-descent optimization.
    grad = jax.tree_util.tree_map(jnp.conj, grad)
    opt_state = opt_update(step, grad, opt_state)
    return opt_state, loss

# `pmap` also `jit`s our function.
pmap_train_batch = jax.pmap(
    distributed_train_batch,
    batch_axis,
    in_axes=(None, None, 0),
    out_axes=(None, None),
)

This is already all the magic behind Horovod! (Ignoring the super-optimized communication.)

The first thing that should have caught your eye is the definition of batch_axis at the very top of the code block. This batch_axis is passed to both jax.lax.pmean and jax.pmap as the reduction axis and mapped-over axis, respectively. We need to do this because – and I hope you’ve started to notice a pattern here – jax.pmap is also nestable! By having to specify an axis for jax.pmap, the reduction operations in jax.lax will always have an axis to refer to. We use a string to name the axis here, but any hashable would do.

The call to jax.lax.pmean averages a tree of values over all processes. In our case, it averages the loss and gradient. We average the loss as well here because you usually want to log the globally averaged loss instead of the local loss to get a smoother overall picture (in Horovod, you need to enable this explicitly). The averaged gradient is then used to do the same update step on each process, so we don’t need any more synchronization afterwards.

jax.pmap has some other arguments here we haven’t seen yet, namely in_axes and out_axes. jax.vmap accepts these too, and they are very important! With them, you control the axis of each argument that the function transformation maps over. If you don’t want to map over an argument (maybe you are passing a single constant that you don’t want to copy until it has the size of the mapped-over axis), you specify in_axes at the position of the argument as None. We do this for the training step (a single integer) and model parameters (a tree of matrices). However, we do want to map over the batch somehow. We specify the very first axis here.

out_axes is similar, but for the output. If we wanted to collect the different outputs of the function transformation for each mapped-over input, we would specify the return values that we want to collect and the axis we want to collect them over at the corresponding positions in out_axes. Since we already reduce over the mapped-over values with the jax.lax.pmean call, we will not have multiple different outputs and thus use None just like for in_axes to prevent the collection.

Training a Spectral Radius MLP Distributively

Here, we reshape the original batch so we can split it over its first axis as desired. In the code, we obtain the world_size as the number of devices known to JAX. Going back to the multi-node setup code from before, if you initialize your distributed training like that, i.e. by starting multiple Python processes, you may want to use world_size = jax.process_count() instead. Like so often, this depends on your use case.

opt_state = orig_opt_state
# Since we are only using a single JAX process with (possibly
# emulated) multiple devices, we use `jax.device_count`.
# Usually, this would be `jax.process_count`.
world_size = jax.device_count()
assert batch_size % world_size == 0
local_batch_size = batch_size // world_size

start_time = time.perf_counter()
print('Training with', world_size, '(possibly simulated) devices.')
for step in range(steps):
    rng_key, batch = jit_randn(rng_key, batch_shape, dtype=complex)
    batch = batch.reshape(
        (world_size, local_batch_size) + batch.shape[1:])
    opt_state, loss = pmap_train_batch(step, opt_state, batch)
    if step % log_step_interval == 0:
        print('step ', step, '; loss ', loss, sep='')

end_time = time.perf_counter()
print('Distributed training took', end_time - start_time, 'seconds.')
Distributed Training Results
Training with 2 (possibly simulated) devices.
step 0; loss 1.6025591
step 1000; loss 0.3969112
step 2000; loss 0.38199607
step 3000; loss 0.4381593
step 4000; loss 0.36498725
step 5000; loss 0.38379186
step 6000; loss 0.40361017
step 7000; loss 0.3104655
step 8000; loss 0.32245204
step 9000; loss 0.40979913
Distributed training took 7.309056613001303 seconds.

Even though I only used 2 simulated devices, training did speed up a bit already. Nice!

Did it Learn?

Let’s test whether our model actually learned anything useful. Maybe the logged losses don’t tell the whole story. Also, please notice that I JITted the model function for inference here; I wanted to show this off as it will really speed up your inference times. (Of course, JITting did not really make sense here since I only use the function once.)

params = get_params(opt_state)
jit_model = jax.jit(model)

ceig_mat = jnp.array([[1.0, -1.0], [1.0, 1.0]])
batched_ceig_mat = jnp.expand_dims(ceig_mat, 0)

[
    ['function', 'spectral radius sample'],
    [
        'jax_spectral_radius',
        jax_spectral_radius(ceig_mat).item(),
    ],
    [
        'model',
        jit_model(params, batched_ceig_mat)[0].item(),
    ],
]
function spectral radius sample
jax_spectral_radius 1.4142135381698608
model 1.4250930547714233

9. Summary

Advantages
  • Educational documentation.
  • Familiar API.
  • Interoperate with TensorFlow using experimental jax2tf (included in JAX; despite the name, also supports TensorFlow to JAX).
  • Faster code!
  • More explicit, less magic, less trouble understanding.
  • Functional style will avoid headaches.
  • Look out for stabilization of pjit (previously sharded_jit); even simpler Horovod, ability to JIT huge functions!

    pjit is actually much more cool than you would expect. The old name is a bit more descriptive here, as pjit – aside from being a more abstracted pmap – allows us to even split super large functions that do not fit in the memory of a single device.

Disadvantages

Initial hurdles:

  • Getting used to it.
  • Sometimes not as quick to write; however, payoff in the long term.

Better with time:

  • Sometimes unpredictable or unstable.
  • Lacking ecosystem.

Will never change:

  • Hidden functionalities/undocumented APIs; some useful code (intentionally?) not public.
  • Mutating state during JITting will cause headaches.
  • Backend in TensorFlow: code split and dependency.

    The code split means that the TensorFlow repository also contains JAX-only code. So you have another code location to keep in mind if you need to dive deeper into the JAX JIT for some reason.

Neural Network Libraries
jax.example_libraries.stax
Included in JAX, bare-bones and requires more manual work. However, simplicity is an advantage.
flax
Most features, most user-friendly in my opinion.
haiku
Goes against JAX’ implicit/immutable state but converts models back to stateless transformations. Thus, maybe better user experience.
objax
Similar API to torch.nn.Module.
trax
Focus on sequence data and large-scale training.

These are all made by the same company!

Just to clarify, while the above statement can be interpreted as a cheap jab at Alphabet and the TensorFlow situation, I do believe that designing sensible APIs for JAX’ paradigms is hard. Most of these libraries can most likely be viewed as experiments with regard to this design.

functorch
  • JAX-like function transformations (vmap, grad, …)
  • Stateless models

for PyTorch! Experimental, but best of both worlds.

When you ever have trouble batching something in PyTorch, it may help. Will probably be a while until it’s included in PyTorch.

By the way, PyTorch also has a JIT!

Since this is the extended version, you did see some results with PyTorch’s JIT. All of those did not seem… promising. The non-JITted PyTorch code was consistently as fast or even faster than the JITted version.

However, I believe that PyTorch’s JIT compiler’s use case is a bit different, leaning more towards optimizing deep learning models for inference. I don’t understand why the super simple linear algebra we tried to JIT did not get optimized (un-optimized, instead!) at all, but I did not dive into PyTorch’s JIT and thus can’t say too much about this.

Thanks for Reading!

Thank you for your attention.

I hope you had as much fun as I had preparing.

10. Appendix

References
Extra Recommendations

I recommend this article and corresponding paper on limitations of XLA and TorchScript compilation.

Maybe of interest: there are already JAX libraries for differentiable rigid-body physics simulation and molecular dynamics. You will find others fields covered as well.

Footnotes:

1

Stable since PyTorch 1.9.0, released in June 2021.

3

Some of the Autograd developers took part in building JAX upon its ideas.

4

Why is the second a considered different from the first? The default __hash__ implementation is based on the object’s id, its location in memory. (Both of these are CPython implementation details.)

5

That way, the compiler gets the most options for optimization. Also, an already jit​ted inner function cannot be optimized further.

6

There are multiple XLA runtimes.

7

C family of languages, to be exact.

8

Not quite sure since they only disable --loop-unroll.

9

The Message Passing Interface (MPI) is a standard in parallel computing.