JAX 0.9 リリースノートまとめ
6 min read
ML
はじめに
2026年1月20日に JAX 0.9.0 がリリースされました。JAX は 0.5.0 から Effort-Based Versioning (EffVer) を採用しており、バージョン番号はアップグレードに必要な「労力」を示します。0.4.38 (2025年1月) から約1年で 0.5 → 0.6 → 0.7 → 0.8 → 0.9 と5回のメソバージョンアップが行われ、多くの変更がありました。
本記事では JAX 0.9.0 の変更点をまとめます。
Effort-Based Versioning (EffVer) とは
JAX 0.5.0 (2025年2月) から採用されたバージョニング方式です。MACRO.MESO.MICRO の3つの数字で、アップグレードに必要な労力を示します。
- MACRO (0.x → 1.x): 大きな労力が必要
- MESO (0.8 → 0.9): 小さな労力が必要(破壊的変更あり)
- MICRO (0.9.0 → 0.9.0.1): ほぼ労力なし
SemVer と異なり、ユーザー視点での影響度に基づいてバージョンが決まります。
JAX 0.9.0 の主な変更点
新機能
jax.thread_guard: マルチコントローラー環境で、複数スレッドからデバイスが使用されることを検出するコンテキストマネージャjax.exportでの explicit sharding サポート:NamedSharding、abstract mesh、partition spec を含む新しいシリアライゼーションフォーマットに対応
破壊的変更
pmapの no-rank-reduction が唯一の動作に:jax_pmap_no_rank_reduction設定が削除。pmap(f)に shape(8, 128)を渡すと、fは(1, 128)を受け取る(以前の(128,)ではなく)- Layout/Format API のリネーム:
Layout→Format、DeviceLocalLayout→Layoutに名称変更 jax.experimental.shardモジュール削除:jax.sharding.reshard、jax.sharding.auto_axes、jax.sharding.explicit_axesに移動pjit_p→jit_p:jax.extend.core.primitives.pjit_pがjit_pにリネームlax.infeed/lax.outfeed削除: JAX 0.6 で非推奨化済み
非推奨化 (Deprecation)
jax.device_put_replicated/jax.device_put_sharded:jax.device_putに統一jax.cloud_tpu_init: JAX が自動でTPU初期化を行うため不要にjax.numpy.fix: NumPy v2.5.0 の非推奨化に追従。jax.numpy.truncを代替として使用
シャーディングの3つのモード
JAX 0.9 時点で、シャーディングには3つのアプローチがあります。
1. Auto モード
mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Auto,))
XLA の Shardy コンパイラが自動的に並列化を行い、AllGather、ReduceScatter 等の通信を挿入します。
2. Explicit モード
mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Explicit,))
シャーディングが JAX の型システムの一部となり、JAX がシャーディングの伝播を処理します。
3. Manual モード (shard_map)
@jax.shard_map(mesh=mesh, in_specs=P("data"), out_specs=P("data"))
def f(x):
return x * 2
デバイスごとのローカルビューで、通信はユーザーが明示的に記述します。
pmap から shard_map への移行
pmap はメンテナンスモードとなり、shard_map への移行が推奨されています。shard_map の利点は以下の通りです。
- 多次元メッシュのサポート(
pmapは1次元のみ) - より効率的な自動微分
- 他の JAX API との高い合成可能性
特に pmap の廃止予定は、大規模な分散学習を行っているユーザーにとって影響が大きいため、早めの対応が必要そうです。