JAX 0.9 Release Notes Summary

2 min read
ML

Introduction

JAX 0.9.0 was released on January 20, 2026. Since version 0.5.0, JAX has adopted Effort-Based Versioning (EffVer), where version numbers indicate the “effort” required to upgrade. In roughly one year since 0.4.38 (January 2025), there have been five meso-version bumps (0.5 → 0.6 → 0.7 → 0.8 → 0.9), bringing many changes.

This article summarizes the key changes in JAX 0.9.0.

What is Effort-Based Versioning (EffVer)?

EffVer is a versioning scheme adopted since JAX 0.5.0 (February 2025). It uses three numbers in the format MACRO.MESO.MICRO to indicate the effort required for upgrading.

  • MACRO (0.x → 1.x): Significant effort required
  • MESO (0.8 → 0.9): Small effort required (includes breaking changes)
  • MICRO (0.9.0 → 0.9.0.1): Almost no effort required

Unlike SemVer, version numbers are determined based on the impact from the user’s perspective.

Key Changes in JAX 0.9.0

New Features

  • jax.thread_guard: A context manager that detects when devices are accessed from multiple threads in multi-controller environments
  • Explicit sharding support in jax.export: Support for a new serialization format including NamedSharding, abstract meshes, and partition specs

Breaking Changes

  • pmap no-rank-reduction is now the only behavior: The jax_pmap_no_rank_reduction configuration has been removed. When passing a shape (8, 128) to pmap(f), f receives (1, 128) (instead of the previous (128,))
  • Layout/Format API renamed: LayoutFormat, DeviceLocalLayoutLayout
  • jax.experimental.shard module removed: Moved to jax.sharding.reshard, jax.sharding.auto_axes, and jax.sharding.explicit_axes
  • pjit_pjit_p: jax.extend.core.primitives.pjit_p renamed to jit_p
  • lax.infeed / lax.outfeed removed: Already deprecated since JAX 0.6

Deprecations

  • jax.device_put_replicated / jax.device_put_sharded: Unified into jax.device_put
  • jax.cloud_tpu_init: No longer needed as JAX now automatically initializes TPUs
  • jax.numpy.fix: Following NumPy v2.5.0 deprecation. Use jax.numpy.trunc as a replacement

Three Sharding Modes

As of JAX 0.9, there are three approaches to sharding.

1. Auto Mode

mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Auto,))

The XLA Shardy compiler automatically handles parallelization, inserting communication operations such as AllGather and ReduceScatter.

2. Explicit Mode

mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Explicit,))

Sharding becomes part of JAX’s type system, and JAX handles sharding propagation.

3. Manual Mode (shard_map)

@jax.shard_map(mesh=mesh, in_specs=P("data"), out_specs=P("data"))
def f(x):
    return x * 2

Operates with a device-local view, where users explicitly write communication operations.

Migrating from pmap to shard_map

pmap is now in maintenance mode, and migration to shard_map is recommended. The advantages of shard_map include:

  • Support for multi-dimensional meshes (pmap only supports 1D)
  • More efficient automatic differentiation
  • Better composability with other JAX APIs

The planned deprecation of pmap has a significant impact on users running large-scale distributed training, so early migration is advisable.

References

Share: