Pallas 101 - multi-backend kernel for JAX
1. why pallas?
- JAX works with pure functions (i.e., same inputs will produce the same outputs). JAX arrays are immutable
- not flexible or efficient for kernel implementation
GEMM steps
- input matrix → tiles
- do GEMM with tile combinations
- write back results to the correct position
- loop over tiles
Difficulties with JAX
- need to update buffers for many many times (i.e., for accumulators)
- pure function’s update are out-of-place → inefficient copy
- need an efficient way to specify which part of the input/output is the current working set
- need to run kernel multiple times (with different tiles)
- need to manage memory explictly
Pallas’s new abstractions: Ref, grid, BlockSpec, and pallas_call
2. layered design
jax.experimental.pallasabstractions commonly used by all backendsjax.experimental.pallas.tpu as pltpu: for TPUjax.experimental.pallas.triton: for Tritonjax.experimental.pallas.mosaic_gpu: for Mosaic GPU
we usually import them as pl, pltpu, pltriton and plgpu
3. Pallas concepts
pl provides a unified name space. core compents:
| Component | Usage |
|---|---|
Ref |
mutable view of memory |
grid / program_id / num_programs |
program and index |
BlockSpec / GridSpec |
specification for tiling |
pallas_call |
the interface to integrate kernel into JAX |
pltpu, pltriton and plgpu provide backend-specific functions. examples
- TPU:
VMEM,SMEM - Triton:
atomic_add,load/storewith mask
4. Ref
- normal JAX function: input/output are both immutable JAX arrays
- Pallas: input/output are
Refs (references to mutable buffers)
1 | |
- read: use
ref[...]to retreive a JAX array from Ref - write:
ref[...] = value Refalso supports indexing/slicing likeref[0,:](the first row)- Kernel has no return values. It only operates on
Ref
sharp bits: explictly distinguish references and values
Refcannot be used as JAX array, you need to explictly read it like above- you cannot assign one
Refto anotherReflikeo_ref[...] = x_ref(no[...]on rhs)
5. Ref.at is still a Ref
In JAX or numpy, we use .at to set value directly
1 | |
However, .at in Ref just provides sub-view that is also a Ref
1 | |
use case 1: pass a portion of the current Ref to another function
1 | |
use case 2: to work with load / store with backends like Triton
1 | |
- this line of code focuses on efficiently moving data from global memory into the fast local memory of the processor
input_ref.at[...]: indexing- it specifies which part of the global input buffer you want to look at without actually loading the data yet
pl.ds(0, block_row): dynamic slicing for Pallas. the kernel is selecting a specific slice of the input, starting at index 0 and spanning a length ofblock_row
pltriton.load(...): the memory op- it takes the
Refsub-view created by.atand pulls that specific chunk of data from high-bandwidth memory (HBM) into registers or SRAM - the result,
row, becomes a standardjax.Array(or a Triton-compatible tensor) that you can perform math on.
- it takes the
mask=mask: the Guardrail- in tiled kernels, the “tile size” often doesn’t perfectly divide the “matrix size.”
- For example, if your tile is 64 but you only have 10 elements left, you don’t want to read out-of-bounds memory
- the
mask(usually a boolean array) tells the hardware which specific elements within that block are valid to read.
- in tiled kernels, the “tile size” often doesn’t perfectly divide the “matrix size.”
other=-float(inf): what value to use for the elements where themaskisFalse
6. grid & program
grid
- normal JAX program: excuete once and return
- Pallas kernel: execute multiple times following the grid pattern
grid dimension
- 1d grid:
pl.pallas_call(kernel, grid=(4,), out_shape=...)- conceptually, it means
for i in range(4): kernel(...)
- conceptually, it means
- 2d grid:
pl.pallas_call(kernel, grid=(3, 2), out_shape=...)conceptually it means
1
2
3for i in range(3):
for j in range(2):
kernel(...)in this example, the kernel is executed 6 times, where each execution is called a program
program id
we can get the grid of a program in a kernel
1
2
3def kernel(x_ref, o_ref):
i = pl.program_id(axis=0) # id of dimension 0
j = pl.program_id(axis=1) # id of dimension 1we can get the dimension len using
pl.num_programs(axis)1
2def kernel(x_ref, o_ref):
total_rows = pl.num_programs(axis=0) # length of dimension 0grid can be named:
grid=(('batch', 4), ('feature', 8))int grid will be promoted to tuple automatically:
pl.pallas_call(kernel, grid=4, out_shape=...)grid is a logical concept. it does NOT directly ****maps to physical hardware (like GPU blocks or TPU cores). it’s up to the backend to do the specific mapping
GridSpec: pack grid, in_specs. out_specs and scratch_shapes
1
2
3
4
5
6
7
8
9
10spec = pl.GridSpec(
grid=(4, 4),
# one BlockSpec for the single input
in_specs=pl.BlockSpec((128, 64), lambda i, j: (i, j)),
out_specs=pl.BlockSpec((128, 64), lambda i, j: (i, j)),
# temp buffer in TPU
scratch_shapes=[pltpu.VMEM((128, 64), jnp.float32)]
)
result = pl.pallas_call(kernel, grid_spec=spec, out_shape=...)(x)dynamic grid: some dimesions can be determined at run time
- in such cases, we can retrive dimension length by
num_programs(axis) - Triton and Mosaic GPU doesn’t support this
- Mosaic TPU supports this in certain cases
- in such cases, we can retrive dimension length by
7. BlockSpec: tile mapping
for each program, what’s its input and output?
BlockSpec maps program_id to tile positions
key fields of BlockSpec
block_shape: the tile shape that each program seesindex_map: a function. input is program id, output is block indicesmemory_space: where the tile is in the memory (related to backend)pipeline_mode: used to optimize pipelining
assume we have a (512, 256) matrix, and we want to tile it to (128, 64) chunks
the corresponding BlockSpec is
1 | |
block indices → tile position. for example, when (i, j) is (2, 3)
- tile start is (2128, 364) = (256, 192)
- tile end is (3128, 464) = (384, 256)
- tile range is [256:384, 192:256]
None in block_shape
- it means: this dimension is 1, and we squeeze (remove) the dimension
block_shape=(None,64): kernel sees shape(64,), not(1,64)- useful when processing row-by-row or column-by-column
default values
- the default
index_mapreturns0s as block indices - the default
block_shapereturns the full shape with no tiling
8. pallas_call
- input: kernel and its config
- output: a callable kernel
1 | |
some other params
interpret=True: simulate in CPU for debugingbackend:triton,mosaic_gpu,mosaic_tpucompiler_paramsinput_output_aliases: a specific output should reuse the memory space already allocated for a specific input
9. example: vector add
1 | |
- kernel params order: input Ref, output Ref, scratch Ref
in_specs: one BlockSpec for each input. If no tiling for an input, useNonefor it.out_shape: the shape and type of the output
10. scratch and temp buffer
- specify
scratch_shapesfor some temp Ref - these temp Ref are not input/outp
1 | |
- we can read/write scratch_ref, but they won’t be outputs
1 | |
- scratch type is dependant on backends
- TPU:
pltpu.VMEMandpltpu.SMEM- VMEM: used for VPU in TensorCore. 128MiB per TensorCore.
- SMEM: for scalar core within the TensorCore. Storing scalar values, addresses, and control flow related data.
- GPU:
plgpu.SMEM - Triton: not supported
pl.run_scoped: allocate scratch inside the kernel
1 | |
temp_refonly exist withinbody- easier to manage life cycles of temp Ref in this way
11. control flow
kernel control flow primitive: pl.when and pl.loop
- normal Python
ifandforwill be excuted during tracing pl.whenandpl.loop: convert to JAX IR node during tracing
pl.when
1 | |
pl.loop
1 | |
- both
pl.loopandlax.fori_loopcan be used for kernel loop - Python
forwill be unrolled during tracing (preferred for static and samll loops) lax.fori_loop/pl.loop: generate loop node in IR during tracing
12. boundary check
- example: array length = 500, block size = 128
- the last block will be padded with undefined values (we cannot assume the values in the padding)
- in a block, there must be at least one element within the boundary
13. compilation
stage 1: Python / JAX
- tracing: kernel → stateful jaxpr (JAX’s internal IR)
- why stateful? because it has side effects realted to Ref
- configs like
gridandBlockSpecare converted toGridMappingandBlockMapping
stage 2: backend-specific lowering
- jaxpr + grid_mapping are converted to backend IR
- GPU Triton: Triton MLIR
- GPU Mosaic: stable_mosaic_gpu MLIR
- TPU Mosaic: stable_mosaic MLIR
- these IRs will be serialized to bytecode, and be embedded into XLA custom call
stage 3: jaxlib/XLA
- compile to device code
- Triton: handled by XLA’s Triton compiler
- Mosaic GPU: to PTX (by jaxlib MLIR pipeline), then cubin (by PTXAS)
- Mosaic TPU: handled by TPU compiler
TPU Mosaic backend
jaxpr → stable_mosaic MLIR → custom call tpu_custom_call → TPU compiler
output: TPU executable
Source blog: Pallas:JAX 中有意思的多后端 Kernel 抽象