XLA01 - architecture & workflows

https://openxla.org/xla

0. Intro

XLA in the whole JAX stack

image.png

Source: Yi Wang’s linkedin post

  • LLM is basically matmul. XLA (Accelerated Linear Algebra) optimizes linear algebra on multiple decives (TPU / GPU)
    • XLA works with n-d arrays of the same data type
  • XLA compiler: model code (PyTorch, JAX, Tensorflow) → optimized machine instructions on GPU / TPU / CPU, etc.
  • XLA uses MLIR (Multi-Level Intermediate Representation) to build a single compiler toolchain
    • MLIR is a hybrid IR infra that allows users to define “dialects” of opes at varying degrees of abstraction and gradually lower between these opsets
    • MLIR performs transformations at each level of granularity
    • StableHLO and CHLO are two examples of MLIR dialects.
  • Example of XLA ops: Add / Abs / AllToAll / AllReduce / CollectivePermute / Sort …

1. XLA input: StableHLO

  • model graphs is defined in StableHLO, which is XLA’s input
  • HLO = high level operations
    • an internal graph representation (IR) for the XLA compiler
    • “internal dialect” used by the XLA compiler to perform optimizations
    • not designed to be stable. may change frequently
  • StableHLO = public-interface HLO
    • a portability layer or “bridge” between ML frameworks (JAX / PyTorch) and the compiler
    • ML frameworks that produce StableHLO programs are compatible with ML compilers that consume StableHLO programs
    • versioned op set of HLOs → model graphs remain valid even as the underlying compiler evolves
    • provide backward and forward compatibility

2. XLA workflows

  1. StableHLO → internal HLO dialect: target-independent optimization on StableHLO. examples
    1. common subexpression elimination
    2. target-independent op fusion
    3. buffer analysis: allocate runtime memory
  2. HLO sent to backend: target-specific optimization
    1. example: GPU may optimize op fusion for CUDA streams
  3. Target-specific code generation: done by backend

3. HLO to excutable (GPU)

  • The journey from initial HLO all the way to machine executable
  • Depends on the context, HLO can refer to “HLO modules”
  • Thunk: a self-contained unit of work that the runtime executes
  • LLVM: compiler backend, and a language that it takes as an input
    • many compilers generate LLVM code as a first step, and then LLVM generates machine code from it
    • allows reusing code that is similar in different compilers

image.png

3.1 Pre-optimization HLO

  • Pre-optimization HLO: no interal XLA ops (like fusion and bitcast) included
  • Ops don’t have a layout at this stage. Or if they do, it will be ignored
  • produced by JAX / pytorch
  • use -xla_dump_to XLA flag to dump Pre-optimization HLO to a file ended with before_optimizations.txt

image.png

3.2 optimize HLO Module

XLA GPU vs TPU pipeline

Feature XLA:GPU Pipeline XLA:TPU Pipeline
Lowering Target Thunks & LLVM IR TPU-specific Machine Code (Binary)
Code Generation IrEmitterUnnested + LLVM (PTX/HSACO) Proprietary TPU Compiler Backend
Runtime Abstraction ThunkSequence TpuExecutable (often via PJRT)
Optimization Focus Kernel fusion, memory coalescing Systolic array utilization, interconnect scaling

here we only introcuces XLA:GPU pipeline

pre-optimization HLO → optimized HLO is done by serveral passes

passes execution orders:

  1. sharding related passes: some passes uses Shardy
  2. optimization passes: legalization and simplification passes
  3. collective optimization passes: similar to step2, but for collective ops
  4. layout assignment passes: each HLO op has a layout like f32[10,20,30]{2,0,1} → control how tensor is sotored in memory
    1. layout format: element type, logical dimensions of the shape, layout permutation in minor to major order
    2. goal: minimize # physical transpositions
    3. propagate layouts “down” and “up” the graph (w/ certain constraints)
    4. may have conflicting layouts (one from an operand, one from a user)
    5. copy will be added for conflicting layouts
  5. layout normalization passes
    1. try to rewrite shape to default layout {rank-1, rank-2, …, 0}
    2. copy (that changes layout) is rewritten as transpose + bitcast
      1. bitcast: a transpose with a layout that makes it a no-op physically
    3. some ops may still have non-default layouts, most notably gather and dot
  6. post layout assignment optimization passes
    1. Triton fusions (GEMM fusions + Softmax/Layernorm fusions) or rewrites to library calls.
    2. autotuning: e.g., finds the best tiling for fusions
  7. fusion passes: PriorityFusion +  Multi-Output fusion
    1. PriorityFusion: fusions guided by the cost model
    2. fusing ops/fusions that share an operand
    3. common subexpression elimination
  8. several post-fusion passes
    1. turning them to async, or enforcing a certain relative order of collectives, ect.
    2. CopyInsertion: add copies to ensure that in-place ops don’t overwrite data needed elsewhere

if adding -xla_dump_to XLA flag, you’ll see a file ended with after_optimizations.txt

3.3 scheduling

  • for the HLO graph, any topological sort order is valid
  • scheduling determines the actual orders
  • goal: reduce peak memory given the tensor lifetime

image.png

3.4 buffer assignment

  • assign buffer slices to each instruction in the HLO graph
  • this happens right before lowering to LLVM IR

image.png

3.5 thunks

  • lower HLO graph to a seq of thunks for a specific backend (CPU or GPU)
  • Thunk: self-contained executution unit
    • compiled kernel launch
    • specific op
    • library call
    • control-flow construct
    • collective communication
  • thunk emission: scheduled HLO → thunk sequence
  • command buffers: optimizing GPU execution
    • CUDA graph: we can record a seq of GPUs ops (kernel launches, mem copies, etc) to reduce CPU overhead
    • command buffer: abstraction of CUDA Graphs or HIP Graphs

3.6 executable

  • final product: bridge btw compiler and runtime
  • includes all the info needed to run the compiled program on a target device (CPU / GPU)
  • modern runtimes like PJRT use slightly higher-level abstractions (see PjRtExecutable), but these ultimately wrap a backend-specific executable
  • what’s included?
    • compiled code
      • CPU: object files
      • GPU: PTX or HSACO code
    • execution Plan (ThunkSequence)
    • memory layout (BufferAssignment)
    • (optional) final, optimized HLO Module: for debugging and profiling

image.png


XLA01 - architecture & workflows
https://gdymind.github.io/2026/02/25/XLA01-architecture-workflows/
Author
gdymind
Posted on
February 25, 2026
Licensed under