JAX 101

Given the length of the official JAX tutorial, this note distills the core concepts, providing an quick reference after reading the original tutorial.

High-level JAX stack

image.png

Source: Yi Wang’s linkedin post

Quick start

1
2
3
4
5
def get_negatives(x):
return x[x < 0]
x = jnp.array(np.random.randn(10))
jit(get_negatives)(x)
# NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[10]
  • jax.vmap: write your function with a single data point (e.g., one vector, one image). Then jax.vmap transforms it to process the entire batch.
    • can be composed with jax.jit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import jax
import jax.numpy as jnp

# works on a single vector
def predict(params, input_vector):
return jnp.dot(params['W'], input_vector) + params['b']

key = jax.random.key(0)
params = {
'W': jax.random.normal(key, (3, 5)),
'b': jax.random.normal(key, (3,))
}
single_vector = jnp.ones(5)
result_single = predict(params, single_vector)

batch_of_vectors = jnp.ones((8, 5))
# Vectorize predict using vmap:
# in_axes=(None, 0):
# - None: Don't map over 'params' (W and b are shared across the batch).
# - 0: Map over the 0-th axis of 'input_vector' (the batch dimension).
# out_axes=0: The output should have the batch dimension at axis 0.
batched_predict = jax.vmap(predict, in_axes=(None, 0), out_axes=0)
results_batch = batched_predict(params, batch_of_vectors)

# Expected shape: (batch_size, output_dim) => (8, 3)
  • random key: never reuse keys (unless you want identical outputs). In order to generate different and independent samples, you must jax.random.split the key explicitly before passing it to a random function:
1
2
3
4
5
6
7
8
9
10
for i in range(3):
new_key, subkey = random.split(key)
del key # The old key is consumed by split() -- we must never use it again.

val = random.normal(subkey)
del subkey # The subkey is consumed by normal().

print(f"draw {i}: {val}")
key = new_key # new_key is safe to use in the next iteration.

The Sharp Bits (aka common mistakes)

https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html

  • JAX works with pure functions
    • all inputs are in function params, all outputs are function returns
    • given same inputs, we always have same outputs (e.g., print is not allowed in the function)
    • not recommended to use iterators and any control-flow primitive in jit function → error or unexpected results
  • in-place updates
    • JAX arrays are immutable. In-place updates like jax_array[1, :] = 1.0 yields an error
    • use .at instead: updated_array = jax_array.at[1, :].set(1.0), which is a out-of-place update that returns a new array.
      • similarly, we have other ops like new_jax_array = jax_array.at[::2, 3:].add(7.0)
  • out-of-bounds indexing
    • when the indexing op is an array index update (e.g. index_add or scatter-like primitives), updates at out-of-bounds indices will be skipped
      • index_add: jax_array.at[index].add(value) → lowered into a scatter operation on GPU/TPU
      • scatter primitives: you give indices, and updates, then JAX writes them into an array
    • when the op is an array index retrieval (e.g. NumPy indexing or gather-like primitives) the index is clamped to the bounds since something must be returned
      • e.g., jnp.arange(10)[11] → the last value will be returned
  • non-jax-array inputs
    • don’t use python list, tuple or other non-array types as inputs for JAX → performance issues
    • you can explicitly convert it to a jax array instead: jnp.sum(jnp.array(x))
  • dynamic shapes
    • code used within transforms (like jax.jit, jax.vmap, and jax.grad) requires all output arrays and intermediate arrays to have static shape
  • JAX by default enforces single-precision numbers (claims dtype=jnp.float64 doesn’t work)
    • to use double, add jax_enable_x64 configuration variable at startup

JAX 101

1. JIT

  • JIT: make python function to be executed efficiently in XLA. recall this part

image.png

  • How? → reducing a function to a sequence of primitives (can be viewed using jax.make_jaxpr())
  • It does not capture the side-effect code (like print(), append the input to an external list, etc.). Don’t JIT-compile impure functions!
    • use jax.debug.print() to debug printing instead (but performance will be worse)
  • When tracing, JAX wraps each arg by a tracer object (by default with shape and dtype), which records all JAX ops during the function call.
  • The tracer records are used to reconstruct the entire function → output a jaxpr
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
result = jnp.dot(x + 1, y + 1)
print(f" result = {result}")
return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

# Running f():
# x = JitTracer<float32[3,4]>
# y = JitTracer<float32[4]>
# result = JitTracer<float32[3]>
  • If we have a Python conditional, jaxpr only knows about the branch we take.

  • JIT won’t work for the following code

    • traced values like x and n, can only affect control flow via their static attributes (like shape or dtype), but not via their values

      1
      2
      3
      4
      5
      6
      7
      def g(x, n):
      i = 0
      while i < n:
      i += 1
      return x + i

      jax.jit(g)(10, 20) # Raises an error
    • we can use special Control flow operators like jax.lax.cond().

    • or we can JIT-compile only part of the function

      1
      2
      3
      4
      5
      6
      7
      8
      9
      @jax.jit
      def loop_body(prev_i):
      return prev_i + 1

      def g_inner_jitted(x, n):
      i = 0
      while i < n:
      i = loop_body(i)
      return x + i
    • we can use static_argnums or static_argnames, but have to re-compile the function for every new value of the specified static input

      1
      2
      3
      4
      5
      6
      7
      8
      9
      def f(x):
      if x < 3:
      return 3. * x ** 2
      else:
      return -4 * x

      f = jit(f, static_argnames='x')

      print(f(2.))
      1
      2
      3
      4
      5
      6
      7
      from functools import partial

      @partial(jit, static_argnums=(1,))
      def f(x, neg):
      return -x if neg else x

      f(1, True)
  • JIT caching: when first calling a jit function, the function get compiled and cached.

2. Automatic vectorization

  • suppose we have a single conv function

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    x = jnp.arange(5)
    w = jnp.array([2., 3., 4.])

    def convolve(x, w):
    output = []
    for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
    return jnp.array(output)

    convolve(x, w)
  • with jax.vmap(), it can receive batches of inputs (default batch dimension = 0)

    1
    2
    3
    4
    xs = jnp.stack([x, x])
    ws = jnp.stack([w, w])
    auto_batch_convolve = jax.vmap(convolve)
    auto_batch_convolve(xs, ws)
  • we can use in_axes and out_axes to specify batch dimensions

    1
    2
    3
    4
    auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
    xst = jnp.transpose(xs)
    wst = jnp.transpose(ws)
    auto_batch_convolve_v2(xst, wst)
  • example: convolve to a single set of weights w with a batch of vectors x

    1
    2
    batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
    batch_convolve_v3(xs, w)
  • jax.vmap can be combined with jax.jit

    1
    jitted_batch_convolve = jax.jit(auto_batch_convolve)

3. Pytrees

3.1 basics

  • a Pytree is simply a container of arrays. It can be a list, a tuple, a dict, or a nested combination of these. JAX functions are designed to accept these structures as inputs and return them as outputs smoothly.

  • a leaf is anything that’s not a pytree, such as an array, but a single leaf is also a pytree

  • pytree: any tree-like structure built out of container-like Python objects

    • container-like: in pytree container registry; defaults are lists, tuples, and dicts
    • object whose type is not in the pytree container registry → leaf node
  • in ML, a pytree can contain: model weights, dataset entries, RL observations, etc.

  • jax.tree.leaves(): flattened leaves from the trees

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
    ]
    for pytree in example_trees:
    leaves = jax.tree.leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
    1
    2
    3
    4
    5
    [1, 'a', <object object at 0x7d8bcaf08490>]   has 3 leaves: [1, 'a', <object object at 0x7d8bcaf08490>]
    (1, (2, 3), ()) has 3 leaves: [1, 2, 3]
    [1, {'k1': 2, 'k2': (3, 4)}, 5] has 5 leaves: [1, 2, 3, 4, 5]
    {'a': 2, 'b': (2, 3)} has 3 leaves: [2, 2, 3]
    Array([1, 2, 3], dtype=int32) has 1 leaves: [Array([1, 2, 3], dtype=int32)]
  • jax.tree.map(f, tree): apply a function to every leaf in a nested data structure

    jax.tree.map()

    1
    2
    3
    4
    5
    6
    7
    list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
    ]
    jax.tree.map(lambda x: x*2, list_of_lists)
    # [[2, 4, 6], [2, 4], [2, 4, 6, 8]]
    1
    2
    another_list_of_lists = list_of_lists
    jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
  • view pytree definition

    1
    2
    3
    from jax.tree_util import tree_structure
    print(tree_structure(object))
    # PyTreeDef(*)
  • pytree and JAX transformations

    • many JAX functions operates over pytrees of arrays
    • all JAX function transformations can be applied to functions whose inputs and outputs are pytrees of arrays
  • in_axes and out_axes arguments to jax.vmap() can also be pytrees (structure must match the corresponding arguments)

    • example: function f has has an input like (a1, {"k1": a2, "k2": a3})

      • a variable a1
      • a dict containing keys k1 (value a2) and k2 (value a3)
    • we can use jax.vmap(f, in_axes=(None, {"k1": None, "k2": 0}))

      • you want to vectorize (loop over) the data in k2, but not a1 or k1 (they are constants/broadcasted)
    • a single leaf value can be applied to an entire sub-pytree

      1
      2
      jax.vmap(f, in_axes=(None, 0))
      # equivalent to (None, {"k1": 0, "k2": 0})
      1
      jax.vmap(f, in_axes=0)  # equivalent to (0, {"k1": 0, "k2": 0})

3.2 key paths

  • each leaf has a key path (i.e., a list of keys, list length = tree depth of the leaf)

  • key type depends on node type; e.g., key type for dicts is different from key type for tuples

    • SequenceKey(idx: int): For lists and tuples.
      • this key type is used for nodes that are ordered sequences and accessed by an integer index.
      • key content: An integer index specifying the position.
    • DictKey(key: Hashable): For dictionaries.
    • GetAttrKey(name: str): For namedtuples and preferably custom pytree nodes (more in the next section)
  • jax.tree_util.tree_flatten_with_path(): similar to jax.tree.flatten(), but returns flattened key paths

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    jnp.array([1, 2, 3]),
    ]
    jax.tree_util.tree_flatten_with_path(example_trees)
    # ([((SequenceKey(idx=0), SequenceKey(idx=0)), 1),
    # ((SequenceKey(idx=0), SequenceKey(idx=1)), 'a'),
    # ((SequenceKey(idx=0), SequenceKey(idx=2)), <object at 0x70d4e8c6b600>),
    # ((SequenceKey(idx=1), SequenceKey(idx=0)), 1),
    # ((SequenceKey(idx=1), SequenceKey(idx=1), SequenceKey(idx=0)), 2),
    # ((SequenceKey(idx=1), SequenceKey(idx=1), SequenceKey(idx=1)), 3),
    # ((SequenceKey(idx=2),), Array([1, 2, 3], dtype=int32))],
    # PyTreeDef([[*, *, *], (*, (*, *), ()), *]))
  • jax.tree_util.tree_map_with_path(): similar to jax.tree.map(), but the inputs are also key paths

  • jax.tree_util.keystr(): given a key path, returns a reader-friendly string

  • example

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import collections

    ATuple = collections.namedtuple("ATuple", ('name'))

    tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
    flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

    for key_path, value in flattened:
    print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')

    # Value of tree[0]: 1
    # Value of tree[1]['k1']: 2
    # Value of tree[1]['k2'][0]: 3
    # Value of tree[1]['k2'][1]: 4
    # Value of tree[2].name: foo

3.3 common mistakes

  • accidentally introducing tree nodes instead of leaves

    1
    2
    3
    4
    5
    6
    7
    a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

    # Try to make another pytree with ones instead of zeros.
    shapes = jax.tree.map(lambda x: x.shape, a_tree)
    jax.tree.map(jnp.ones, shapes)
    # [(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
    # (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]
    • shape of an array is a tuple, which is a pytree node, with its elements as leaves.
    • the map instead of calling jnp.ones on e.g. (2, 3), it’s called on 2 and 3.
  • jax.tree_util functions treat None as the absence of a pytree node, not as a leaf

    1
    2
    3
    4
    5
    jax.tree.leaves([None, None, None])
    # []

    # fix
    jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)

3.4 common pytree patterns

  • transposing pytrees
    • transpose: turn a list of trees into a tree of lists

    • option1 (basic): jax.tree.map()

      1
      2
      3
      4
      5
      6
      7
      8
      def tree_transpose(list_of_trees):
      return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

      # Convert a dataset from row-major to column-major.
      episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
      tree_transpose(episode_steps)
      # *xs are tuples {'obs': (3, 4), 't': (1, 2)}
      # with lambda, they are converted to lists {'obs': [3, 4], 't': [1, 2]}
    • option2 (comple but flexible): jax.tree.transpose()

      1
      2
      3
      4
      5
      jax.tree.transpose(
      outer_treedef = jax.tree.structure([0 for e in episode_steps]),
      inner_treedef = jax.tree.structure(episode_steps[0]),
      pytree_to_transpose = episode_steps
      )

4. Intro to parallel programming

4.1 Basics

  • this tutorial is for SPMD program: same computation (e.g., forward pass), different data (e.g., different inputs in a batch) in parallel on different devices (e.g., TPUs)

  • we cover auto sharding (jax.jit()), explicit sharding, and fully manual sharding (jax.shard_map(), per-device sharding + explicit collectives)

  • jax.Array is designed with distributed data and computation in mind. It has jax.sharding.Sharding object.

  • by default, arrays are on a single device

    1
    2
    3
    4
    5
    import jax
    jax.config.update('jax_num_cpu_devices', 8)
    arr = jnp.arange(32.0).reshape(4, 8)
    arr.devices()
    # {CpuDevice(id=0)}
  • visualize sharding: jax.debug.visualize_array_sharding()

  • define sharding

1
2
3
4
5
6
from jax.sharding import PartitionSpec as P
# 8 devices (2x4), 1st dimesion called 'x', 2nd 'y'
mesh = jax.make_mesh((2, 4), ('x', 'y'))
# array row mapped to x-axis, col to y-axis
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr_sharded = jax.device_put(arr, sharding)

4.2 auto sharding via jax.jit()

  • you specify how inputs/outputs are sharded, jax.jit() will automatically
    • partition everything inside
    • compile inter-device communications
  • example with element-wise function: same sharding as inputs
1
2
3
4
5
6
@jax.jit
def f_elementwise(x):
return 2 * jnp.sin(x) + 1
result = f_elementwise(arr_sharded)
print("shardings match:", result.sharding == arr_sharded.sharding)
# True
  • example with reduce
1
2
3
4
5
6
7
8
@jax.jit
def f_contract(x):
return x.sum(axis=0)
result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
# ┌───────┬───────┬───────┬───────┐
# │CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│
# └───────┴───────┴───────┴───────┘

4.3 Explicit sharding (sharding-in-types)

  • the JAX-level type of a value includes how the value is sharded
1
2
3
some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")
# JAX-level type of some_array: int32[8]
  • we can query the type in jit
1
2
3
4
@jax.jit
def foo(x):
print(f"{jax.typeof(x)}")
return x + x
  • set up an explicit-sharding mesh
1
2
3
4
5
6
7
8
9
10
11
from jax.sharding import AxisType
mesh = jax.make_mesh((2, 4), ("X", "Y"),
axis_types=(AxisType.Explicit, AxisType.Explicit))
replicated_array = np.arange(8).reshape(4, 2)
sharded_array = jax.device_put(replicated_array,
jax.NamedSharding(mesh, P("X", None)))
print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
# replicated_array type: int32[4,2]
# sharded_array type: int32[4@X,2]
# -> 1st dimesion sharded along X, others are replicated

4.4 Manual parallelism with shard_map

  • you write the function for a single shard, and jax.shard_map() constructs the full function

  • shard_map() maps over shards

    • in_specs determines the shard sizes
      • split 32 elements in arr to 8 devices (4 elements / device)
    • out_specs identifies how blocks are assembled back together
      • the outputs coming from the devices represent chunks split along axis 'x'
      • it must concatenate (glue) these 8 chunks back together in order to form the final global array
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    mesh = jax.make_mesh((8,), ('x',))

    @jax.jit
    def f_elementwise(x):
    return 2 * jnp.sin(x) + 1

    f_elementwise_sharded = jax.shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

    arr = jnp.arange(32)
    f_elementwise_sharded(arr)
  • the function can only see a single shard

    1
    2
    3
    4
    5
    6
    7
    8
    x = jnp.arange(32)
    print(f"global shape: {x.shape=}")
    def f(x):
    print(f"device local shape: {x.shape=}")
    return x * 2
    y = jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
    # global shape: x.shape=(32,)
    # device local shape: x.shape=(4,)
  • be careful about aggregation-like functions

1
2
3
4
def f(x):
return jnp.sum(x, keepdims=True)
jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
# Array([ 6, 22, 38, 54, 70, 86, 102, 118], dtype=int32)
  • you need jax.lax.psum() to get a single sum
1
2
3
4
5
def f(x):
sum_in_shard = x.sum()
return jax.lax.psum(sum_in_shard, 'x')
jax.shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P())(x)
# Array(496, dtype=int32)

4.5 Data / computation placement on devices

  • In JAX, the computation follows data placement
  • 2 placement properties
    • the device where the data resides
    • whether it is committed to the device or not
    • default: uncommitted on jax.devices()[0]
  • data with jax.device_put() becomes committed to the device. example: arr = device_put(1, jax.devices()[2])

5. Structured control flow with JIT

if we want to use control flow that’s traceable, and that avoids un-rolling large loops, there are 4 structured control flow:

lax.cond()

semantics:

1
2
3
4
5
def cond(pred, true_fun, false_fun, operand):
if pred:
return true_fun(operand)
else:
return false_fun(operand)

example:

1
2
3
4
5
6
7
8
from jax import lax

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)
# --> array([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand)
# --> array([-1.], dtype=float32)

lax.while_loop()

semantics:

1
2
3
4
5
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val

example:

1
2
3
4
5
init_val = 0
cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)

lax.fori_loop()

semantics:

1
2
3
4
5
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val

example:

1
2
3
4
5
6
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)

image.png

6. Static vs traced ops

  • the following will lead to an error, because when we use jnp.array and jnp.prod on the static value x.shape, it becomes a traced value
1
2
3
4
5
6
7
8
9
import jax.numpy as jnp
from jax import jit

@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)
  • use numpy for operations that should be static (i.e. done at compile-time)
  • use jax.numpy for operations that should be traced (i.e. compiled and executed at run-time)
  • fix:
1
2
3
4
5
6
7
8
9
10
11
12
13
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
# np.prod will get 6 during compile-time
# while jnp.prod is like "I will calculate this product later on the GPU"
return x.reshape((np.prod(x.shape),))


x = jnp.ones((2, 3))
f(x)

7. Stateful computation

  • JAX transformations (like jit(), vmap(), and grad()) replies on pure functions
  • how to make stateful computations? → add explicit state as arg
  • suppose we have
1
2
3
4
5
6
7
8
9
10
11
12
13
import jax
import jax.numpy as jnp

class Counter:
def __init__(self):
self.n = 0

def count(self) -> int:
self.n += 1
return self.n

def reset(self):
self.n = 0
  • to use it with jax.jit(), we can do
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
CounterState = int

class CounterV2:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1

def reset(self) -> CounterState:
return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
value, state = counter.count(state)

JAX 101
https://gdymind.github.io/2025/12/22/jax-101/
Author
gdymind
Posted on
December 22, 2025
Licensed under