jax 0.8.0


pip install jax

  Latest version

Released: Oct 15, 2025

Project Links

Meta
Author: JAX team
Requires Python: >=3.11

Classifiers

Development Status
  • 5 - Production/Stable

Programming Language
  • Python :: 3.11
  • Python :: 3.12
  • Python :: 3.13
  • Python :: 3.14
  • Python :: Free Threading :: 3 - Stable
logo

Transformable numerical computing at scale

Continuous integration PyPI version

Transformations | Scaling | Install guide | Change logs | Reference docs

What is JAX?

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning.

JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via jax.grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

JAX uses XLA to compile and scale your NumPy programs on TPUs, GPUs, and other hardware accelerators. You can compile your own pure functions with jax.jit. Compilation and automatic differentiation can be composed arbitrarily.

Dig a little deeper, and you'll see that JAX is really an extensible system for composable function transformations at scale.

This is a research project, not an official Google product. Expect sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

import jax
import jax.numpy as jnp

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)  # inputs to the next layer
  return outputs                # no activation on last layer

def loss(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_loss = jax.jit(jax.grad(loss))  # compiled gradient evaluation function
perex_grads = jax.jit(jax.vmap(grad_loss, in_axes=(None, 0, 0)))  # fast per-example grads

Contents

Transformations

At its core, JAX is an extensible system for transforming numerical functions. Here are three: jax.grad, jax.jit, and jax.vmap.

Automatic differentiation with grad

Use jax.grad to efficiently compute reverse-mode gradients:

import jax
import jax.numpy as jnp

def tanh(x):
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = jax.grad(tanh)
print(grad_tanh(1.0))
# prints 0.4199743

You can differentiate to any order with grad:

print(jax.grad(jax.grad(jax.grad(tanh)))(1.0))
# prints 0.62162673

You're free to use differentiation with Python control flow:

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = jax.grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)

See the JAX Autodiff Cookbook and the reference docs on automatic differentiation for more.

Compilation with jit

Use XLA to compile your functions end-to-end with jit, used either as an @jit decorator or as a higher-order function.

import jax
import jax.numpy as jnp

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)

Using jax.jit constrains the kind of Python control flow the function can use; see the tutorial on Control Flow and Logical Operators with JIT for more.

Auto-vectorization with vmap

vmap maps a function along array axes. But instead of just looping over function applications, it pushes the loop down onto the function’s primitive operations, e.g. turning matrix-vector multiplies into matrix-matrix multiplies for better performance.

Using vmap can save you from having to carry around batch dimensions in your code:

import jax
import jax.numpy as jnp

def l1_distance(x, y):
  assert x.ndim == y.ndim == 1  # only works on 1D inputs
  return jnp.sum(jnp.abs(x - y))

def pairwise_distances(dist1D, xs):
  return jax.vmap(jax.vmap(dist1D, (0, None)), (None, 0))(xs, xs)

xs = jax.random.normal(jax.random.key(0), (100, 3))
dists = pairwise_distances(l1_distance, xs)
dists.shape  # (100, 100)

By composing jax.vmap with jax.grad and jax.jit, we can get efficient Jacobian matrices, or per-example gradients:

per_example_grads = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0)))

Scaling

To scale your computations across thousands of devices, you can use any composition of these:

Mode View? Explicit sharding? Explicit Collectives?
Auto Global
Explicit Global
Manual Per-device
from jax.sharding import set_mesh, AxisType, PartitionSpec as P
mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
set_mesh(mesh)

# parameters are sharded for FSDP:
for W, b in params:
  print(f'{jax.typeof(W)}')  # f32[512@data,512]
  print(f'{jax.typeof(b)}')  # f32[512]

# shard data for batch parallelism:
inputs, targets = jax.device_put((inputs, targets), P('data'))

# evaluate gradients, automatically parallelized!
gradfun = jax.jit(jax.grad(loss))
param_grads = gradfun(params, (inputs, targets))

See the tutorial and advanced guides for more.

Gotchas and sharp bits

See the Gotchas Notebook.

Installation

Supported platforms

Linux x86_64 Linux aarch64 Mac aarch64 Windows x86_64 Windows WSL2 x86_64
CPU yes yes yes yes yes
NVIDIA GPU yes yes n/a no experimental
Google TPU yes n/a n/a n/a n/a
AMD GPU yes no n/a no experimental
Apple GPU n/a no experimental n/a n/a
Intel GPU experimental n/a n/a no no

Instructions

Platform Instructions
CPU pip install -U jax
NVIDIA GPU pip install -U "jax[cuda13]"
Google TPU pip install -U "jax[tpu]"
AMD GPU (Linux) Follow AMD's instructions.
Mac GPU Follow Apple's instructions.
Intel GPU Follow Intel's instructions.

See the documentation for information on alternative installation strategies. These include compiling from source, installing with Docker, using other versions of CUDA, a community-supported conda build, and answers to some frequently-asked questions.

Citing JAX

To cite this repository:

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/jax-ml/jax},
  version = {0.3.13},
  year = {2018},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from jax/version.py, and the year corresponds to the project's open-source release.

A nascent version of JAX, supporting only automatic differentiation and compilation to XLA, was described in a paper that appeared at SysML 2018. We're currently working on covering JAX's ideas and capabilities in a more comprehensive and up-to-date paper.

Reference documentation

For details about the JAX API, see the reference documentation.

For getting started as a JAX developer, see the developer documentation.

0.8.0 Oct 15, 2025
0.7.2 Sep 16, 2025
0.7.1 Aug 20, 2025
0.7.0 Jul 22, 2025
0.6.2 Jun 17, 2025
0.6.1 May 21, 2025
0.6.0 Apr 17, 2025
0.5.3 Mar 19, 2025
0.5.2 Mar 05, 2025
0.5.1 Feb 24, 2025
0.5.0 Jan 17, 2025
0.4.38 Dec 17, 2024
0.4.37 Dec 10, 2024
0.4.36 Dec 05, 2024
0.4.35 Oct 22, 2024
0.4.34 Oct 04, 2024
0.4.33 Sep 16, 2024
0.4.32 Sep 11, 2024
0.4.31 Jul 30, 2024
0.4.30 Jun 18, 2024
0.4.29 Jun 10, 2024
0.4.28 May 09, 2024
0.4.27 May 07, 2024
0.4.26 Apr 03, 2024
0.4.25 Feb 26, 2024
0.4.24 Feb 06, 2024
0.4.23 Dec 14, 2023
0.4.22 Dec 14, 2023
0.4.21 Dec 04, 2023
0.4.20 Nov 02, 2023
0.4.19 Oct 19, 2023
0.4.18 Oct 07, 2023
0.4.17 Oct 04, 2023
0.4.16 Sep 19, 2023
0.4.15 Aug 30, 2023
0.4.14 Jul 27, 2023
0.4.13 Jun 23, 2023
0.4.12 Jun 08, 2023
0.4.11 Jun 01, 2023
0.4.10 May 12, 2023
0.4.9 May 10, 2023
0.4.8 Mar 29, 2023
0.4.7 Mar 27, 2023
0.4.6 Mar 09, 2023
0.4.5 Mar 03, 2023
0.4.4 Feb 16, 2023
0.4.3 Feb 08, 2023
0.4.2 Jan 25, 2023
0.4.1 Dec 13, 2022
0.4.0 Dec 12, 2022
0.3.25 Nov 15, 2022
0.3.24 Nov 04, 2022
0.3.23 Oct 12, 2022
0.3.22 Oct 11, 2022
0.3.21 Oct 03, 2022
0.3.20 Sep 28, 2022
0.3.19 Sep 27, 2022
0.3.18 Sep 26, 2022
0.3.17 Aug 31, 2022
0.3.16 Aug 12, 2022
0.3.15 Jul 22, 2022
0.3.14 Jun 28, 2022
0.3.13 May 16, 2022
0.3.12 May 16, 2022
0.3.11 May 15, 2022
0.3.10 May 05, 2022
0.3.9 May 03, 2022
0.3.8 Apr 30, 2022
0.3.7 Apr 16, 2022
0.3.6 Apr 13, 2022
0.3.5 Apr 07, 2022
0.3.4 Mar 18, 2022
0.3.3 Mar 17, 2022
0.3.2 Mar 16, 2022
0.3.1 Feb 18, 2022
0.3.0 Feb 10, 2022
0.2.28 Feb 02, 2022
0.2.27 Jan 18, 2022
0.2.26 Dec 08, 2021
0.2.25 Nov 10, 2021
0.2.24 Oct 19, 2021
0.2.23 Oct 19, 2021
0.2.22 Oct 13, 2021
0.2.21 Sep 23, 2021
0.2.20 Sep 03, 2021
0.2.19 Aug 13, 2021
0.2.18 Jul 21, 2021
0.2.17 Jul 09, 2021
0.2.16 Jun 23, 2021
0.2.15 Jun 23, 2021
0.2.14 Jun 10, 2021
0.2.13 May 04, 2021
0.2.12 Apr 01, 2021
0.2.11 Mar 24, 2021
0.2.10 Mar 05, 2021
0.2.9 Jan 27, 2021
0.2.8 Jan 12, 2021
0.2.7 Dec 05, 2020
0.2.6 Nov 18, 2020
0.2.5 Oct 26, 2020
0.2.4 Oct 20, 2020
0.2.3 Oct 14, 2020
0.2.2 Oct 14, 2020
0.2.1 Oct 07, 2020
0.2.0 Sep 24, 2020
0.1.77 Sep 15, 2020
0.1.76 Sep 08, 2020
0.1.75 Jul 31, 2020
0.1.74 Jul 30, 2020
0.1.73 Jul 22, 2020
0.1.72 Jun 28, 2020
0.1.71 Jun 26, 2020
0.1.70 Jun 09, 2020
0.1.69 Jun 03, 2020
0.1.68 May 21, 2020
0.1.67 May 12, 2020
0.1.66 May 05, 2020
0.1.65 Apr 30, 2020
0.1.64 Apr 21, 2020
0.1.63 Apr 13, 2020
0.1.62 Mar 22, 2020
0.1.61 Mar 17, 2020
0.1.60 Mar 17, 2020
0.1.59 Feb 11, 2020
0.1.58 Jan 28, 2020
0.1.57 Jan 08, 2020
0.1.56 Jan 04, 2020
0.1.55 Dec 07, 2019
0.1.54 Dec 05, 2019
0.1.53 Nov 27, 2019
0.1.52 Nov 19, 2019
0.1.51 Nov 14, 2019
0.1.50 Nov 06, 2019
0.1.49 Oct 31, 2019
0.1.48 Oct 22, 2019
0.1.47 Oct 21, 2019
0.1.46 Sep 17, 2019
0.1.45 Sep 10, 2019
0.1.44 Aug 31, 2019
0.1.43 Aug 25, 2019
0.1.42 Aug 22, 2019
0.1.41 Aug 08, 2019
0.1.40 Aug 05, 2019
0.1.39 Jun 27, 2019
0.1.38 Jun 18, 2019
0.1.37 Jun 11, 2019
0.1.36 Jun 03, 2019
0.1.35 May 23, 2019
0.1.34 May 22, 2019
0.1.33 May 21, 2019
0.1.32 May 20, 2019
0.1.31 May 17, 2019
0.1.30 May 16, 2019
0.1.29 May 16, 2019
0.1.28 May 10, 2019
0.1.27 May 07, 2019
0.1.26 May 05, 2019
0.1.25 Apr 08, 2019
0.1.24 Apr 06, 2019
0.1.23 Apr 05, 2019
0.1.22 Mar 28, 2019
0.1.21 Mar 01, 2019
0.1.20 Feb 17, 2019
0.1.19 Feb 13, 2019
0.1.18 Feb 06, 2019
0.1.16 Jan 12, 2019
0.1.15 Dec 24, 2018
0.1.14 Dec 19, 2018
0.1.13 Dec 19, 2018
0.1.12 Dec 19, 2018
0.1.11 Dec 16, 2018
0.1.10 Dec 15, 2018
0.1.9 Dec 13, 2018
0.1.8 Dec 12, 2018
0.1.7 Dec 12, 2018
0.1.6 Dec 11, 2018
0.1.5 Dec 09, 2018
0.1.4 Dec 09, 2018
0.1.3 Dec 08, 2018
0.1.2 Dec 07, 2018
0.1.1 Dec 07, 2018
0.1 Dec 07, 2018
0.0 Dec 07, 2018

Wheel compatibility matrix

Platform Python 3
any

Files in release

Extras:
Dependencies:
jaxlib (<=0.8.0,>=0.8.0)
ml_dtypes (>=0.5.0)
numpy (>=2.0)
opt_einsum
scipy (>=1.13)