No summary available
Project Links
Meta
Author: Allen Wang
Requires Python: >=3.10,<4.0
Classifiers
Programming Language
- Python :: 3
- Python :: 3.10
- Python :: 3.11
- Python :: 3.12
Simple Xarray + JAX Integration
This is an experiment at integrating Xarray + JAX in a simple way, leveraging equinox.
import jax.numpy as jnp
import xarray as xr
import xarray_jax as xj
# Construct a DataArray.
da = xr.DataArray(
xr.Variable(["x", "y"], jnp.ones((2, 3))),
coords={"x": [1, 2], "y": [3, 4, 5]},
name="foo",
attrs={"attr1": "value1"},
)
# Do some operations inside a JIT compiled function.
@eqx.filter_jit
def some_function(data):
neg_data = -1.0 * data
return neg_data * neg_data.coords["y"] # Multiply data by coords.
da = some_function(da)
# Construct a xr.DataArray with dummy data (useful for tree manipulation).
da_mask = jax.tree.map(lambda _: True, data)
# Use jax.grad.
@eqx.filter_jit
def fn(data):
return (data**2.0).sum().data
grad = jax.grad(fn)(da)
# Convert to a custom XjDataArray, implemented as an equinox module.
# (Useful for avoiding potentially weird xarray interactions with JAX).
xj_da = xj.from_xarray(da)
# Convert back to a xr.DataArray.
da = xj.to_xarray(xj_da)
Installation
pip install xarray_jax
Status
- PyTree node registrations
-
xr.Variable -
xr.DataArray -
xr.Dataset
-
- Minimal shadow types implemented as equinox modules to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types).
-
XjVariable -
XjDataArray -
XjDataset
-
-
xj.from_xarrayandxj.to_xarrayfunctions to go betweenxjandxrtypes. - Support for
xrtypes with dummy data (useful for tree manipulation). - Support for transformations that change the dimensionality of the data.
Sharp Edges
Prefer eqx.filter_jit over jax.jit
There are some edge cases with metadata that eqx.filter_jit handles but jax.jit does not.
Operations that Increase the Dimensionality of the Data
Operations that increase the dimensionality of the data (e.g. jnp.expand_dims) will cause problems downstream.
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))
# This will not error.
var = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), var)
# The error from expanding the dimensionality will be triggered here.
var = var + 1
Dispatching to jnp is not supported yet
Pending resolution of https://github.com/pydata/xarray/issues/7848.
var = xr.Variable(dims=("x", "y"), data=jnp.ones((3, 3)))
# This will fail.
jnp.square(var)
# This will work.
xr.apply_ufunc(jnp.square, var)
Distinction from the GraphCast Implementation
This experiment is largely inspired by the GraphCast implementation, with a direct re-use of the _HashableCoords in that project.
However, this experiment aims to:
- Take a more minimialist approach (and thus neglects some features such as support JAX arrays as coordinates).
- Find a solution more compatible with common JAX PyTree manipulation patterns that trigger errors with Xarray types. For example, it's common to use boolean masks to filter out elements of a PyTree, but this tends to fail with Xarray types.
Acknowledgements
This repo was made possible by great discussions within the JAX + Xarray open source community, especially this one. In particular, the author would like to acknowledge @shoyer, @mjwillson, and @TomNicholas.
Wheel compatibility matrix
Files in release
Extras:
None
Dependencies:
equinox
(<0.12.0,>=0.11.7)
jax
(<0.5.0,>=0.4.33)
xarray
(<2025.0.0,>=2024.9.0)