Skip to content

ADR-0002: Equinox PyTree Pattern

Status

Accepted

Context

Geometric primitives need to carry differentiable parameters (center, radius, etc.) and be compatible with JAX transformations. The standard approaches are: (1) plain JAX arrays with functional APIs, (2) flax.struct, (3) equinox.Module.

Decision

All geometric primitives and composite shapes inherit from equinox.Module.

  • Parameters are declared as dataclass-style fields with jaxtyping annotations
  • The Module automatically registers as a PyTree, enabling seamless jit/vmap/grad
  • Follows the pattern established by Diffrax, Optimistix, and other scientific JAX libraries

Consequences

  • Clean, readable API: Disk(center=..., radius=...) instead of dictionaries
  • Automatic PyTree flattening/unflattening for all JAX transformations
  • Equinox becomes a core dependency (acceptable given ecosystem alignment)
  • Users familiar with Equinox/Diffrax will find the API immediately familiar