Pallas 101 - multi-backend kernel for JAX
1. why pallas? JAX works with pure functions (i.e., same inputs will produce the same outputs). JAX arrays are immutable not flexible or efficient for kernel implementation GEMM steps input matrix →