-
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interest in add array-api module #10
Comments
That's really nice! However, once you've got user-facing docs etc., then I would be very happy to add a link to this from the README / advertise it as a downstream library in the Quax docs. (On an unrelated note, by the way, I just went poking through |
Sounds good, thanks!
Thanks for the heads up. Is there not a way to make it work, e.g. |
The problem with doing this Not that I can see a way to happen cleanly. If we were to go for that approach, then the following: def foo(x):
z = Quantity(...)
return x + z
quaxify(foo)(y) would actually result in two nested In terms of what that one end up doing: rather than hitting the dispatch rule for What that means is that right now, the expectation is that you will create all your array-ish values outside the What might we do about this. I'm not 100% happy with this approach, as it means that we have to pass all array-ish values as formal arguments. This makes it best suited for the approach of "transform an existing program" rather than "use array-ish values anywhere within a new program". But it was the only one I could find that seemed to be largely footgun free, act in a consistent way, etc. For what it's worth, it should be possible to write some kind of " def foo(x):
z = Quantity(...)
z2 = quax.wrap_arrayish_value_into_tracer(z)
return x + z2
quaxify(foo)(y) in other words, we get to pretend that we'd passed I don't love this though, as it (a) seems like an easy footgun to forget, (b) leaves the original I think we're still figuring out what using Quax looks like. Maybe a cleaner way to do things will present itself. |
Thanks for the detailed response!
I think this would be incredibly useful for Equinox so that the following would work. class Model(eqx.Module):
y: Quantity
@quaxify
def foo(self, x: Quantity):
return x + self.y ** 2 |
FWIW, I think the one you've linked there should work already. All PyTrees (both |
Awesome, then in |
I have confirmed that wrapping into tracers will be necessary. def _potential_energy(self, q: BatchQVec3, /, t: BatchFloatOrIntQScalar) -> BatchFloatQScalar:
r = xp.linalg.vector_norm(q, axis=-1)
out = -self._G * self.m(t) / (r + self.c(t))
return out Inserting a
So that
|
Perhaps a function |
Or is there some way |
I think I like this! I might even make it a classmethod on class Value:
@classmethod
def quaxify(cls, like_tracer, /, *args, **kwargs):
... # implementation here Or maybe we really jump off the API deep-end and do something like In either case the goal is to help discourage ever creating an unwrapped array-ish value in the first place. I'm not 100% sold on any of these approaches here -- frankly maybe just the
I'm not sure what you're getting with this one I'm afraid -- it will already handle |
I meant if all the other inputs to the method are tracer objects then it would auto-apply |
So if you have class Foo:
@quaxify
def some_method(self, ...): ... then it will already wrap all values into tracers. That's what (It sounds like you're after some other non- Let me know if you have any opinion on the other options though btw, as I don't currently have strong feelings between them! |
@patrick-kidger, I think we're reaching that point for |
Nice! I've just opened #26 to add this to the docs. |
Based on
quax
I wrote array-api-jax-compat for use with jax-quantity. Are you interested to upstream array-api-jax-compat as a submodule, e.g.quax.array_api
? With the submodule users won't have toquaxif
any function in the array-api themselves, just importquax.array_api as xp
.The text was updated successfully, but these errors were encountered: