JAX 0.9 リリースノートまとめ

6 min read
ML

はじめに

2026年1月20日に JAX 0.9.0 がリリースされました。JAX は 0.5.0 から Effort-Based Versioning (EffVer) を採用しており、バージョン番号はアップグレードに必要な「労力」を示します。0.4.38 (2025年1月) から約1年で 0.5 → 0.6 → 0.7 → 0.8 → 0.9 と5回のメソバージョンアップが行われ、多くの変更がありました。

本記事では JAX 0.9.0 の変更点をまとめます。

Effort-Based Versioning (EffVer) とは

JAX 0.5.0 (2025年2月) から採用されたバージョニング方式です。MACRO.MESO.MICRO の3つの数字で、アップグレードに必要な労力を示します。

  • MACRO (0.x → 1.x): 大きな労力が必要
  • MESO (0.8 → 0.9): 小さな労力が必要(破壊的変更あり)
  • MICRO (0.9.0 → 0.9.0.1): ほぼ労力なし

SemVer と異なり、ユーザー視点での影響度に基づいてバージョンが決まります。

JAX 0.9.0 の主な変更点

新機能

  • jax.thread_guard: マルチコントローラー環境で、複数スレッドからデバイスが使用されることを検出するコンテキストマネージャ
  • jax.export での explicit sharding サポート: NamedSharding、abstract mesh、partition spec を含む新しいシリアライゼーションフォーマットに対応

破壊的変更

  • pmap の no-rank-reduction が唯一の動作に: jax_pmap_no_rank_reduction 設定が削除。pmap(f) に shape (8, 128) を渡すと、f(1, 128) を受け取る(以前の (128,) ではなく)
  • Layout/Format API のリネーム: LayoutFormatDeviceLocalLayoutLayout に名称変更
  • jax.experimental.shard モジュール削除: jax.sharding.reshardjax.sharding.auto_axesjax.sharding.explicit_axes に移動
  • pjit_pjit_p: jax.extend.core.primitives.pjit_pjit_p にリネーム
  • lax.infeed / lax.outfeed 削除: JAX 0.6 で非推奨化済み

非推奨化 (Deprecation)

  • jax.device_put_replicated / jax.device_put_sharded: jax.device_put に統一
  • jax.cloud_tpu_init: JAX が自動でTPU初期化を行うため不要に
  • jax.numpy.fix: NumPy v2.5.0 の非推奨化に追従。jax.numpy.trunc を代替として使用

シャーディングの3つのモード

JAX 0.9 時点で、シャーディングには3つのアプローチがあります。

1. Auto モード

mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Auto,))

XLA の Shardy コンパイラが自動的に並列化を行い、AllGather、ReduceScatter 等の通信を挿入します。

2. Explicit モード

mesh = jax.make_mesh((8,), ("data",), axis_types=(jax.sharding.AxisType.Explicit,))

シャーディングが JAX の型システムの一部となり、JAX がシャーディングの伝播を処理します。

3. Manual モード (shard_map)

@jax.shard_map(mesh=mesh, in_specs=P("data"), out_specs=P("data"))
def f(x):
    return x * 2

デバイスごとのローカルビューで、通信はユーザーが明示的に記述します。

pmap から shard_map への移行

pmap はメンテナンスモードとなり、shard_map への移行が推奨されています。shard_map の利点は以下の通りです。

  • 多次元メッシュのサポート(pmap は1次元のみ)
  • より効率的な自動微分
  • 他の JAX API との高い合成可能性

特に pmap の廃止予定は、大規模な分散学習を行っているユーザーにとって影響が大きいため、早めの対応が必要そうです。

参考

Share: