JAX 0.9.1 Release Notes Summary
This article was translated from Japanese using an LLM.
Introduction
Following JAX 0.9.0, JAX 0.9.1 release tag was published on March 2, 2026. Although 0.9.2 has already been released, here is a belated summary. As a MICRO version bump, the upgrade effort is minimal, but it includes some important changes and new features worth covering.
Previous article: JAX 0.9 Release Notes Summary
Key Changes in JAX 0.9.1
Changes
isinstance Behavior Change for Non-Array Tracers
JAX tracers that are not of Array type (e.g., Ref type) will no longer report themselves as instances of Array.
Previously, non-Array tracers such as Ref would return True for isinstance(tracer, jax.Array), but this behavior has been fixed. If you have code that relies on type checking, verify that it still works as intended.
Stricter Input Validation for shard_map in Explicit Mode
When using jax.shard_map in Explicit mode, an error is now raised if the input’s PartitionSpec does not match the PartitionSpec specified in in_specs.
Previously, an implicit reshard would silently occur, but from 0.9.1 onwards, it acts as an assertion and raises an error on mismatch.
Example: The Problem with Implicit Reshard
The following example uses Explicit mode on a 2x2 device mesh. It demonstrates the behavior when there is a mismatch between the input array’s sharding and in_specs.
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
# Configure 4 devices as a 2x2 mesh (Explicit mode)
mesh = jax.make_mesh(
(2, 2), ("data", "model"),
axis_types=(AxisType.Explicit, AxisType.Explicit)
)
# Shard an (8, 8) array along the "data" axis only
x = jnp.ones((8, 8))
x = jax.device_put(x, NamedSharding(mesh, P("data", None)))
# Specify 2-axis sharding ("data", "model") in shard_map's in_specs
# → Input is ("data", None), so PartitionSpec doesn't match
@jax.shard_map(
mesh=mesh,
in_specs=P("data", "model"), # Mismatch with input sharding
out_specs=P("data", "model")
)
def matmul_sharded(x):
return x @ x.T
Behavior before JAX 0.9.0:
# No error. Implicit reshard from ("data", None) → ("data", "model") occurs
# AllGather + Slice communication runs behind the scenes, causing performance degradation
result = matmul_sharded(x)
Behavior from JAX 0.9.1 onwards:
# ValueError: shard_map input has PartitionSpec('data', None)
# but in_specs specifies PartitionSpec('data', 'model')
result = matmul_sharded(x) # Error raised
Correct Approaches
# Approach 1: Explicitly reshard with jax.reshard before passing
x_resharded = jax.reshard(x, NamedSharding(mesh, P("data", "model")))
result = matmul_sharded(x_resharded)
# Approach 2: Omit in_specs to auto-infer from input
@jax.shard_map(
mesh=mesh,
out_specs=P("data", None) # in_specs omitted
)
def matmul_auto(x):
return x @ x.T
result = matmul_auto(x) # Input sharding is used as-is
Verification Notebook
You can run the code above to verify the behavior in this notebook:
Significance of This Change
Implicit reshard was a source of unintended communication (AllGather, ReduceScatter, etc.) silently inserted by the compiler. These communications involve inter-device data transfers, which can significantly impact performance in large-scale distributed training.
| Aspect | Implicit reshard (before 0.9.0) | Explicit reshard (0.9.1+) |
|---|---|---|
| Mismatch detection | None (silent reshard) | Immediate error detection |
| Hidden communication cost | Can occur | User-controlled |
| Debugging | Difficult (requires inspecting HLO) | Easy (clear error message) |
In large-scale model training, even a single unintended reshard can significantly reduce throughput. This change enables detection of sharding inconsistencies at an early development stage.
New Features
Compilation Cache Debug Setting
A new debug config jax_compilation_cache_check_contents has been added.
When enabled, it behaves as follows:
- When
get()is called on a value that has not beenput()by the current process, it is treated as a cache miss — even if the value actually exists in the disk cache - When
put()is called, it verifies that the contents match
This is useful for debugging compilation cache integrity issues.
Summary
While JAX 0.9.1 is a MICRO version bump, the behavior change for shard_map in Explicit mode is particularly noteworthy for users working with distributed computing. By disallowing implicit reshards, sharding mismatches are now detected early, making debugging significantly easier.