JAX 0.9.1 Release Notes Summary

3 min read
ML

This article was translated from Japanese using an LLM.

Introduction

Following JAX 0.9.0, JAX 0.9.1 release tag was published on March 2, 2026. Although 0.9.2 has already been released, here is a belated summary. As a MICRO version bump, the upgrade effort is minimal, but it includes some important changes and new features worth covering.

Previous article: JAX 0.9 Release Notes Summary

Key Changes in JAX 0.9.1

Changes

isinstance Behavior Change for Non-Array Tracers

JAX tracers that are not of Array type (e.g., Ref type) will no longer report themselves as instances of Array.

Previously, non-Array tracers such as Ref would return True for isinstance(tracer, jax.Array), but this behavior has been fixed. If you have code that relies on type checking, verify that it still works as intended.

Stricter Input Validation for shard_map in Explicit Mode

When using jax.shard_map in Explicit mode, an error is now raised if the input’s PartitionSpec does not match the PartitionSpec specified in in_specs.

Previously, an implicit reshard would silently occur, but from 0.9.1 onwards, it acts as an assertion and raises an error on mismatch.

Example: The Problem with Implicit Reshard

The following example uses Explicit mode on a 2x2 device mesh. It demonstrates the behavior when there is a mismatch between the input array’s sharding and in_specs.

import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType

# Configure 4 devices as a 2x2 mesh (Explicit mode)
mesh = jax.make_mesh(
    (2, 2), ("data", "model"),
    axis_types=(AxisType.Explicit, AxisType.Explicit)
)

# Shard an (8, 8) array along the "data" axis only
x = jnp.ones((8, 8))
x = jax.device_put(x, NamedSharding(mesh, P("data", None)))

# Specify 2-axis sharding ("data", "model") in shard_map's in_specs
# → Input is ("data", None), so PartitionSpec doesn't match
@jax.shard_map(
    mesh=mesh,
    in_specs=P("data", "model"),   # Mismatch with input sharding
    out_specs=P("data", "model")
)
def matmul_sharded(x):
    return x @ x.T

Behavior before JAX 0.9.0:

# No error. Implicit reshard from ("data", None) → ("data", "model") occurs
# AllGather + Slice communication runs behind the scenes, causing performance degradation
result = matmul_sharded(x)

Behavior from JAX 0.9.1 onwards:

# ValueError: shard_map input has PartitionSpec('data', None)
# but in_specs specifies PartitionSpec('data', 'model')
result = matmul_sharded(x)  # Error raised
Correct Approaches
# Approach 1: Explicitly reshard with jax.reshard before passing
x_resharded = jax.reshard(x, NamedSharding(mesh, P("data", "model")))
result = matmul_sharded(x_resharded)

# Approach 2: Omit in_specs to auto-infer from input
@jax.shard_map(
    mesh=mesh,
    out_specs=P("data", None)  # in_specs omitted
)
def matmul_auto(x):
    return x @ x.T

result = matmul_auto(x)  # Input sharding is used as-is
Verification Notebook

You can run the code above to verify the behavior in this notebook:

Open In Colab

Significance of This Change

Implicit reshard was a source of unintended communication (AllGather, ReduceScatter, etc.) silently inserted by the compiler. These communications involve inter-device data transfers, which can significantly impact performance in large-scale distributed training.

AspectImplicit reshard (before 0.9.0)Explicit reshard (0.9.1+)
Mismatch detectionNone (silent reshard)Immediate error detection
Hidden communication costCan occurUser-controlled
DebuggingDifficult (requires inspecting HLO)Easy (clear error message)

In large-scale model training, even a single unintended reshard can significantly reduce throughput. This change enables detection of sharding inconsistencies at an early development stage.

New Features

Compilation Cache Debug Setting

A new debug config jax_compilation_cache_check_contents has been added.

When enabled, it behaves as follows:

  • When get() is called on a value that has not been put() by the current process, it is treated as a cache miss — even if the value actually exists in the disk cache
  • When put() is called, it verifies that the contents match

This is useful for debugging compilation cache integrity issues.

Summary

While JAX 0.9.1 is a MICRO version bump, the behavior change for shard_map in Explicit mode is particularly noteworthy for users working with distributed computing. By disallowing implicit reshards, sharding mismatches are now detected early, making debugging significantly easier.

References

Share: