1. Setup
1.1. About this Text
I created this as a slide-based presentation for the Helmholtz AI Food for Thought seminar on December 02, 2021, in order to introduce researchers from various backgrounds to JAX. This text was produced from the same Org source, with some extra commentary text interleaved. That’s why the text may feel choppy at times and code snippets are more compressed instead of following good style. I tried my best to explain everything I mentioned in the presentation as well.
Also, this is the “extended edition”, including many more code snippets as well as some behind-the-scenes stuff. This way, everything should be reproducible and you should have a good base reference in case you decide to pick up JAX.
You can download the original slides here.
1.2. Environment Setup
This is the environment I used. I gave specific version numbers as comments in case you need the better reproducibility.
Sadly, I was unable to compile JAX with GPU support; the version of CUDA the official NVIDIA drivers support for Ubuntu 20.04 is no longer supported by JAX.
That’s why all numbers I will show you need to be taken with a grain of salt – we aren’t able to use JAX’ killer feature!
conda create -p env python=3.8 conda activate ./env # `pytorch` version: 1.10.0 conda install pytorch cudatoolkit=10.2 -c pytorch python -m pip install --upgrade pip # `jax` version: 0.2.25 # `jaxlib` version: 0.1.74 python -m pip install --upgrade jax jaxlib # I only have the CPU version.
clang --version | sed 's/ *$//'
clang version 10.0.0-4ubuntu1 Target: x86_64-pc-linux-gnu Thread model: posix InstalledDir: /usr/bin
2. Overview
Topic for Today
- JAX is a cool new corporation-backed framework for differentiable programming/scientific computing.
- Faster than NumPy/SciPy due to GPU usage and…
- Compilability via Accelerated Linear Algebra compiler (XLA, reused from TensorFlow).
- More usability over NumPy and PyTorch.
- Due to “forced” functional style, get good code for free!
- Everything I will show you ran on the CPU!
import functools import os import time import timeit import jax import jax.numpy as jnp import numpy as np import torch
Behind-the-Scenes Setup
The following setup code can safely be ignored (skip ahead) but achieves the following:
- XLA IR output in directory
generated
. - 2 simulated devices for XLA.
- Activate/disable GPUs in both JAX and PyTorch.
- Less print output.
- Initialize JAX and PyTorch so we don’t see warnings later.
# Need to do this before even the `jax.config` update. os.environ['XLA_FLAGS'] = ' '.join([ os.getenv('XLA_FLAGS', ''), '--xla_dump_to=./generated', '--xla_force_host_platform_device_count=2', ]) jax.lib.xla_bridge.get_backend.cache_clear() try: gpus = jax.local_devices(None, 'gpu') except RuntimeError: gpus = None has_gpu = bool(gpus) and torch.cuda.is_available() # By default, JAX runs on the GPU (if available) while PyTorch by # default runs on the CPU. So we choose the common denominator depending # on GPU availability. if not has_gpu: # Force JAX on CPU. jax.config.update('jax_platform_name', 'cpu') # PyTorch will complain if we call this without a GPU active. torch.cuda.synchronize = lambda: None else: # Force PyTorch on GPU. torch.cuda.set_device(0) print_precision = 3 np.set_printoptions(precision=print_precision) jnp.set_printoptions(precision=print_precision) torch.set_printoptions(precision=print_precision) def round_complex(val, precision): if isinstance(val, complex): return round(val.real, precision) + round(val.imag, precision) * 1j return round(val, precision) def round_tree(tree, precision=print_precision): return jax.tree_util.tree_map(lambda x: round_complex(x, precision), tree) # Initialize JAX. jnp.array(0) # Initialize PyTorch. torch.tensor(0)
Why is it Cool?
- NumPy/SciPy on the GPU.
- No mutable state means much joy: easier maintainability, easier scalability.
- No need for manual batching.
- Complex number differentiation previously1 not possible on PyTorch. Certain unsupported cases may still be out there.
- User-friendly parallelism.
JAX’ Libraries
jax
- Mostly important for function transformations/compositions.
jax.numpy
- NumPy replacement.
jax.scipy
- SciPy replacement.
jax.random
- Deterministic randomness.
jax.lax
- Lower-level operations.
jax.profiler
- Analyze performance.
Just a selection, there are more!
3. Mutating Immutable Arrays
Mutating Arrays in NumPy
In NumPy:
x = np.eye(2) x[0, 0] = 2 x[:, -1] *= -3 x
2 | 0 |
0 | -3 |
Some things to note for those not familiar with Python and/or NumPy:
- In Python, indices start at 0, so the first element of an array is at index 0.
- NumPy and JAX call any \( n \)-dimensional tensor an array, so do not expect a 1-dimensional vector when you see the word “array”.
- The colon “:” we used for indexing selects all elements over that dimension. In our case, this resulted in a whole row of the matrix being indexed.
- Negative values index from the end (but, confusingly, starting at 1 this time). So indexing with -1 yields the very last value of an array. Combining the colon with -1 in that order means we index the last row of the array.
- The shape of an array is a collection of its dimensionalities
(
x
’s shape is(2, 2)
). - The dtype of an array is the type of its elements.
Mutating Arrays in JAX
In JAX:
x = jnp.eye(2) try: x[0, 0] = 2 except TypeError: pass else: raise TypeError('array is mutable') x = x.at[0, 0].set(2) x = x.at[:, -1].multiply(-3) x
2 | 0 |
0 | -3 |
You will receive helpful errors in case you forget this!
The try-except-else
block of the code has the following effect: if
the statement in the try
block x[0, 0] = 2
raises a TypeError
,
we simply ignore it. That’s what the except TypeError
block does. If
we did not get any error, i.e. no exception was caught,
the else
block executes and raises a TypeError
telling us that the
array was mutable after all. Since the code executed and evaluated
just fine, we know that JAX raised an error upon trying to mutate the
array, meaning JAX arrays are not mutable as expected. (Whew!)
You may think performance tanks due to the extra allocations; however, these are optimized away to in-place mutations during JIT (just-in-time) compilation.
Closures (Functional Programming Interlude)
Function capturing something from an outer environment. (You’ve been using this whenever you refer to a non-local, e.g. global, variable inside a function.) Very useful for programming with JAX, but be careful about mutable state!
def create_counter(): count = 0 def inc(): nonlocal count count += 1 return count return inc counter = create_counter() ', '.join(map(str, [counter(), counter(), counter()]))
1, 2, 3 |
If you don’t know Python: the def
indicates a function definition.
In Python, closure capturing rules are the same as for function argument passing, also called “pass-by-sharing” – non-primitives are referenced (imagine a C pointer on that object). So outside modifications to non-primitives will be visible inside the closure as well!
This behavior varies between programming languages, so keep in mind that this is a Python implementation detail.
4. Randomness in JAX
Randomness in JAX
Stateless RNG makes reproducibility fun and easy as pie!2
seed = 0 rng_key = jax.random.PRNGKey(seed) def randn(rng_key, shape, dtype=float): rng_key, subkey = jax.random.split(rng_key) rands = jax.random.normal(subkey, shape, dtype=dtype) return rng_key, rands rng_key, rands = randn(rng_key, (2, 3)) rands
-1.458 | -2.047 | -1.424 |
1.168 | -0.976 | -1.272 |
RNG keys are the way JAX represents RNG state. We called the key we
use right away subkey
even though it is the exact same kind of
object as rng_key
. This naming is simply a nice pattern to use
because we know from the names which keys we are going to “consume”
and which we will pass along.
You can also split off more than one key by passing an additional
integer to jax.random.split
.
Another thing I wanted to show off here is the pattern of using
standard Python types with JAX. Since we only specified float
instead of the more specific jnp.float32
or jnp.float64
, JAX will
automatically pick the correct one based on its configured default.
RNG Keys
_, rands = randn(rng_key, (2, 3)) rands
0.932 | -0.004 | -0.74 |
-1.427 | 1.06 | 0.929 |
Notice we did not update our rng_key
, instead discarding that part
of the result. Can you guess what happens when we generate numbers
again?
rng_key, rands = randn(rng_key, (2, 3)) rands
0.932 | -0.004 | -0.74 |
-1.427 | 1.06 | 0.929 |
This time, we updated our rng_key
: the world of randomness is whole
again! Notice this is a really easy way to produce random numbers that
you (hopefully) want to stay the same.
rng_key, rands = randn(rng_key, (2, 3)) rands
1.169 | 0.312 | -0.571 |
0.137 | 0.742 | 0.038 |
5. grad
: Advanced AutoDiff
5.1. grad
How-To
How to Take Gradients
JAX’ grad
function transformation takes a function we want to
differentiate as input and returns another function calculating the
original one’s gradient given arbitrary inputs. By default, it takes
the derivative with regard to the first argument. Multiple jax.grad
applications can be nested to take higher-order derivatives.
Another Python syntax thing here, the **
is the exponentiation
operator. If you’re curious, ^
is used for a bitwise exclusive or.
def expo_fn(x, y): return x**4 + 2**y + 3 x = 1.0 y = 2.0 grad_x_fn = jax.grad(expo_fn) # 4x^3 grad2_x_fn = jax.grad(grad_x_fn) # 12x^2 grad3_x_fn = jax.grad(grad2_x_fn) # 24x [ ['function', r'\partial{}expo_fn/\partial{}\(x\)'], ['grad_x', grad_x_fn(x, y).item()], ['grad2_x', grad2_x_fn(x, y).item()], ['grad3_x', grad3_x_fn(x, y).item()], ]
function | ∂expo_fn/∂\(x\) |
---|---|
grad_x |
4.0 |
grad2_x |
12.0 |
grad3_x |
24.0 |
Differentiating Different Arguments
To differentiate with regard to other arguments of our functions, we
pass jax.grad
the indices of those arguments in the argnums
argument. We can also specify multiple argnums
.
x = 1.0 y = 2.0 grad_y_fn = jax.grad(expo_fn, argnums=1) # ln(y) · 2^y grad_xy_fn = jax.grad(expo_fn, argnums=(0, 1)) [ ['function', 'result'], ['grad_y', grad_y_fn(x, y).item()], ['grad_xy', [g.item() for g in grad_xy_fn(x, y)]], ]
function | result |
---|---|
grad_y |
2.7725887298583984 |
grad_xy |
(4.0 2.7725887298583984) |
Differentiation and Outputs
In machine learning, we really like to monitor our loss values,
which are the non-differentiated results of the function we
differentiate (i.e. the value we want to minimize). In
order to not have to evaluate the function twice and lose precious
performance, JAX offers the jax.value_and_grad
function
transformation which returns a new function that calculates the result
of the function as well as its gradient. Now we can log our losses and
sleep well again.
result_grad_fn = jax.value_and_grad(expo_fn) result, grad = result_grad_fn(x, y)
Fun fact: jax.value_and_grad
is actually what jax.grad
calls as
well, it just tosses result
.
Let’s assume the function we want to differentiate has multiple outputs, for example maybe we need to return some new state!
Let’s assume we went through the trouble of collecting all our extra return values in a tuple. We then also changed our function to return a pair (i.e. another tuple) containing (in this order)
- the value we want to differentiate through, and
- the tuple of extra return values.
We can then supply the has_aux=True
argument to jax.grad
and
happily differentiate again while keeping our state intact:
def poly_fn_and_aux(x, y): aux_output = ({'y': y}, 1337) return x**4 + (y - 1)**2 + 3, aux_output grad_aux_fn = jax.grad(poly_fn_and_aux, has_aux=True) grad, aux = grad_aux_fn(x, y)
Of course, the same works for jax.value_and_grad
as well; however,
its tree of return values needs some special care to deconstruct:
result_aux_grad_fn = jax.value_and_grad(poly_fn_and_aux, has_aux=True) (result, aux), grad = result_aux_grad_fn(x, y)
By the way, if you ever want to disable gradient computation inside a
jax.grad
context, you can use jax.lax.stop_gradient
. Its use is a
bit unintuitive, so I’d recommend checking out the link.
5.2. Differentiating Spectral Radius
Default Precision Interlude
Simple Spectral Radius
def jax_spectral_radius(mat): eigvals = jnp.linalg.eigvals(mat) spectral_radius = jnp.max(jnp.abs(eigvals)) return spectral_radius def torch_spectral_radius(mat): eigvals = torch.linalg.eigvals(mat) spectral_radius = torch.max(torch.abs(eigvals)) return spectral_radius # Eigenvalues: 1 ± i ceig_mat = np.array([[1.0, -1.0], [1.0, 1.0]]) jax_mat = jnp.array(ceig_mat) torch_mat = torch.from_numpy(ceig_mat) [ ['function', 'result'], ['jax', jax_spectral_radius(jax_mat).item()], ['torch', torch_spectral_radius(torch_mat).item()], ['sqrt', np.sqrt(2)], ]
JAX’ Default Precision
function | result |
---|---|
jax |
1.4142135381698608 |
torch |
1.4142135623730951 |
sqrt |
1.4142135623730951 |
Wait! JAX’ precision is seriously behind PyTorch here! Is PyTorch just more precise or what’s going on?
While both JAX and PyTorch have single-precision (32-bit) floating
point numbers as their default dtype
, NumPy uses double-precision
(64-bit) floats by default.
Now, when we converted the matrix from NumPy to the respective
frameworks, JAX’ jnp.array
created a new array from NumPy’s, thus
converting the dtype
to JAX’ default. This leaves us with
jax_mat.dtype
being jnp.float32
. However, PyTorch’s
torch.from_numpy
adapted the dtype
exactly, which is why PyTorch
had double the precision to work with.
With that knowledge, let’s make the test more fair by converting
torch_mat
to single-precision as well:
torch_mat = torch_mat.float() [ ['jax', jax_spectral_radius(jax_mat).item()], ['torch', torch_spectral_radius(torch_mat).item()], ]
jax |
1.4142135381698608 |
torch |
1.4142135381698608 |
Ahh, all is well; JAX is not lacking in terms of precision after all (at least in this small example).
Let’s check out what the square root of 2 is in single-precision to see just how precise we are:
two_f32 = np.array(2, dtype=np.float32) [['sqrt_f32', '{:.16f}'.format(np.sqrt(two_f32))]]
sqrt_f32 |
1.4142135381698608 |
Gradients and Complex Numbers
Complex Differentiation
Let’s finally differentiate some complex numbers. You may not ever have seen this way to differentiate in PyTorch – it’s the functional way!
Just to clarify again, being able to take this gradient is a rather recent change in PyTorch: stable since PyTorch 1.9.0, June 2021. Originally, I wanted to show here that JAX is capable of differentiating something that PyTorch cannot. The competition has caught up, though!
jax_grad = jax.grad(jax_spectral_radius)(jax_mat) torch_mat.requires_grad_(True) torch_rho = torch_spectral_radius(torch_mat) torch_grad = torch.autograd.grad(torch_rho, torch_mat) decimals = 3 [ ['function', 'grad'], ['jax', round_tree(jax_grad.tolist(), decimals)], ['torch', round_tree(torch_grad[0].tolist(), decimals)], ]
function | grad |
---|---|
jax |
((0.354 -0.354) (0.354 0.354)) |
torch |
((0.354 -0.354) (0.354 0.354)) |
Setting Up Complex Input
complex_mat = np.array([[1 + 1j, -1], [1, 1]]) jax_complex_mat = jnp.array(complex_mat) torch_complex_mat = torch.from_numpy(complex_mat).to(torch.complex64) [ ['function', 'result'], None, ['jax', jax_spectral_radius(jax_complex_mat).item()], ['torch', torch_spectral_radius(torch_complex_mat).item()], ]
function | result |
---|---|
jax | 1.9021130800247192 |
torch | 1.9021130800247192 |
Differentiating Through Complex Gradients
Due to JAX doing some heavy abstract syntax tree (AST) work, it
includes a nice module with tree-related functions called
jax.tree_util
. We will use it to conjugate the tree of gradients we
obtain from jax.grad
(for no special reason at all).
jax_grad = jax.grad(jax_spectral_radius)(jax_complex_mat) jax_conj_grad = jax.tree_util.tree_map(jnp.conj, jax_grad) torch_complex_mat.requires_grad_(True) torch_rho = torch_spectral_radius(torch_complex_mat) torch_grad = torch.autograd.grad(torch_rho, torch_complex_mat) decimals = 3 [ ['type', 'gradient'], None, ['jax', round_tree(jax_grad.tolist(), decimals)], ['jax conj', round_tree(jax_conj_grad.tolist(), decimals)], ['torch', round_tree(torch_grad[0].tolist(), decimals)], ]
type | gradient |
---|---|
jax | (((0.38-0.616j) (-0.38-0.235j)) ((0.38+0.235j) (0.145-0.235j))) |
jax conj | (((0.38+0.616j) (-0.38+0.235j)) ((0.38-0.235j) (0.145+0.235j))) |
torch | (((0.38+0.616j) (-0.38+0.235j)) ((0.38-0.235j) (0.145+0.235j))) |
Oops! It seems there is a discrepancy here…
Without going too much into the math, when optimizing a function with complex inputs and real outputs, steepest-descent algorithms need the conjugate of the complex gradient in order to walk in the correct direction. As PyTorch is a deep learning framework first-and-foremost, it conjugates its gradients by default so users can go plug-and-play when fooling around with complex numbers and optimization.
6. jit
Compilation via XLA
6.1. Introducing jit
Side Effects in a jit
Context
Just-in-Time-Compiling via XLA
def print_hello(): print('Hello!') # Side effect! jax.jit(print_hello)()
Hello!
jit_print_hello = jax.jit(print_hello) jit_print_hello() jit_print_hello() print('... Hello? :(')
… Hello? :(
Before we dive in here, a quick terminology heads-up: “to JIT something”, means “to just-in-time-compile something”.
Multiple interesting things happened in these short snippets:
- We did not get any other “Hello!” after the first one in general.
- We did not get a second “Hello!” even though we applied
jax.jit
toprint_hello
a second time. - We did get a “Hello!” for the first call of the JITted function.
Let’s go through these in order:
- The
print
call is called a side effect in computer science. JAX does not care for these during JIT compilation, it only cares about math – or, to be more exact, whatever comes in and what comes out of the function. - The reason we do not get a second “Hello!” even though we apply
jax.jit
again (and we would expect the side effect to happen again before the function is compiled) is because JAX caches its compilation output in the background. So if we JIT the same function twice with the same arguments, the previous compilation output will be reused. JAX traces what happens inside the function on its first call, building a computational graph. This means, the first call of the function executes just like a standard Python function (though even slower due to the computational graph building).
This also means that JITting functions that result in a large computational graph (for example a Python loop that is executed very many times) can take forever to JIT only because the first tracing of it takes so long. When you encounter this issue, you can replace your loop with control flow substitutes from the
jax.lax
module.
JITting State
Here, we’ll see that jax.jit
can also be used as a decorator.
However, because we need to supply another argument to jax.jit
, we
cannot use it as a decorator that simply. Instead, we need to combine
it with the functools.partial
operator. An explanation follows after
the code block.
class Counter: count = 0 @functools.partial(jax.jit, static_argnums=(0,)) def inc(self): self.count += 1 a = Counter() print(a.count) a.inc() print(a.count) a.inc() print(a.count)
0 |
1 |
1 |
Applying functools.partial
like this results in the following
(actually anonymous) function:
# Result of `functools.partial(jax.jit, static_argnums=(0,))`. def partial_jit(*args, **kwargs): return jax.jit(*args, static_argnums=(0,), **kwargs)
This new partial_jit
function wraps the inc
method of Counter
,
resulting in the equivalent of the following:
class Counter: count = 0 def inc(self): self.count += 1 Counter.inc = jax.jit(Counter.inc, static_argnums=(0,))
I hope that helped make sense of the code. static_argnums
basically
tells jax.jit
to recompile the JITted function for a different
argument at that place. In return, we get some freedoms (for example,
we would not be able to JIT the function otherwise). We call the
arguments at the positions designated by static_argnums
static
from now on. More on static and non-static arguments later.
JITting State Again
a = Counter() print(a.count) a.inc() print(a.count) a.inc() print(a.count)
0 |
1 |
1 |
Due to self
being a static argument as specified via
static_argnums
, the function is recompiled for a new, different
self
4.
Benchmarking jit
on randn
JITting our randn
We’ll now use Python’s timeit
module to benchmark a JIT-compiled
version of our old friend randn
(remember we implemented this in the
section on RNG in JAX). We implement a simple wrapper around it in
order to initialize the JIT-compiled function before we benchmark it.
You will notice some block_until_ready
calls. These are due to JAX’
asynchronous execution engine. Whenever JAX executes code on a
non-host device (such as a GPU), it happens asynchronously. This means
that the main thread of the program continues to run ahead with a
“mock” result, also called a “future”, while the actual result is
computed in the background. Only when we actually query the result
will we wait until it’s available.
During benchmarking, we would get these futures immediately – that’s
not much use to us. So we call the block_until_ready
function in
order to wait until the result of the computation is actually
available. You achieve the same in PyTorch using a call to
torch.cuda.synchronize
.
jit_randn = jax.jit(randn, static_argnums=(1, 2)) def time_str(code, number=5000): # Initialize cache exec(code) return timeit.timeit(code, globals=globals(), number=number) randn_time = time_str( 'randn(rng_key, (100, 100))[1].block_until_ready()') jit_randn_time = time_str( 'jit_randn(rng_key, (100, 100))[1].block_until_ready()') [ ['function', 'time [s]'], ['randn', randn_time], ['jit_randn', jit_randn_time], ]
JIT Results
function | time [s] |
---|---|
randn |
1.0273832870007027 |
jit_randn |
0.7757850260022678 |
A 25 % reduction! Not bad at all, especially since we shouldn’t really have much room to optimize here.
Let’s see how PyTorch does:
np_rng = np.random.default_rng() np_randn_time = time_str('np_rng.normal(size=(100, 100))') torch_randn_time = time_str( 'torch.randn((100, 100)); ' 'torch.cuda.synchronize()' ) [['np_randn', np_randn_time], ['torch_randn', torch_randn_time]]
np_randn |
0.6026334530001805 |
torch_randn |
0.25071304299990516 |
Apparently, PyTorch is super good at generating random numbers (3 times as fast as JIT-compiled JAX!). I did not analyze this and can’t say much more about this as there can be a myriad of reasons.
6.2. More about XLA
About XLA
- Works via completely known shapes.
- Can’t work dynamically with non-static values! That means (for
jax.jit
):- No
if x > 0
. - No
jnp.unique
. - No
y[y % 2 == 0]
ory[:x]
. - However, we can mark
x
(or what it depends on) as static. - Alternatively, “disable” JIT in section via experimental
host_callback
module.
- No
- Most important optimization: operator/kernel fusion.
- Best to apply at outer-most location only5.
When XLA Recompiles
Function is recompiled for…
- different static argument values,
- different argument shapes, and
- different argument dtypes.
When you hit performance issues, constant recompilation may be the problem!
When XLA Recompiles (in Text)
Imagine a dictionary (hash map) compcache
and a function with arguments
args
:
- For each argument
x
inargs
, collect the following incache_key
:- if it’s static,
x
(identity). - if it’s not static,
(x.shape, x.dtype)
,
- if it’s static,
compcache
maps from key cache_key
to JIT-output (value). Recompile
and insert if compcache
does not contain cache_key
.
Any static x
not hashable? Bzzzt, error!
When XLA Recompiles (in Code)
compcache = {} def maybe_compile(compcache, func, args): cache_key = [] for x in args: if is_static_arg(x): assert isinstance(x, collections.abc.Hashable) cache_key.append(x) # Identity else: # Imagine this works on Python primitives. cache_key.append((x.shape, x.dtype)) try: return compcache[cache_key] except KeyError: jit_output = xla_compile(func) compcache[cache_key] = jit_output return jit_output
Just an intuitive example! For example, fails with arbitrarily ordered
keyword arguments; cache_key
should be a dict
.
JIT and C++
- Possible to run JIT-compiled functions from C++ via XLA
runtime6 and
jax_to_hlo
utility. - Intermediate representation (IR) output from
jax.jit
will be JIT-compiled in C++ program. - However, a bit involved.
- Example in JAX repository.
6.3. LLVM is Smart… And XLA?
In this section, I’d like to show you a clever optimization compilers do for you.
We’ll take a look at a simple sum implementation in C and the code generated from it. We will compare that with several implementations in Python, compiling the JAX version and seeing why (spoiler alert) XLA does not match up in terms of performance, even though it also uses LLVM for compilation.
Scalar Evolution with Clang
Giant’s Shoulders
#include <stdio.h> int sum(int limit) { int i, total; total = 0; for (i = 0; i < limit; ++i) { total += i; } return total; } int main(int argc, char** argv) { printf("%d\n", sum(50000)); return 0; }
1249975000 |
If you’re following along, save the above to a file called sum.c
.
This is the simplest sum function we can implement. Let’s see what a modern C compiler does to this code: by outputting LLVM’s lower-level representation.
Outputting LLVM IR
Intermediate representation (IR) is a lower-level (in this case assembly-like) representation of the program. The compiler backend LLVM uses IR to achieve portability across assembly languages.
Don’t worry too much about the vim
call below – we are simply
filtering the LLVM IR output so it only shows the definition of the
sum
function. The sed
call strips trailing spaces.
clang -S -emit-llvm sum.c -O1 cat sum.ll \ | vim - +'/^define.*sum(.*{$/,/^}$/p' \ -es --not-a-term \ | sed 's/ *$//'
LLVM Scalar Evolution
Warning: assembly-like language below! Don’t worry about reading this, I’ll give the gist of it below.
define dso_local i32 @sum(i32 %0) local_unnamed_addr #0 { %2 = icmp sgt i32 %0, 0 br i1 %2, label %3, label %13 3: ; preds = %1 %4 = add i32 %0, -1 %5 = zext i32 %4 to i33 %6 = add i32 %0, -2 %7 = zext i32 %6 to i33 %8 = mul i33 %5, %7 %9 = lshr i33 %8, 1 %10 = trunc i33 %9 to i32 %11 = add i32 %10, %0 %12 = add i32 %11, -1 br label %13 13: ; preds = %3, %1 %14 = phi i32 [ 0, %1 ], [ %12, %3 ] ret i32 %14 }
So, without focusing too much on the details and interpreting this
intuitively: please believe me that LLVM converted our sum
function
that used a for
-loop into… the closed-form sum
formula \( n
\cdot (n - 1) / 2 \) (minus instead of plus due to the limit being
excluded)! Isn’t that amazing?
This form of optimization is called scalar evolution and is an induction-based technique for – as you can see – quite substantial performance improvements. If you became interested in this topic, in the next section follows the link to the source code which includes references to the papers it implements.
If you really wanted to make sure the optimization happens on the
machine code level, you can compile the code to an object file and
disassemble it using for example the radare2
program.
More on Scalar Evolution
LLVM scalar evolution analysis (SCEV) source code with links to papers.
Weirdly enough, sums with a step other than 1 are not optimized even though a closed-form solution exists…
To explain a bit more, an earlier version had an int sum(int limit,
int step)
function that allowed a varying step size. However, LLVM
did not optimize this function to the closed-form solution, even
though it really should be able to (from what I could see in the
comments of the scalar evolution source code).
Benchmarking C Sum
This section is about obtaining a C timing for the sum
function and
can safely be skipped.
In order to get some timings for C, which does not include a nice
timeit
module, what follows is a benchmark program for the above
sum
function allowing various timing methods. The idea is to mimic
timeit
with this. I saved this to a file called benchmark_sum.c
.
#include <stdio.h> #include <time.h> // The CPU cycle-based timings suffer from poor resolution compared to // wall-time measurements. #define MY_CLOCK 0 #define MY_CLOCK_GETTIME 1 #define MY_CLOCK_GETTIME_WALL 2 #define MY_TIME 3 #define MY_CLOCK_FUN MY_CLOCK_GETTIME_WALL int sum(int limit) { int i, total; total = 0; for (i = 0; i < limit; ++i) { total += i; } return total; } int main(int args, char** argv) { double duration; int res; #if MY_CLOCK_FUN == MY_CLOCK clock_t start_time, end_time; start_time = clock(); #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME struct timespec start_time, end_time; clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start_time); #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME_WALL struct timespec start_time, end_time; clock_gettime(CLOCK_MONOTONIC, &start_time); #elseif MY_CLOCK_FUN == MY_TIME time_t start_time, end_time; #endif res = sum(50000); #if MY_CLOCK_FUN == MY_CLOCK end_time = clock(); duration = (double) (end_time - start_time) / CLOCKS_PER_SEC; #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end_time); duration = ( end_time.tv_sec + 1e-9 * end_time.tv_nsec - start_time.tv_sec - 1e-9 * start_time.tv_nsec ); #elseif MY_CLOCK_FUN == MY_CLOCK_GETTIME_WALL clock_gettime(CLOCK_MONOTONIC, &end_time); duration = ( end_time.tv_sec + 1e-9 * end_time.tv_nsec - start_time.tv_sec - 1e-9 * start_time.tv_nsec ); #elseif MY_CLOCK_FUN == MY_TIME end_time = time(); duration = difftime(end_time, start_time); #endif // Use `res` so the `sum` call is not optimized away. printf("sum %d\n", res); printf("dur %.17g\n", duration); return 0; }
Check that we get the sum formula optimization we want. I also
manually checked to make sure that main
does not optimize away the
sum
call.
clang -S -emit-llvm benchmark_sum.c -O1 cat benchmark_sum.ll \ | vim - +'/^define.*sum(.*{$/,/^}$/p' \ -es --not-a-term \ | sed 's/ *$//'
define dso_local i32 @sum(i32 %0) local_unnamed_addr #0 { %2 = icmp sgt i32 %0, 0 br i1 %2, label %3, label %13 3: ; preds = %1 %4 = add i32 %0, -1 %5 = zext i32 %4 to i33 %6 = add i32 %0, -2 %7 = zext i32 %6 to i33 %8 = mul i33 %5, %7 %9 = lshr i33 %8, 1 %10 = trunc i33 %9 to i32 %11 = add i32 %10, %0 %12 = add i32 %11, -1 br label %13 13: ; preds = %3, %1 %14 = phi i32 [ 0, %1 ], [ %12, %3 ] ret i32 %14 }
Execute the same number of times as we do with time_str
and add up
the timings.
clang -o benchmark_sum.o benchmark_sum.c -O1 ./benchmark_sum.o seq 5000 \ | xargs -n 1 ./benchmark_sum.o \ | awk '/dur/ {total+=$2} END {print total}'
sum 1249975000 dur 9.532824124368238e-130 4.76641e-126
I copy-pasted this number at a later location so that I was able to give a live surprise.
Clang/Math vs. XLA
How does XLA stack sum up?
Back to Python, let’s see whether XLA timings match Clang. Also, let’s acknowledge Python’s miserable performance.
def py_sum_up(limit): return sum(range(limit)) def jax_sum_up(limit): return jnp.sum(jnp.arange(limit)) limit = 50000 python_time = time_str('py_sum_up(limit)') jax_time = time_str('jax_sum_up(limit).block_until_ready()') jit_sum_up = jax.jit(jax_sum_up, static_argnums=(0,)) jax_jit_time = time_str('jit_sum_up(limit).block_until_ready()') [ ['function', 'time [s]', 'result'], ['py_sum_up', python_time, py_sum_up(limit)], ['jax_sum_up', jax_time, jax_sum_up(limit).item()], ['jit_sum_up', jax_jit_time, jit_sum_up(limit).item()], ]
Sum Timings
function | time [s] | result |
---|---|---|
py_sum_up |
3.0464433249981084 | 1249975000 |
jax_sum_up |
0.44727893300296273 | 1249975000 |
jit_sum_up |
0.1480330039994442 | 1249975000 |
JIT vs. Math
Here, we manually implement the sum formula optimization. To get a
better idea how fast jit_sum_up
should be in the best case.
def sum_up_const(limit): # Since we exclude the limit, subtract one instead of adding. return limit * (limit - 1) // 2 [ ['function', 'time [s]', 'result'], [ 'jit_sum_up', jax_jit_time, jit_sum_up(limit).item(), ], [ 'sum_up_const', time_str('sum_up_const(limit)'), sum_up_const(limit), ], ]
JIT vs. Math Results
function | time [s] | result |
---|---|---|
jit_sum_up |
0.1480330039994442 | 1249975000 |
sum_up_const |
0.0008629900003143121 | 1249975000 |
Reason? The optimization is too slow to apply and was disabled (but only on the CPU)!8
C with -O1
takes around 5e-126 seconds.
Even More on JIT and Math
More JIT vs. Math
Just for fun, we try the same in PyTorch, using its JIT. We need to
save the following code to a file torch_sum.py
so PyTorch can parse
the (TorchScript) source code.
import torch __all__ = ['torch_sum_up', 'torch_jit_sum_up'] def torch_sum_up(limit): return torch.sum(torch.arange(limit)) torch_jit_sum_up = torch.jit.script(torch_sum_up)
from torch_sum import * torch_limit = torch.tensor(limit) [ ['function', 'time [s]', 'result'], [ 'torch_sum_up', time_str( 'torch_sum_up(torch_limit); ' 'torch.cuda.synchronize()' ), torch_sum_up(torch_limit).item(), ], [ 'torch_jit_sum_up', time_str( 'torch_jit_sum_up(torch_limit); ' 'torch.cuda.synchronize()' ), sum_up_const(torch_limit).item(), ], ]
function | time [s] | result |
---|---|---|
torch_sum_up |
0.15606207399832783 | 1249975000 |
torch_jit_sum_up |
0.1805207170000358 | 1249975000 |
XLA-Optimized LLVM IR
If you’ve been executing jax.jit
code snippets, you can find XLA IR
output in the generated
directory (if you have set the XLA_FLAGS
as in the behind-the-scenes setup code block).
7. vmap
: No Batching Required
What is Batching
If you don’t know what batching is, here are some simple examples. First three non-batched versions, then a batched version. Whether you’re using R, Octave/MATLAB, NumPy, or PyTorch, you always want to batch your calculations for optimum performance. Especially when you are interested in taking gradients, batching greatly simplifies the computational graph.
For the setting, assume we have a small set of 15 3-dimensional edges and we want to sum up their norms because we need it for some computer graphics algorithm.
Non-batched:
rng_key, edges = randn(rng_key, (15, 3)) norm_sum = 0 for edge in edges: norm_sum += jnp.linalg.norm(edge) norm_sum
27.265522 |
Now, we can write this in a more pythonic way by using the built-in
sum
function. However, we are still not batching.
sum(jnp.linalg.norm(edge) for edge in edges)
The following is a more NumPy-like way to write the non-batched version. It may even be faster than the pythonic version due to being able to use SIMD operations. Whether performance is gained or lost depends on the size of our dataset, though.
jnp.sum(jnp.array([jnp.linalg.norm(edge) for edge in edges]))
Finally, we arrive at the batched version. We avoid any and all Python loops, calculating our norm sum in a much more efficient manner.
Batched:
norms = jnp.linalg.norm(edges, axis=-1) # shape: (15,) norm_sum = jnp.sum(norms) norm_sum
27.265522 |
vmap
: No Batching Required
Now, what if I told you you no longer needed to do batching manually?
Enter jax.vmap
!
The following example calculates the spectral radius on a 128-size
batch of 3×3 matrices. With the assert
statement, we make sure
our old version jax_spectral_radius
is not batched already.
Notice also that you can combine jax.jit
with jax.vmap
– any of
the function transformations in JAX are arbitrarily nestable; quite
the magic! I’ll let the below timings speak for themselves.
Why prefer the jax.jit
-jax.vmap
combo over a jax.jit
ted loop? If
anything, it keeps the computational graph simpler!
rng_key, batch = jit_randn(rng_key, (128, 3, 3)) assert jax_spectral_radius(batch).shape != (128,) def looped_spectral_radius(batch): return list(map(jax_spectral_radius, batch)) jit_looped_spectral_radius = jax.jit(looped_spectral_radius) batched_spectral_radius = jax.vmap(jax_spectral_radius) jit_batched_spectral_radius = jax.jit(batched_spectral_radius)
function | time [s] |
---|---|
looped |
10.506742246001522 |
jit_looped |
1.7194310759987275 |
batched |
3.65071794799951 |
jit_batched |
0.9865755659993738 |
More Batching
Because we can’t get enough of comparisons across frameworks, here are some more batched implementations.
Since we are going to JIT-compile PyTorch’s TorchScript again, you
need to save the following code block to torch_batching.py
.
By the way, if you were wondering: implementing looped versions with
either jnp.stack
or a preallocated result matrix did not improve
results. In fact, the pre-allocating version took too long to
JIT-compile so I never got a number for that one. Below, you can find
a version using jax.lax.scan
.
import torch __all__ = ['torch_batched_spectral_radius', 'torch_jit_batched_spectral_radius'] def torch_batched_spectral_radius(mat): eigvals = torch.linalg.eigvals(mat) spectral_radius = torch.max(torch.abs(eigvals), -1) return spectral_radius torch_jit_batched_spectral_radius = torch.jit.script( torch_batched_spectral_radius)
from torch_batching import * def np_batched_spectral_radius(mat): eigvals = np.linalg.eigvals(mat) spectral_radius = np.max(np.abs(eigvals), -1) return spectral_radius def jax_manually_batched_spectral_radius(mat): eigvals = jnp.linalg.eigvals(mat) spectral_radius = jnp.max(jnp.abs(eigvals), -1) return spectral_radius jax_jit_manually_batched_spectral_radius = jax.jit( jax_manually_batched_spectral_radius) np_batch = np.array(batch) torch_batch = torch.from_numpy(np_batch) [ ['function', 'time [s]'], [ 'np_batched', time_str('np_batched_spectral_radius(np_batch)'), ], [ 'jax_manually_batched', time_str( 'jax_manually_batched_spectral_radius(batch).block_until_ready()'), ], [ 'jax_jit_manually_batched', time_str( 'jax_jit_manually_batched_spectral_radius(batch)' '.block_until_ready()'), ], [ 'torch_batched', time_str( 'torch_batched_spectral_radius(torch_batch); ' 'torch.cuda.synchronize()' ), ], [ 'torch_jit_batched', time_str( 'torch_jit_batched_spectral_radius(torch_batch); ' 'torch.cuda.synchronize()' ), ], ]
More Batching Results
function | time [s] |
---|---|
np_batched |
0.8191060569988622 |
jax_manually_batched |
1.0442584550000902 |
jax_jit_manually_batched |
0.8894995779992314 |
torch_batched |
1.276841124999919 |
torch_jit_batched |
1.373109233998548 |
When comparing these timings, keep in mind that we are not running on the GPU.
8. pmap
: Simple, Differentiable MPI
pmap
- Simple function transformation for splitting a computation across devices.
- Certain
jax.lax
primitives for reduction (pmean
,psum
, …). - Like to stay old school? Differentiable
mpi4jax
9.
To activate for local tests (adjust num_devices
as desired):
import multiprocessing import os num_devices = multiprocessing.cpu_count() os.environ['XLA_FLAGS'] = ( os.getenv('XLA_FLAGS') + ' --xla_force_host_platform_device_count=' + str(num_devices) )
jax.pmap
is actually not just related to jax.vmap
in name – the
functions do the exact same thing, just different: jax.vmap
batches
its function and can be imagined as a for
-loop over the mapped-over
axis. jax.pmap
also batches its function but is instead a parallelly
executed for
-loop. That’s really all there is to it; when you know
jax.vmap
, you know jax.pmap
. Of course, the parallelism offers
some extra functionalities which is exactly what we are going to be
learning in this section.
pmap
and Axes
- Splits computation according to an axis; works over multiple devices and/or JAX processes.
- This broadcasting axis must be the size of our devices/processes, so
reshape your data accordingly:
- assuming 4 devices/processes and broadcasting axis 0, reshape
dataset of shape
(128, 3, 3)
to(4, 32, 3, 3)
. In code:
world_size = jax.device_count() # or `jax.process_count()` dataset = dataset.reshape(world_size, -1, dataset.shape[1:])
- assuming 4 devices/processes and broadcasting axis 0, reshape
dataset of shape
8.1. Write your own Horovod!
The following is a case study of using JAX to parallelize simple deep learning code. We are going to train a simple multilayer perceptron and teach it to calculate spectral radii.
I referenced the parallel deep learning framework Horovod because it seems to be the most-used tool at Jülich Supercomputing Centre for model-parallel training.
You can substitute “Horovod” with whatever you like to use, though; be
it MPI, tf.distribute.Strategy
, torch.distributed
,
torch.distributed
with torch.nn.parallel.DistributedDataParallel
(DDP), or whatever else you know and love. The principles are all the
same, although our JAX implementation is more similar to Horovod than
PyTorch DDP, for example (due to global instead of node-local
splitting of the batch).
Non-Distributed Setup
Training Code
What follows is some boilerplate setup code for deep learning using
JAX’ built-in example machine learning libraries
jax.example_libraries.stax
and jax.example_libraries.optimizers
.
Notice how all state is explicitly handled with these. If you don’t
care for this, you can skip straight to the interesting part.
I also quickly want to mention why the code is a bit larger than it
could be: I like my model to be able to work with dynamic batch sizes;
however, if you follow JAX’ official example, it would seem that the
batch size needs to be fixed. By implementing two little extras, we
are able to handle arbitrary batch sizes. First, we implement our own
small flatten
function/layer that flattens the whole input (unlike
stax.Flatten
). Second, we apply jax.vmap
to our model
function.
The model
function is always passed the model’s parameters as well,
that is why we do not want to vmap
over the first argument
(indicated by in_axes=(None, [...])
). And that’s it!
The only slight inconvenience with this setup is that we need to add an extra size-1 batch dimension in order to handle singular inputs. You’ll see this when we test our model later.
from jax.example_libraries import stax from jax.example_libraries import optimizers input_shape = (2, 2) def flatten(): def init_fun(rng_key, input_shape): flattened_size = jnp.prod(jnp.array(list(input_shape))) return (flattened_size,), () def apply_fun(params, inputs, **kwargs): return inputs.ravel() return init_fun, apply_fun def build_model(rng_key): model_init, model = stax.serial( flatten(), stax.Dense(64), stax.Relu, stax.Dense(64), stax.Relu, stax.Dense(64), stax.Relu, stax.Dense(1), ) # Handle varying batch sizes. model = jax.vmap(model, in_axes=(None, 0)) rng_key, subkey = jax.random.split(rng_key) output_shape, params = model_init(subkey, input_shape) assert output_shape == (1,) return rng_key, params, model def build_opt(params): opt_init, update, get_params = optimizers.adam(3e-4) opt_state = opt_init(params) return opt_state, update, get_params rng_key, params, model = build_model(rng_key) orig_opt_state, opt_update, get_params = build_opt(params)
Interesting Part of the Training Code
Here we implement our update methods. Notice how we use
batched_spectral_radius
instead of jit_batched_spectral_radius
in
order to give XLA more optimization freedom. Also, here we see
conjugating the possibly complex gradients in action.
def batch_loss(params, batch): preds = model(params, batch) targets = batched_spectral_radius(batch) return jnp.mean(jnp.abs(preds - targets)) @jax.jit def train_batch(step, opt_state, batch): params = get_params(opt_state) loss, grad = jax.value_and_grad(batch_loss)(params, batch) # Conjugate gradient for steepest-descent optimization. grad = jax.tree_util.tree_map(jnp.conj, grad) opt_state = opt_update(step, grad, opt_state) return opt_state, loss
Training a Spectral Radius MLP
A really simple deep learning training loop. We generate our batches
on-demand, taking care to update our rng_key
, of course!
opt_state = orig_opt_state batch_size = 64 batch_shape = (batch_size,) + input_shape steps = 10000 log_step_interval = 1000 start_time = time.perf_counter() for step in range(steps): rng_key, batch = jit_randn(rng_key, batch_shape, dtype=complex) opt_state, loss = train_batch(step, opt_state, batch) if step % log_step_interval == 0: print('step ', step, '; loss ', loss, sep='') end_time = time.perf_counter() print('Training took', end_time - start_time, 'seconds.')
Training Results
step 0; loss 1.6026156 step 1000; loss 0.39599544 step 2000; loss 0.38151193 step 3000; loss 0.4386006 step 4000; loss 0.3645811 step 5000; loss 0.38383436 step 6000; loss 0.4037715 step 7000; loss 0.3104779 step 8000; loss 0.32223767 step 9000; loss 0.40970623 Training took 7.869086120001157 seconds.
Okay, that’s some sound old MLP training. Let’s get into parallelization already.
Multi-Node Distribution
A quick interlude on some extra distributed use cases and GPU memory pre-allocation. These are only interesting if you plan to distribute code yourself and are more code snippets I wanted to leave here as a reference. To skip these sections, click here.
The following is some setup code you would use on a Slurm-managed cluster, for example. But first, a word of caution…
Multi-Node Distributed Setup
- Caution: Experimental, undocumented and not even used, much less tested, anywhere inside JAX!
Future versions will have a function
jax.distributed.initialize
, working much like PyTorch’storch.distributed.init_process_group( [...], init_method="tcp://[...]", # or "env://" )
Multi-Node Distributed Setup Code
Adapted from JAX source code (clickable version of link below):
# Adapted and compressed from # https://github.com/google/jax/blob/4d6467727709de1c9ad220ac62783a18bcbf4990/jax/_src/distributed.py def jax_distributed_initialize( coordinator_address, num_processes, process_id): if process_id == 0: global _service _service = \ jax.lib.xla_extension.get_distributed_runtime_service( coordinator_address, num_processes) client = jax.lib.xla_extension.get_distributed_runtime_client( coordinator_address, process_id) client.connect() factory = functools.partial( jax.lib.xla_client.make_gpu_client, client, process_id) jax.lib.xla_bridge.register_backend_factory( 'gpu', factory, priority=300)
Handling GPU Memory Pre-Allocation
If you ever had trouble with running out of GPU memory when using multi-process TensorFlow, you may have fixed it by enabling “GPU memory growing”. (The default is that TensorFlow pre-allocates a large block of memory in order to reduce memory fragmentation.)
JAX does the same, so in case you need it, what follows is the JAX equivalent to enabling GPU memory growing in TensorFlow.
Disable GPU Memory Pre-Allocation
Equivalent of the following TensorFlow:
import tensorflow as tf gpus = tf.config.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
In JAX:
import os # Before first JAX computation. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
Distributed Training
Distributing our Training Code
With all of that out of the way, let’s finally write some distributed training code!
… What’s that? We only need to add a single line in our
train_batch
function?
Well, almost; we also need to apply the titular jax.pmap
function
transformation. I’ll explain what’s going on here after the code
block.
batch_axis = 'batch' def distributed_train_batch(step, opt_state, batch): params = get_params(opt_state) loss, grad = jax.value_and_grad(batch_loss)(params, batch) # This is the only line we had to add to `train_batch`. loss, grad = jax.lax.pmean((loss, grad), batch_axis) # Conjugate gradient for steepest-descent optimization. grad = jax.tree_util.tree_map(jnp.conj, grad) opt_state = opt_update(step, grad, opt_state) return opt_state, loss # `pmap` also `jit`s our function. pmap_train_batch = jax.pmap( distributed_train_batch, batch_axis, in_axes=(None, None, 0), out_axes=(None, None), )
This is already all the magic behind Horovod! (Ignoring the super-optimized communication.)
The first thing that should have caught your eye is the definition of
batch_axis
at the very top of the code block. This batch_axis
is
passed to both jax.lax.pmean
and jax.pmap
as the reduction axis
and mapped-over axis, respectively. We need to do this because – and
I hope you’ve started to notice a pattern here – jax.pmap
is also
nestable! By having to specify an axis for jax.pmap
, the reduction
operations in jax.lax
will always have an axis to refer to. We use a
string to name the axis here, but any hashable would do.
The call to jax.lax.pmean
averages a tree of values over all
processes. In our case, it averages the loss and gradient. We average
the loss as well here because you usually want to log the globally
averaged loss instead of the local loss to get a smoother overall
picture (in Horovod, you need to enable this explicitly). The averaged
gradient is then used to do the same update step on each process, so
we don’t need any more synchronization afterwards.
jax.pmap
has some other arguments here we haven’t seen yet, namely
in_axes
and out_axes
. jax.vmap
accepts these too, and they are
very important! With them, you control the axis of each argument that
the function transformation maps over. If you don’t want to map over
an argument (maybe you are passing a single constant that you don’t
want to copy until it has the size of the mapped-over axis), you
specify in_axes
at the position of the argument as None
. We do
this for the training step (a single integer) and model parameters (a
tree of matrices). However, we do want to map over the batch
somehow. We specify the very first axis here.
out_axes
is similar, but for the output. If we wanted to collect the
different outputs of the function transformation for each mapped-over
input, we would specify the return values that we want to collect and
the axis we want to collect them over at the corresponding positions
in out_axes
. Since we already reduce over the mapped-over values
with the jax.lax.pmean
call, we will not have multiple different
outputs and thus use None
just like for in_axes
to prevent the
collection.
Training a Spectral Radius MLP Distributively
Here, we reshape the original batch so we can split it over its first
axis as desired. In the code, we obtain the world_size
as the number
of devices known to JAX. Going back to the multi-node setup code from
before, if you initialize your distributed training like that,
i.e. by starting multiple Python processes, you may want to
use world_size = jax.process_count()
instead. Like so often, this
depends on your use case.
opt_state = orig_opt_state # Since we are only using a single JAX process with (possibly # emulated) multiple devices, we use `jax.device_count`. # Usually, this would be `jax.process_count`. world_size = jax.device_count() assert batch_size % world_size == 0 local_batch_size = batch_size // world_size start_time = time.perf_counter() print('Training with', world_size, '(possibly simulated) devices.') for step in range(steps): rng_key, batch = jit_randn(rng_key, batch_shape, dtype=complex) batch = batch.reshape( (world_size, local_batch_size) + batch.shape[1:]) opt_state, loss = pmap_train_batch(step, opt_state, batch) if step % log_step_interval == 0: print('step ', step, '; loss ', loss, sep='') end_time = time.perf_counter() print('Distributed training took', end_time - start_time, 'seconds.')
Distributed Training Results
Training with 2 (possibly simulated) devices. step 0; loss 1.6025591 step 1000; loss 0.3969112 step 2000; loss 0.38199607 step 3000; loss 0.4381593 step 4000; loss 0.36498725 step 5000; loss 0.38379186 step 6000; loss 0.40361017 step 7000; loss 0.3104655 step 8000; loss 0.32245204 step 9000; loss 0.40979913 Distributed training took 7.309056613001303 seconds.
Even though I only used 2 simulated devices, training did speed up a bit already. Nice!
Did it Learn?
Let’s test whether our model actually learned anything useful. Maybe
the logged losses don’t tell the whole story. Also, please notice that
I JITted the model
function for inference here; I wanted to show
this off as it will really speed up your inference times. (Of course,
JITting did not really make sense here since I only use the function
once.)
params = get_params(opt_state) jit_model = jax.jit(model) ceig_mat = jnp.array([[1.0, -1.0], [1.0, 1.0]]) batched_ceig_mat = jnp.expand_dims(ceig_mat, 0) [ ['function', 'spectral radius sample'], [ 'jax_spectral_radius', jax_spectral_radius(ceig_mat).item(), ], [ 'model', jit_model(params, batched_ceig_mat)[0].item(), ], ]
function | spectral radius sample |
---|---|
jax_spectral_radius |
1.4142135381698608 |
model |
1.4250930547714233 |
9. Summary
Advantages
- Educational documentation.
- Familiar API.
- Interoperate with TensorFlow using experimental
jax2tf
(included in JAX; despite the name, also supports TensorFlow to JAX). - Faster code!
- More explicit, less magic, less trouble understanding.
- Functional style will avoid headaches.
Look out for stabilization of
pjit
(previouslysharded_jit
); even simpler Horovod, ability to JIT huge functions!pjit
is actually much more cool than you would expect. The old name is a bit more descriptive here, aspjit
– aside from being a more abstractedpmap
– allows us to even split super large functions that do not fit in the memory of a single device.
Disadvantages
Initial hurdles:
- Getting used to it.
- Sometimes not as quick to write; however, payoff in the long term.
Better with time:
- Sometimes unpredictable or unstable.
- Lacking ecosystem.
Will never change:
- Hidden functionalities/undocumented APIs; some useful code (intentionally?) not public.
- Mutating state during JITting will cause headaches.
Backend in TensorFlow: code split and dependency.
The code split means that the TensorFlow repository also contains JAX-only code. So you have another code location to keep in mind if you need to dive deeper into the JAX JIT for some reason.
Neural Network Libraries
jax.example_libraries.stax
- Included in JAX, bare-bones and requires more manual work. However, simplicity is an advantage.
flax
- Most features, most user-friendly in my opinion.
haiku
- Goes against JAX’ implicit/immutable state but converts models back to stateless transformations. Thus, maybe better user experience.
objax
- Similar API to
torch.nn.Module
. trax
- Focus on sequence data and large-scale training.
These are all made by the same company!
Just to clarify, while the above statement can be interpreted as a cheap jab at Alphabet and the TensorFlow situation, I do believe that designing sensible APIs for JAX’ paradigms is hard. Most of these libraries can most likely be viewed as experiments with regard to this design.
functorch
- JAX-like function transformations (
vmap
,grad
, …) - Stateless models
… for PyTorch! Experimental, but best of both worlds.
When you ever have trouble batching something in PyTorch, it may help. Will probably be a while until it’s included in PyTorch.
By the way, PyTorch also has a JIT!
Since this is the extended version, you did see some results with PyTorch’s JIT. All of those did not seem… promising. The non-JITted PyTorch code was consistently as fast or even faster than the JITted version.
However, I believe that PyTorch’s JIT compiler’s use case is a bit different, leaning more towards optimizing deep learning models for inference. I don’t understand why the super simple linear algebra we tried to JIT did not get optimized (un-optimized, instead!) at all, but I did not dive into PyTorch’s JIT and thus can’t say too much about this.
Thanks for Reading!
Thank you for your attention.
I hope you had as much fun as I had preparing.
10. Appendix
References
- JAX source repository (accessed from 2021-11-18 to 2021-11-26)
Extra Recommendations
I recommend this article and corresponding paper on limitations of XLA and TorchScript compilation.
Maybe of interest: there are already JAX libraries for differentiable rigid-body physics simulation and molecular dynamics. You will find others fields covered as well.
Footnotes:
Stable since PyTorch 1.9.0, released in June 2021.
Some of the Autograd developers took part in building JAX upon its ideas.
Why is the second a
considered different from the first? The
default __hash__
implementation is based on the object’s id
, its
location in memory. (Both of these are CPython implementation
details.)
That way, the compiler gets the most options for optimization.
Also, an already jit
ted inner function cannot be optimized further.
There are multiple XLA runtimes.
C family of languages, to be exact.
Not quite sure since they only disable --loop-unroll
.
The Message Passing Interface (MPI) is a standard in parallel computing.