Pallas examples by Sharad Vikram (Pallas author)

https://www.youtube.com/watch?v=NFKubflDb1A

  • code was written in 2023 (may be slightly outdated, but the core concept still valid)
  • presented by Sharad Vikram (Pallas author)

1. TPU architecture recap

image.png

  • VMEM: like L1 cache for TPU
  • VREGs: vector registers
  • HBM → VMEM → vector registers (VREGs in 8x128 unit)
  • scalar unit & SMEM: like the regular CPU
  • this tutorial focuses on VMEM examples, not SMEM

2. add_vectors example

2.1 code

1
2
3
4
5
6
7
8
9
10
import functools
import jax
import jax.numpy as jnp
from jax import random
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
jax.devices()

def add_vectors_kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
  • by default, the refs are VMEM refs. you can think them roughly as pointers to VMEM slots
  • intput Refs (x_ref[...] and y_ref[...]) vector-load data from VMEM to VREGs
  • output Ref o_ref[...] vector-store data back to the VMEM slots
1
2
3
4
5
@functools.partial(jax.jit, static_argnames=["debug"])
def add_vectors(x: jax.Array, y: jax.Array, *, debug: bool = False) -> jax.Array:
return pl.pallas_call(add_vectors_kernel, out_shape=x, debug=debug)(
x, y
)
  • the code above does the orchestration
  • it moves data from HBM to VMEM, excute the kernel, and copy output buffer back to HBM
  • jax.jit: trace your function and compiles it with XLA
    • tracing: replace array arguments with abstract “tracers” (no concrete values, only shape and dtype)
    • purpose of tracing: compile once and reuse for different inputs

2.2 lambda and module blocks

let’s check the debug output (skip this section if not interested in details)

1
2
add_vectors(jnp.arange(8), jnp.arange(8), debug=True)
# Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)

High-Level Pallas IR (lambda block)

1
2
3
4
5
6
{ **lambda** ; a:Ref{int32[8]} b:Ref{int32[8]} c:Ref{int32[8]}. let
d:i32[8] <- a[:]
e:i32[8] <- b[:]
f:i32[8] = add d e
c[:] <- f
in () }
  • lambda ; a:Ref{int32[8]} ...: function args. Ref type indicates these are pointers to VMEM
  • d:i32[8] <- a[:]: load op. moves 8 integers from a (VMEM) into d (VREG).
  • f:i32[8] = add d e: vector addition of d and e, storing the result in register f
  • c[:] <- f: writes data from register f back into VMEM slot pointed to by c

Lower-Level MLIR (Mosaic ML module block)

the lambda block above is lowered to Mosaic ML module

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
**module** {
func.func @main(
%arg0: **memref**<8xi32, #tpu.memory_space<vme>>,
%arg1: memref<8xi32, #tpu.memory_space<vme>>,
%arg2: memref<8xi32, #tpu.memory_space<vme>>) {
%c0 = arith.constant 0 : index
%0 = **vector.load** %arg0[%c0] : memref<8xi32, #tpu.memory_space<vme>>, vector<8xi32>
%c0_0 = arith.constant 0 : index
%1 = vector.load %arg1[%c0_0] : memref<8xi32, #tpu.memory_space<vme>>, vector<8xi32>
%2 = **arith.addi** %0, %1 : vector<8xi32>
%c0_1 = arith.constant 0 : index
**vector.store** %2, %arg2[%c0_1] : memref<8xi32, #tpu.memory_space<vme>>, vector<8xi32>
return
}
}
  • memref<8xi32, #tpu.memory_space<vme>>
    • arg type: VMEM memory reference
  • vector.load
    • load a chunk of data into its registers in a single cycle
    • %c0 (constant 0) indicates it is loading from the beginning of the buffer
  • arith.addi
    • “Arithmetic Add Integer” instruction performed on the vector registers
    • happens in MXU or VPU
  • vector.store: This moves the calculated vector result back from the registers into the output buffer in VMEM

2.3 HLO

a serialized version of the Mosaic ML module above

1
print(jax.jit(add_vectors).lower(jnp.arange(8), jnp.arange(8)).compiler_ir())

HLO block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
**module** @jit_add_vectors attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public **@main**(
%arg0: tensor<8xi32> {mhlo.sharding = "{replicated}"},
%arg1: tensor<8xi32> {mhlo.sharding = "{replicated}"})
-> %arg0 : tensor<8xi32> {
%0 = call @add_vectors(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}

func.func private **@add_vectors**(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
%0 = call @wrapped(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}

func.func private **@wrapped**(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
%0 = call @apply_kernel(%arg0, %arg1) : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}

func.func private **@apply_kernel**(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> tensor<8xi32> {
%0 = **stablehlo.custom_call** **@tpu_custom_call**(%arg0, %arg1) {backend_config = "..."} : (tensor<8xi32>, tensor<8xi32>) -> tensor<8xi32>
return %0 : tensor<8xi32>
}
}
  1. Entry Point (@main): what the XLA compiler sees when it schedules the task
  • tensor<8xi32>: unlike the memref in the lower-level MLIR (which points to a specific physical memory location), these are high-level XLA tensors
  • mhlo.sharding = "{replicated}": data is replicated across all TPU cores in the partition, rather than being sharded btw them
  1. Wrapper Chain (@add_vectors@wrapped@apply_kernel)
  • JAX often generates several private functions during the lowering process. act as administrative wrappers that handle:
    • Shape checking: ensure the input tensors match the kernel’s expected dimensions
    • Transformation state: maintain consistency if the kernel was transformed by other JAX primitives like vmap or grad
  1. Custom Call (@tpu_custom_call)
  • the most important part
  • stablehlo.custom_call
    • since “Pallas kernels” are not standard XLA ops (like a normal add or dot), XLA treats them as a “custom call.”
    • it basically says: “I don’t know how to execute this myself, so I’m going to hand it off to a specialized backend.”
  • @tpu_custom_call: targets the TPU’s custom execution pipeline
  • backend_config = "...":
    • in a real output, this string is very long. It contains the serialized Mosaic MLIR module (the one with vector.load and arith.addi)
    • when the TPU compiler receives this HLO, it unpacks this config string to find the exact instructions for MXU

3. matmul w/ BlockSpec

we want to show how we stream data chunks (rather than all the data)

kernel definition

1
2
3
4
5
def matmul_kernel(x_ref, y_ref, o_ref):
@pl.when(pl.program_id(2) == 0)
def initialize_output():
o_ref[...] = jnp.zeros_like(o_ref)
o_ref[...] += jnp.dot(x_ref[...], y_ref[...])
  • each execution is called a program
  • program_id: the index of the program

and the wrapper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@functools.partial(jax.jit, static_argnames=["bm", "bk", "bn", "debug"])
def matmul(
x: jax.Array, y: jax.Array, *, bm: int = 128, bk: int = 128, bn: int = 128,
debug: bool = False,
) -> jax.Array:
m, k = x.shape
k, n = y.shape
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
in_specs=[
pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)),
pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)),
],
out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)),
grid=(m // bm, n // bn, k // bk),
debug=debug,
)(x, y)
  • static_argnames
    • it tells JAX “these args are not traceable. treat their concrete values as compile-time constants.”
    • when values changes, JAX recompiles
  • BlockSpec
    • w/o BlockSpec: copy all data from HBM to VMEM, do all the computes, then copy all outputs back to HBM
    • w/ BlockSpec :we stream the chunks of x, y using Blocks instead
  • key fields
    • block_shape in BlockSpec: the tile shape that each program sees, like (bm, bk)
    • grid: how many times the kernel runs + indexing scheme
    • index_map in BlockSpec: maps index of a grid (program_id) to a specific block (tile)
  • this matmul implementation is pretty close to a peak-performance f32 matmul on TPU
  • 3d grid vs 2d grid
    • this implementation uses a 3d grid
      • X and Y tiles shift along the K axis as k increments (the sliding window over the contraction dimension)
      • output tile (i, j) is fixed for all values of k
    • if we change it to 2d grid on TPU
      • we need to unroll the accumulation among k inside the kernel
      • requires loading the entire row/column into VMEM at once, which leads to more copies into VMEM
      • worse performance → we want a better compute-data overlap in the pipeline
    • on GPU, in many cases we want 2d grid
  • tuning block size is a major work (i.e., how much data should be transfered while doing this amount of computation)
  • TPU vs GPU: parallel vs sequential
    • GPU: kernels are launched in parallel, so the accumultion of o_ref is a sync point (it will be lowered to atomic updates)
    • TPU: kernels are launched sequentially. no need to worry about contention for o_ref

4. batched matmul w/ vmap

batched matmul: add a new dimention using vmap

1
2
3
4
5
6
batched_matmul = jax.vmap(functools.partial(matmul, bm=1024, bk=512, bn=1024))

k1, k2 = random.split(random.PRNGKey(1))
x = random.normal(k1, (**4**, 1024, 2048))
y = random.normal(k2, (**4**, 2048, 4096))
print(jax.block_until_ready(batched_matmul(x, y)))
  • we have a batch dimesion whose size = 4

5. matmul_activation: higher order function

meta-programming: we can inject a higher-order function (activation)

  • higher-order funciton: takes function as input arg, and/or return function as output
  • activation can be ReLU or something
  • we can parametrize matmul with an activation function

kernel definition

1
2
3
4
5
6
def matmul_activation_kernel(x_ref, y_ref, o_ref, *, nk, activation):
matmul_kernel(x_ref, y_ref, o_ref)
if activation:
@pl.when(pl.program_id(2) == nk - 1)
def activate():
o_ref[...] = activation(o_ref[...])
  • pl.program_id(2) == nk - 1: activation only runs after the final reduction step in tiled matmul

and the wrapper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@functools.partial(jax.jit, static_argnames=["bm", "bk", "bn", "activation"])
def matmul(
x: jax.Array, y: jax.Array, *, bm: int = 128, bk: int = 128, bn: int = 128,
activation = None,
) -> jax.Array:
m, k = x.shape
k, n = y.shape
return pl.pallas_call(
functools.partial(matmul_activation_kernel, nk=k // bk, activation=activation),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
in_specs=[
pl.BlockSpec(lambda i, j, k: (i, k), (bm, bk)),
pl.BlockSpec(lambda i, j, k: (k, j), (bk, bn)),
],
out_specs=pl.BlockSpec(lambda i, j, k: (i, j), (bm, bn)),
grid=(m // bm, n // bn, k // bk),
)(x, y)

6. softmax

  • softmax is tricky to get right → involve several reductions
  • a naive implementaion here: load the entire feature x_ref[...] to VMEM, and do all the redcutions in VMEM
  • when cannot fit feature into VMEM, or wannt fuse softmax with other matmul → use flash attention

define kernel

1
2
3
4
5
6
def softmax_kernel(x_ref, o_ref):
x = x_ref[...]
x_max = x.max(axis=-1)
s = jnp.exp(x - x_max[..., None])
l = s.sum(axis=-1)
o_ref[...] = s / l[..., None]

wrapper

1
2
3
4
5
6
7
8
9
10
11
12
@functools.partial(jax.jit, static_argnames=["bs", "activation"])
def softmax( x: jax.Array, bs: int = 128) -> jax.Array:
n, d = x.shape
return pl.pallas_call(
softmax_kernel,
out_shape=jax.ShapeDtypeStruct((n, d), x.dtype),
in_specs=[
pl.BlockSpec(lambda i: (i, 0), (bs, d)),
],
out_specs=pl.BlockSpec(lambda i: (i, 0), (bs, d)),
grid=(n // bs),
)(x)

7. flash attention

  • here is a minimal flash attention implementation
  • only do f32 for a single sequence
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def flash_attention_kernel(q_ref, k_ref, v_ref, m_ref, l_ref, o_ref):
@pl.when(pl.program_id(1) == 0)
def init():
o_ref[...] = jnp.zeros_like(o_ref)
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)

q, k, v = q_ref[...], k_ref[...], v_ref[...]
m_prev, l_prev = m_ref[...], l_ref[...]

qk = jnp.dot(q, k.T)
m_curr = qk.max(axis=-1)
s_curr = jnp.exp(qk - m_curr[..., None])
l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
o_curr = jnp.dot(s_curr, v) / l_curr

m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr

m_ref[...], l_ref[...] = m_next, l_next
o_ref[...] = (l_prev * alpha * o_ref[...] + l_curr * beta * o_curr) / l_next
  • m_ref: previous max
  • l_ref: previous sum of exponentials
  • jax.lax.broadcast_in_dim(operand, shape, broadcast_dimensions)
    • reshapes/broadcasts a tensor to a target shape
    • while specifying which dimensions of the source map to which dimensions of the output
  • l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
    • s_curr has shape [Bq, Bk]
    • after .sum(axis=-1), you get sum of exponentials per query row. shape [Bq]
    • l_prev likely has shape [Bq, 1]
    • this call takes the 1D [Bq] result and broadcasts it into shape [Bq, 1]
    • with (0,) saying “axis 0 of the source maps to axis 0 of the target.

schedule it using a 2d grid

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@functools.partial(jax.jit, static_argnames=["bq", "bk"])
def flash_attention(q, k, v, *, bq: int, bk: int):
seq_len, head_dim = q.shape
return pl.pallas_call(
flash_attention_kernel,
out_shape=[
jax.ShapeDtypeStruct((seq_len, head_dim), q.dtype), # l
jax.ShapeDtypeStruct((seq_len, head_dim), q.dtype), # m
q, # o
],
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (bq, head_dim)), # q
pl.BlockSpec(lambda i, j: (j, 0), (bk, head_dim)), # k
pl.BlockSpec(lambda i, j: (j, 0), (bk, head_dim)), # v
],
out_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (bq, head_dim)), # l
pl.BlockSpec(lambda i, j: (i, 0), (bq, head_dim)), # m
pl.BlockSpec(lambda i, j: (i, 0), (bq, head_dim)), # o
],
grid=(seq_len // bq, seq_len // bk),
)(q, k, v)[
  • Q tiles index by i: for a given query block, Q stays fixed while we sweep over K/V
  • K and V tiles index by j: they advance together through the sequence.

call it

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
seq_len = 32768
head_dim = 128
q = random.normal(k1, (seq_len, head_dim))
k = random.normal(k2, (seq_len, head_dim))
v = random.normal(k3, (seq_len, head_dim))

@jax.jit
def attention_reference(q, k, v):
logits = jnp.einsum('sd,td->st', q, k)
s = jax.nn.softmax(logits, axis=-1)
return jnp.einsum('st,td->sd', s, v)

%timeit attention_reference(q, k, v).block_until_ready()

fa = functools.partial(flash_attention, bq=512, bk=2048)

%timeit fa(q, k, v).block_until_ready()

8. batched MHA / MQA: nested vmap

8. 1MHA

we can just vmap the single-sequence one above to get mha

1
2
3
4
5
6
7
8
9
10
11
12
mha = jax.vmap(fa)

# call mha
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
num_heads = 8
seq_len = 32768
head_dim = 128
q = random.normal(k1, (num_heads, seq_len, head_dim)) * 0.1
k = random.normal(k2, (num_heads, seq_len, head_dim)) * 0.1
v = random.normal(k3, (num_heads, seq_len, head_dim)) * 0.1

mha(q, k, v).block_until_ready()

8.2 batched MHA

vmap again to get batched mha

1
2
3
4
5
6
7
8
9
10
11
12
13
batched_mha = jax.vmap(mha)

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
batch_size = 2
num_heads = 8
seq_len = 32768
head_dim = 128

q = random.normal(k1, (batch_size, num_heads, seq_len, head_dim))
k = random.normal(k2, (batch_size, num_heads, seq_len, head_dim))
v = random.normal(k3, (batch_size, num_heads, seq_len, head_dim))

print(jax.block_until_ready(batched_mha(q, k, v)))

8.3 MQA

vmap for q , not k and v

1
2
3
4
5
6
7
8
9
10
11
12
mqa = jax.vmap(fa, in_axes=(0, None, None))

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
num_heads = 8
seq_len = 32768
head_dim = 128

q = random.normal(k1, (num_heads, seq_len, head_dim))
k = random.normal(k2, (seq_len, head_dim))
v = random.normal(k3, (seq_len, head_dim))

print(jax.block_until_ready(mqa(q, k, v)))
  • in_axes:
    • None: Don’t map over k and v (k and v are shared across different q)
    • 0: Map over the 0-th axis of q (the batch dimension).`

8.4 batched MQA

vmap again for the new batch_size dimension

1
2
3
4
5
6
7
8
9
10
11
12
batched_mqa = jax.vmap(mqa)

k1, k2, k3 = random.split(random.PRNGKey(0), 3)
batch_size = 2
num_heads = 8
seq_len = 32768
head_dim = 128
q = random.normal(k1, (batch_size, num_heads, seq_len, head_dim)) * 0.1
k = random.normal(k2, (batch_size, seq_len, head_dim)) * 0.1
v = random.normal(k3, (batch_size, seq_len, head_dim)) * 0.1

print(jax.block_until_ready(batched_mqa(q, k, v)))

9. multiple TPUs: use shard_map

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
mesh = jax.sharding.Mesh(jax.devices(), ['x'])

@jax.jit
@functools.partial(shard_map, mesh=mesh, in_specs=(P('x'), P('x'), P('x')),
out_specs=P('x'), check_rep=False)
def sharded_attention(q, k, v):
print(q.shape) # the batch dimesion will be evenly sharded along 'x'
return batched_mqa(q, k, v)

# Execution
k1, k2, k3 = random.split(random.PRNGKey(0), 3)
batch_size = 8
num_heads = 8
seq_len = 32768
head_dim = 128

q = random.normal(k1, (batch_size, num_heads, seq_len, head_dim)) * 0.1
k = random.normal(k2, (batch_size, seq_len, head_dim)) * 0.1
v = random.normal(k3, (batch_size, seq_len, head_dim)) * 0.1

print(sharded_attention(q, k, v))

10. DMA: compute-data overlap

below is a memory copy kernel: tensor trip is HBM → VMEM → HBM

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def kernel(x_hbm_ref, y_hbm_ref):
def body(x_ref, y_ref, sem):
pltpu.async_copy(x_hbm_ref, x_ref, sem)**.wait()** # HBM->VMEM
y_ref[...] = x_ref[...]
pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() # VMEM->HBM
**pltpu.run_scoped**(body,
pltpu.VMEM((8, 128), jnp.float32),
pltpu.VMEM((8, 128), jnp.float32),
pltpu.SemaphoreType.DMA)

x = jnp.arange(8 * 128.).reshape((8, 128))
y = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=**pltpu.TPUMemorySpace.ANY**)],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
)(x)
print(x)
print(y)
  • pltpu.TPUMemorySpace.ANY
    • runtime decides where to place the top-level buffers
    • we don’t really care about it, because kernel itself manages the VMEM staging manually
    • we have to conservtively assume it’s in HBM, and manually schedule transfers into VMEM
  • pltpu.run_scoped allocates scoped resources and passes them into body:
    • x_ref, y_ref: (8, 128) f32 scratch buffers in VMEM
    • sem: a DMA semaphore for synchronizing async copies. do things concurrently while we are waiting for memory transfers
    • run_scoped allocates semaphores with automatic lifetime management
      • creates semaphores → passes them to body → clean up afterward
  • .wait() calls make this synchronous. in a real pipeline you’d overlap compute with async DMAs by deferring the waits
  • async_copy() also works witgh SMEM
    • the values you get from SMEM can be used to schedule another DMA
    • so it offers another degree of dynamism and sparsity

11. collectives w/ Pallas

in JAX, you may have seen some collectives

  • lax.psum(): sums a value across all devices in a named axis
    • similarly we have lax.pmax() / lax.pmin() / lax.pmean()
  • lax.all_gather(): each device starts with a small array and ends up with a large array containing concatenated data
  • lax.all_to_all(): each device splits its input into chunks and sends one chunk to every other device (e.g., to change from DP sharding to TP sharding)
  • lax.ppermute(): moves data along a fixed communication ring. You specify a list of (source_index, destination_index) pairs

here is a very simple ppermute() implemented with Pallas

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from jax.experimental import mesh_utils, shard_map
from jax import lax
import numpy as np

def kernel(x_ref, y_ref, *, direction):
def body(ready_sem, send_sem, recv_sem):
my_id = lax.axis_index('x')
num_devices = lax.psum(1, 'x')
if direction == 'right':
neighbor = lax.rem(my_id + 1, num_devices)
else:
neighbor = lax.rem(my_id - 1, num_devices)
# Neighbor might be negative here so we add num_devices in case
neighbor = jnp.where(neighbor < 0, neighbor + num_devices, neighbor)
pltpu.semaphore_signal(ready_sem, device_id=neighbor)
pltpu.semaphore_wait(ready_sem)
send_done, recv_done = pltpu.async_remote_copy(
x_ref, y_ref, send_sem, recv_sem, device_id=neighbor
)
send_done.wait()
recv_done.wait()

pltpu.run_scoped(body, pltpu.SemaphoreType.REGULAR,
pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA)
  • assume you have a ring device topology, you may want to send your data to your left (or right) neighbor device

steps

  • query the current axis index: my_id = lax.axis_index('x')
  • then compute your neighbor index: lax.rem(my_id + 1, num_devices)
  • sync btw you and your neighbor
    • pltpu.semaphore_signal(ready_sem)
      • tells the neighbor “I’m ready”
      • ready_sem is a REGULAR (non-DMA) semaphore, used purely for sync
    • pltpu.semaphore_wait(): blocks until the neighbor has signaled back
  • async_remote_copy()
    • initiates a DMA transfer of x_ref on this device to y_ref on the neighbor device
    • it returns two futures. two DMA semaphores (send_sem, recv_sem) track completion of the send and receive sides independently
  • wait for completion
    • send_done.wait() and recv_done.wait() block until the DMA has finished on both ends
    • assume we are transfering clockwise
      • send_done.wait(): you send to your right neighbor
      • recv_done.wait(): your left neighbor sends to you

useful when you need custom compute-communication overlap

  • e.g., pipelining KV-cache transfers with attention computation

or when the standard lax.ppermute / lax.all_gather don’t give you the scheduling flexibility you need for latency-sensitive serving


Pallas examples by Sharad Vikram (Pallas author)
https://gdymind.github.io/2026/03/08/Pallas-examples-by-Pallas-author-Sharad-Vikram/
Author
gdymind
Posted on
March 8, 2026
Licensed under