Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with jaxtyping #12

Open
nstarman opened this issue Mar 5, 2024 · 1 comment
Open

Compatibility with jaxtyping #12

nstarman opened this issue Mar 5, 2024 · 1 comment
Labels
question User queries

Comments

@nstarman
Copy link
Contributor

nstarman commented Mar 5, 2024

Hi! Really enjoying quax. I've been working to get galax potentials quaxified (most relevantly in GalacticDynamics/galax#187) and ran into a compatibility issue with jaxtyping's runtime type checking.

If a module has runtime type checking turned on, e.g. install_import_hook("galax.potential", "beartype.beartype") then quaxed functions don't pass objects through correctly. As an example from GalacticDynamics/galax#187

>>> from jax_quantity import Quantity
>>> import galax.potential as gp
>>> import galax.units as gu
>>> pot = gp.KeplerPotential(m=Quantity(1e12, "Msun"), units=gu.galactic)
>>> pot._potential_energy(Quantity([1.0, 0, 0], "kpc"), Quantity(0, "Myr"), pot._G)
TypeCheckError: Type-check error whilst checking the parameters of _potential_energy.
The problem arose whilst typechecking argument 'q'.
Called with arguments: {'self': KeplerPotential(...), 'q': f64[3], 't': i64[], '_G': f64[]}
Parameter annotations: (self, q: Shaped[Quantity, '*batch 3'], t: Union[Shaped[Quantity, '*#batch '], Shaped[Quantity, '*#batch ']], /, _G: Float[Quantity, '']).
@patrick-kidger
Copy link
Owner

This looks expected to me. Once a Value passes through a quaxify boundary, then it should look like an array to everything inside.

Put another way, the underlying function should be one that acts on arrays. It is by wrapping it in a quaxify that it becomes able to consume Values.

If you use the jaxtyping import hook, then the type-checking is always put on the bottom of the decorator list (this is needed for compatibility with jax.custom_{jvp,vjp}. In particular, that means it happens inside the quaxify.

@patrick-kidger patrick-kidger added the question User queries label Mar 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants