JAX 0.9 Release Notes Summary
Introduction
JAX 0.9.0 was released on January 20, 2026. Since version 0.5.0, JAX has adopted Effort-Based Versioning (EffVer), where version numbers indicate the “effort” required to upgrade. In roughly one year since 0.4.38 (January 2025), there have been five meso-version bumps (0.5 → 0.6 → 0.7 → 0.8 → 0.9), bringing many changes.
This article summarizes the key changes in JAX 0.9.0.
What is Effort-Based Versioning (EffVer)?
EffVer is a versioning scheme adopted since JAX 0.5.0 (February 2025). It uses three numbers in the format MACRO.MESO.MICRO to indicate the effort required for upgrading.
- MACRO (0.x → 1.x): Significant effort required
- MESO (0.8 → 0.9): Small effort required (includes breaking changes)
- MICRO (0.9.0 → 0.9.0.1): Almost no effort required
Unlike SemVer, version numbers are determined based on the impact from the user’s perspective.
Key Changes in JAX 0.9.0
New Features
jax.thread_guard: A context manager that detects when devices are accessed from multiple threads in multi-controller environments- Explicit sharding support in
jax.export: Support for a new serialization format includingNamedSharding, abstract meshes, and partition specs
Breaking Changes
pmapno-rank-reduction is now the only behavior: Thejax_pmap_no_rank_reductionconfiguration has been removed. When passing a shape(8, 128)topmap(f),freceives(1, 128)(instead of the previous(128,))- Layout/Format API renamed:
Layout→Format,DeviceLocalLayout→Layout jax.experimental.shardmodule removed: Moved tojax.sharding.reshard,jax.sharding.auto_axes, andjax.sharding.explicit_axespjit_p→jit_p:jax.extend.core.primitives.pjit_prenamed tojit_plax.infeed/lax.outfeedremoved: Already deprecated since JAX 0.6
Deprecations
jax.device_put_replicated/jax.device_put_sharded: Unified intojax.device_putjax.cloud_tpu_init: No longer needed as JAX now automatically initializes TPUsjax.numpy.fix: Following NumPy v2.5.0 deprecation. Usejax.numpy.truncas a replacement
Three Sharding Modes
As of JAX 0.9, there are three approaches to sharding.
1. Auto Mode
mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Auto,))
The XLA Shardy compiler automatically handles parallelization, inserting communication operations such as AllGather and ReduceScatter.
2. Explicit Mode
mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Explicit,))
Sharding becomes part of JAX’s type system, and JAX handles sharding propagation.
3. Manual Mode (shard_map)
@jax.shard_map(mesh=mesh, in_specs=P("data"), out_specs=P("data"))
def f(x):
return x * 2
Operates with a device-local view, where users explicitly write communication operations.
Migrating from pmap to shard_map
pmap is now in maintenance mode, and migration to shard_map is recommended. The advantages of shard_map include:
- Support for multi-dimensional meshes (
pmaponly supports 1D) - More efficient automatic differentiation
- Better composability with other JAX APIs
The planned deprecation of pmap has a significant impact on users running large-scale distributed training, so early migration is advisable.