Pallas 101 - multi-backend kernel for JAX

1. why pallas?

  • JAX works with pure functions (i.e., same inputs will produce the same outputs). JAX arrays are immutable
  • not flexible or efficient for kernel implementation

GEMM steps

  • input matrix → tiles
  • do GEMM with tile combinations
  • write back results to the correct position
  • loop over tiles

Difficulties with JAX

  • need to update buffers for many many times (i.e., for accumulators)
    • pure function’s update are out-of-place → inefficient copy
  • need an efficient way to specify which part of the input/output is the current working set
  • need to run kernel multiple times (with different tiles)
  • need to manage memory explictly

Pallas’s new abstractions: Ref, grid, BlockSpec, and pallas_call

2. layered design

  • jax.experimental.pallas abstractions commonly used by all backends
  • jax.experimental.pallas.tpu as pltpu: for TPU
  • jax.experimental.pallas.triton: for Triton
  • jax.experimental.pallas.mosaic_gpu: for Mosaic GPU

we usually import them as pl, pltpu, pltriton and plgpu

3. Pallas concepts

pl provides a unified name space. core compents:

Component Usage
Ref mutable view of memory
grid / program_id / num_programs program and index
BlockSpec / GridSpec specification for tiling
pallas_call the interface to integrate kernel into JAX

pltpu, pltriton and plgpu provide backend-specific functions. examples

  • TPU: VMEM, SMEM
  • Triton: atomic_add, load/store with mask

4. Ref

  • normal JAX function: input/output are both immutable JAX arrays
  • Pallas: input/output are Refs (references to mutable buffers)
1
2
3
4
5
6
7
# normal JAX fuction
def jax_function(x: jax.Array) -> jax.Array:
return x * 2

# Pallas kernel
def pallas_kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] * 2
  • read: use ref[...] to retreive a JAX array from Ref
  • write: ref[...] = value
  • Ref also supports indexing/slicing like ref[0,:] (the first row)
  • Kernel has no return values. It only operates on Ref

sharp bits: explictly distinguish references and values

  • Ref cannot be used as JAX array, you need to explictly read it like above
  • you cannot assign one Ref to another Ref like o_ref[...] = x_ref (no [...] on rhs)

5. Ref.at is still a Ref

In JAX or numpy, we use .at to set value directly

1
new_array = array.at[0].set(value)

However, .at in Ref just provides sub-view that is also a Ref

1
2
3
sub_ref = x_ref.at[0:64] # still a Ref
value = sub_ref[...]
sub_ref[...] = new_value

use case 1: pass a portion of the current Ref to another function

1
2
3
4
5
6
7
8
def process_tile(in_tile_ref, out_tile_ref):
out_tile_ref[...] = in_tile_ref[...] * 2

def kernel(x_ref, o_ref):
for i in range(4):
in_tile = x_ref.at[i*64:(i+1)*64, :]
out_tile = o_ref.at[i*64:(i+1)*64, :]
process_tile(in_tile, out_tile)

use case 2: to work with load / store with backends like Triton

1
2
row = pltriton.load(input_ref.at[pl.ds(0, block_row)],
mask=mask, other=-float(inf))
  • this line of code focuses on efficiently moving data from global memory into the fast local memory of the processor
  • input_ref.at[...]: indexing
    • it specifies which part of the global input buffer you want to look at without actually loading the data yet
    • pl.ds(0, block_row): dynamic slicing for Pallas. the kernel is selecting a specific slice of the input, starting at index 0 and spanning a length of block_row
  • pltriton.load(...): the memory op
    • it takes the Ref sub-view created by .at and pulls that specific chunk of data from high-bandwidth memory (HBM) into registers or SRAM
    • the result, row, becomes a standard jax.Array (or a Triton-compatible tensor) that you can perform math on.
  • mask=mask: the Guardrail
    • in tiled kernels, the “tile size” often doesn’t perfectly divide the “matrix size.”
      • For example, if your tile is 64 but you only have 10 elements left, you don’t want to read out-of-bounds memory
    • the mask (usually a boolean array) tells the hardware which specific elements within that block are valid to read.
  • other=-float(inf): what value to use for the elements where the mask is False

6. grid & program

grid

  • normal JAX program: excuete once and return
  • Pallas kernel: execute multiple times following the grid pattern

grid dimension

  • 1d grid: pl.pallas_call(kernel, grid=(4,), out_shape=...)
    • conceptually, it means for i in range(4): kernel(...)
  • 2d grid: pl.pallas_call(kernel, grid=(3, 2), out_shape=...)
    • conceptually it means

      1
      2
      3
      for i in range(3):
      for j in range(2):
      kernel(...)
    • in this example, the kernel is executed 6 times, where each execution is called a program

program id

  • we can get the grid of a program in a kernel

    1
    2
    3
    def kernel(x_ref, o_ref):
    i = pl.program_id(axis=0) # id of dimension 0
    j = pl.program_id(axis=1) # id of dimension 1
  • we can get the dimension len using pl.num_programs(axis)

    1
    2
    def kernel(x_ref, o_ref):
    total_rows = pl.num_programs(axis=0) # length of dimension 0
  • grid can be named: grid=(('batch', 4), ('feature', 8))

  • int grid will be promoted to tuple automatically: pl.pallas_call(kernel, grid=4, out_shape=...)

  • grid is a logical concept. it does NOT directly ****maps to physical hardware (like GPU blocks or TPU cores). it’s up to the backend to do the specific mapping

  • GridSpec: pack grid, in_specs. out_specs and scratch_shapes

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    spec = pl.GridSpec(
    grid=(4, 4),
    # one BlockSpec for the single input
    in_specs=pl.BlockSpec((128, 64), lambda i, j: (i, j)),
    out_specs=pl.BlockSpec((128, 64), lambda i, j: (i, j)),
    # temp buffer in TPU
    scratch_shapes=[pltpu.VMEM((128, 64), jnp.float32)]
    )

    result = pl.pallas_call(kernel, grid_spec=spec, out_shape=...)(x)
  • dynamic grid: some dimesions can be determined at run time

    • in such cases, we can retrive dimension length by num_programs(axis)
    • Triton and Mosaic GPU doesn’t support this
    • Mosaic TPU supports this in certain cases

7. BlockSpec: tile mapping

for each program, what’s its input and output?

BlockSpec maps program_id to tile positions

key fields of BlockSpec

  • block_shape: the tile shape that each program sees
  • index_map: a function. input is program id, output is block indices
  • memory_space: where the tile is in the memory (related to backend)
  • pipeline_mode: used to optimize pipelining

assume we have a (512, 256) matrix, and we want to tile it to (128, 64) chunks

the corresponding BlockSpec is

1
2
3
4
pl.BlockSpec(
block_shape=(128, 64),
index_map=lambda i, j: (i, j) # program_id to block indices
)

block indices → tile position. for example, when (i, j) is (2, 3)

  • tile start is (2128, 364) = (256, 192)
  • tile end is (3128, 464) = (384, 256)
  • tile range is [256:384, 192:256]

None in block_shape

  • it means: this dimension is 1, and we squeeze (remove) the dimension
  • block_shape=(None,64): kernel sees shape (64,), not (1,64)
  • useful when processing row-by-row or column-by-column

default values

  • the default index_map returns 0s as block indices
  • the default block_shape returns the full shape with no tiling

8. pallas_call

  • input: kernel and its config
  • output: a callable kernel
1
2
3
4
5
6
7
8
9
callable_kernel = pl.pallas_call(
kernel,
grid=...,
in_specs=...,
out_specs=...,
out_shape=...,
...
)
result = callable_kernel(x, y, ...)

some other params

  • interpret=True: simulate in CPU for debuging
  • backend: triton, mosaic_gpu, mosaic_tpu
  • compiler_params
  • input_output_aliases: a specific output should reuse the memory space already allocated for a specific input

9. example: vector add

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
x = x_ref[...]
y = y_ref[...]
result = x + y
o_ref[...] = result

def add_vectors(x, y):
block_size = 128
n = x.shape[0]
grid_size = (n + block_size - 1) // block_size # ceiling

return pl.pallas_call(
add_kernel,
grid=(grid_size,),
in_specs=(
pl.BlockSpec((block_size,), lambda i: (i,)), # for x
pl.BlockSpec((block_size,), lambda i: (i,)), # for y
),
out_specs=pl.BlockSpec((block_size,), lambda i: (i,)),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
  • kernel params order: input Ref, output Ref, scratch Ref
  • in_specs: one BlockSpec for each input. If no tiling for an input, use None for it.
  • out_shape: the shape and type of the output

10. scratch and temp buffer

  • specify scratch_shapes for some temp Ref
  • these temp Ref are not input/outp
1
2
3
4
5
6
7
8
result = pl.pallas_call(
kernel,
grid=(4,),
in_specs=[...],
out_specs=...,
out_shape=...,
scratch_shapes=[pltpu.VMEM((128,), jnp.float32)] # a temp buffer
)(x)
  • we can read/write scratch_ref, but they won’t be outputs
1
2
3
def kernel(x_ref, o_ref, scratch_ref):
scratch_ref[...] = x_ref[...] * 2
o_ref[...] = scratch_ref[...] + 1
  • scratch type is dependant on backends
  • TPU: pltpu.VMEM and pltpu.SMEM
    • VMEM: used for VPU in TensorCore. 128MiB per TensorCore.
    • SMEM: for scalar core within the TensorCore. Storing scalar values, addresses, and control flow related data.
  • GPU: plgpu.SMEM
  • Triton: not supported

pl.run_scoped: allocate scratch inside the kernel

1
2
3
4
5
def kernel(x_ref, o_ref):
def body(temp_ref):
temp_ref[...] = x_ref[...] * 2
o_ref[...] = temp_ref[...] + 1
pl.run_scoped(body, pltpu.VMEM((128,), jnp.float32))
  • temp_ref only exist within body
  • easier to manage life cycles of temp Ref in this way

11. control flow

kernel control flow primitive: pl.when and pl.loop

  • normal Python if and for will be excuted during tracing
  • pl.when and pl.loop: convert to JAX IR node during tracing

pl.when

1
2
3
4
5
def kernel(x_ref, o_ref):
@pl.when(pl.program_id(0) == 0)
def _():
# only init at the 1st program
o_ref[...] = jnp.zeros((128,), dtype=jnp.float32)

pl.loop

1
2
3
4
5
6
7
8
def kernel(x_ref, o_ref):
acc = jnp.zeros((64,), dtype=jnp.float32)

def body(i, acc):
return acc + x_ref[i, :]

result = jax.lax.fori_loop(0, 4, body, acc)
o_ref[...] = result
  • both pl.loop and lax.fori_loop can be used for kernel loop
  • Python for will be unrolled during tracing (preferred for static and samll loops)
  • lax.fori_loop / pl.loop: generate loop node in IR during tracing

12. boundary check

  • example: array length = 500, block size = 128
  • the last block will be padded with undefined values (we cannot assume the values in the padding)
  • in a block, there must be at least one element within the boundary

13. compilation

stage 1: Python / JAX

  • tracing: kernel → stateful jaxpr (JAX’s internal IR)
    • why stateful? because it has side effects realted to Ref
  • configs like grid and BlockSpec are converted to GridMapping and BlockMapping

stage 2: backend-specific lowering

  • jaxpr + grid_mapping are converted to backend IR
    • GPU Triton: Triton MLIR
    • GPU Mosaic: stable_mosaic_gpu MLIR
    • TPU Mosaic: stable_mosaic MLIR
  • these IRs will be serialized to bytecode, and be embedded into XLA custom call

stage 3: jaxlib/XLA

  • compile to device code
  • Triton: handled by XLA’s Triton compiler
  • Mosaic GPU: to PTX (by jaxlib MLIR pipeline), then cubin (by PTXAS)
  • Mosaic TPU: handled by TPU compiler

TPU Mosaic backend

jaxpr → stable_mosaic MLIR → custom call tpu_custom_call → TPU compiler

output: TPU executable

Source blog: Pallas:JAX 中有意思的多后端 Kernel 抽象


Pallas 101 - multi-backend kernel for JAX
https://gdymind.github.io/2026/02/19/Pallas-101/
Author
gdymind
Posted on
February 19, 2026
Licensed under