JAX
Data & MLComposable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Release History
jax-v0.8.2BreakingThis release deprecates several core and interpreter symbols, removes `jax.experimental.si_vjp`, and changes `Tracer` inheritance, requiring code updates to use `jax.vjp` and the new `pcast` API.
jax-v0.8.11 fix3 featuresJAX adds decorator‑factory support for `jax.jit`, new `implementation` and `algorithm` options for linear algebra ops, fixes a GPU eigh bug, and deprecates several sharding and cloud‑TPU utilities.
jax-v0.8.0Breaking5 featuresJAX introduces several breaking changes, including a new default implementation for `jax.pmap` and removal of many deprecated APIs, while adding new features such as namedtuple returns for eig and enhanced dlpack support.
jax-v0.7.2Breaking2 fixesJAX drops support for raw DLPack capsules in jax.dlpack.from_dlpack, raises minimum NumPy/SciPy versions, and introduces several deprecations and bug fixes.
jax-v0.7.1Breaking5 featuresJAX introduces new Python 3.13t/3.14t wheels, a new `jax.set_mesh` API, CUDA 12.9 builds, and several deprecations including removal of `jax.sharding.use_mesh`.
jax-v0.7.0Breaking2 featuresJAX 0.7 introduces Shardy as the default execution model, updates autodiff to direct linearization, raises the minimum Python version to 3.11, and deprecates or removes several legacy APIs.
jax-v0.6.2Breaking1 featureThis release introduces the new `jax.tree.broadcast` helper and raises the minimum required versions of NumPy and SciPy.
jax-v0.6.1Breaking1 featureThis release adds the new `jax.lax.axis_size` feature, makes `PartitionSpec` and `ShapeDtypeStruct` behavior stricter, re‑enables CUDA version checks, and deprecates `custom_jvp_call_jaxpr_p`.
jax-v0.6.0Breaking4 featuresThis release removes several legacy tracing and configuration options, raises the minimum CUDA/CuDNN versions, updates package extras syntax, and deprecates many old APIs while introducing stricter `jax.jit` calling conventions.
jax-v0.5.32 featuresThis release introduces new options for slicing functions and categorical sampling in JAX, improving code size and adding support for sampling without replacement.
jax-v0.5.21 fixPatch 0.5.1 fixes TPU metric logging and the `tpu-info` command.
jax-v0.5.12 fixes3 featuresThis release adds experimental custom DCE support, new low‑level reduction ops, column‑pivoting QR, updates CPU collective defaults, removes the libtpu‑nightly dependency, deprecates internal functions requiring debug info, and includes TPU runtime and compilation cache fixes.
jax-v0.5.0Breaking2 fixes3 featuresJAX meso release adds multi‑dimensional FFT support, FFI state registration, and debugging info for AOT lowering, while breaking PRNG semantics, dropping Mac x86 wheels, and raising NumPy/SciPy minimum versions.