JAX is Autograd and XLA, brought together for high-performance numerical computing and ML research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, JIT compile to GPU/TPU, and more.
Accelerated Linear Algebra (XLA) is a domain-specific compiler for linear algebra that can accelerate math operations with potentially no source code changes.
Just-in-time compilation is technique that interpreted languages use that, while executing code, the interpreter will also compile the code so that on the next time that the code is executed, it will run as compiled code (faster) ⚡️
9/11🧵
With all the concepts (loosely explained for brevity), we can understand what is JAX:
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
10/11🧵
This month I'll keep posting about my journey learning JAX. 📚👓