gdymind's blog
  • Home
  • Archives
  • Categories
  • Tags
  • About

Pallas examples by Sharad Vikram (Pallas author)

https://www.youtube.com/watch?v=NFKubflDb1A code was written in 2023 (may be slightly outdated, but the core concept still valid) presented by Sharad Vikram (Pallas author) 1. TPU architecture recap
2026-03-08
#TPU #kernel

jax.jit, torch.compile & CUDA graph

1. jax.jit jax.jit traces Python into computational graph (a jaxpr) → XLA compiles the graph into an optimized HLO program for the target device After compilation, Python is completely out of the loo
2026-03-07
#JAX #TPU #kernel #GPU

KV cache in sliding-window attention

1. Longformerhttps://arxiv.org/abs/2004.05150 it was the 1st sliding-window attention (SWA) paper. published in 2020 traditional full attention: compute complexity $O(n^2)$, memory $O(n^2)$, where
2026-03-02
#LLM inference #KV cache

XLA02 - shapes, layout & tiling

https://openxla.org/xla/shapes https://openxla.org/xla/tiled_layout 1. XLA op formatHLO example 12add.936 = bf16[8,1,1280,16384]{3,2,0,1:T(8,128)(2,1)} add(exponential.183, broadcas
2026-02-26
#JAX #TPU #GPU #Pallas #Kernel

XLA01 - architecture & workflows

https://openxla.org/xla 0. IntroXLA in the whole JAX stack Source: Yi Wang’s linkedin post LLM is basically matmul. XLA (Accelerated Linear Algebra) optimizes linear algebra on multiple decives (TPU
2026-02-25
#JAX #TPU #GPU #Pallas #Kernel

Knowledge Distillation 101

source: https://huggingface.co/blog/Kseniase/kd 1. history Knowledge Distillation (KD): transfer knowledge from teacher model to a smaller student model DeepSeek-R1 proposed effective distillation imp
2026-02-22
#Training

GPU mode - lecture2 - CUDA 101

https://www.youtube.com/watch?v=NQ-0D5Ti2dc&t=9s https://github.com/gpu-mode/lectures/tree/main/lecture_002 from PMPP book 1. Memory allocation nvidia devices come with their own DRAM (device) glo
2026-02-19
#kernel #GPU

Pallas 101 - multi-backend kernel for JAX

1. why pallas? JAX works with pure functions (i.e., same inputs will produce the same outputs). JAX arrays are immutable not flexible or efficient for kernel implementation GEMM steps input matrix →
2026-02-19
#JAX #TPU #kernel

5D parallelism in LLM training

source: The Ultra-scale Playbook 0. High-level overview Targed on large-scale (like 512 GPUs) training Tradeoff among the following factors memory usage: params, optimizer states, gradients compute ef
2026-02-07
#Training

Memory usage breakdown during Training

1. Memory Composition Model Parameters Intermediate Activations (Forward pass) will be used to calculate gradiants during backward Gradients (Backward pass) Optimizer States 2. Static Memory (Weigh
2026-01-25
#Training
123

Search

Hexo Fluid
visited times unique visitors: