JAX 0.9.1 リリースノートまとめ
はじめに
JAX 0.9.0 に続き、JAX 0.9.1 が2026/3/2にリリースtagが公開されました。すでに0.9.2も公開されているので遅くなりました。今回は MICRO バージョンアップのため、アップグレードに必要な労力は小さいですが、いくつかの重要な変更と新機能が含まれていましたので説明します。
前回の記事: JAX 0.9 リリースノートまとめ
JAX 0.9.1 の主な変更点
変更 (Changes)
Array 型でないトレーサーの isinstance 動作変更
Array 型でない JAX トレーサー(例: Ref 型)が、isinstance チェックで Array のインスタンスとして報告されなくなりました。
これまでは、Ref 型などの非 Array トレーサーも isinstance(tracer, jax.Array) で True を返していましたが、この動作が修正されました。型チェックに依存したコードを書いている場合は、意図通りに動作するか確認が必要です。
Explicit モードでの shard_map の入力検証強化
jax.shard_map を Explicit モードで使用する際、入力の PartitionSpec が in_specs で指定した PartitionSpec と一致しない場合にエラーが発生するようになりました。
以前の動作では暗黙的に reshard が行われていましたが、0.9.1 からは assert として機能し、不一致があればエラーになります。
具体例: 暗黙的 reshard の問題
以下は、2x2 のデバイスメッシュ上で Explicit モードを使う例です。入力配列のシャーディングと in_specs の不一致がある場合の動作を示します。
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
# 4デバイスを 2x2 メッシュとして構成(Explicit モード)
mesh = jax.make_mesh(
(2, 2), ("data", "model"),
axis_types=(AxisType.Explicit, AxisType.Explicit)
)
# (8, 8) の配列を "data" 軸のみでシャーディング
x = jnp.ones((8, 8))
x = jax.device_put(x, NamedSharding(mesh, P("data", None)))
# shard_map の in_specs で ("data", "model") の2軸シャーディングを指定
# → 入力は ("data", None) なので PartitionSpec が不一致
@jax.shard_map(
mesh=mesh,
in_specs=P("data", "model"),
out_specs=P("data", "model")
)
def matmul_sharded(x):
return x @ x.T
JAX 0.9.0 以前の動作:
# エラーなし。暗黙的に ("data", None) → ("data", "model") へ reshard が発生
# AllGather + Slice の通信が裏で実行され、パフォーマンス劣化の原因に
result = matmul_sharded(x)
JAX 0.9.1 以降の動作:
# ValueError: shard_map input has PartitionSpec('data', None)
# but in_specs specifies PartitionSpec('data', 'model')
result = matmul_sharded(x) # エラー発生
正しい対処方法
# 方法1: jax.reshard で明示的に reshard してから渡す
x_resharded = jax.reshard(x, NamedSharding(mesh, P("data", "model")))
result = matmul_sharded(x_resharded)
# 方法2: in_specs を省略して入力から自動推論させる
@jax.shard_map(
mesh=mesh,
out_specs=P("data", None) # in_specs を省略
)
def matmul_auto(x):
return x @ x.T
result = matmul_auto(x) # 入力のシャーディングがそのまま使われる
検証ノートブック
上記のコードを実際に動かして検証できるノートブックはこちらです。
この変更の意義
暗黙的な reshard は、ユーザーが意図しない通信(AllGather、ReduceScatter 等)をコンパイラが挿入する原因でした。これらの通信はデバイス間のデータ転送を伴うため、大規模な分散学習ではパフォーマンスに大きな影響を与えます。
| 項目 | 暗黙的 reshard (0.9.0以前) | 明示的 reshard (0.9.1以降) |
|---|---|---|
| 不一致検出 | なし(黙って reshard) | エラーで即座に検出 |
| 隠れた通信コスト | 発生しうる | ユーザーが意識して制御 |
| デバッグ | 困難(HLO を見ないと気づけない) | 容易(エラーメッセージで判明) |
特に大規模モデルの分散学習では、意図しない reshard が1箇所あるだけでスループットが大きく低下することがあります。この変更により、シャーディングの不整合を開発初期段階で発見できるようになりました。
新機能 (New Features)
コンパイルキャッシュのデバッグ設定
新しいデバッグ設定 jax_compilation_cache_check_contents が追加されました。
この設定を有効にすると、以下の動作になります:
get()が呼ばれた際、現在のプロセスがput()していない値はキャッシュミスとして扱われる(実際にディスクキャッシュに存在していても)put()の際に、内容が一致するか検証する
これはコンパイルキャッシュのデバッグに有用で、キャッシュの整合性問題を調査する際に役立ちます。
まとめ
JAX 0.9.1 は MICRO バージョンアップですが、特に Explicit モードでの shard_map の動作変更は、分散処理を行っているユーザーにとって注意が必要です。暗黙的な reshard が許容されなくなったことで、シャーディングの不一致がより早期に検出されるようになり、デバッグの容易さが向上しました。