-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New solvers #3
Comments
Thanks for the library. As they said, and another one... Do you have any intuition whether it would be useful to swap in the cusparse kernels for, say, a tridiagonal solve, and add a At the moment, I have a custom written broadcast-based function similar to the one in here but with But would something like this be better than what XLA comes up with? https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-t-gtsv2stridedbatch Thanks again |
I think that would definitely be useful. Note that we actually already have the gradient via our custom primitive ( But at least in the unbatched case, I reckon that probably would give a performance boost. (Plus I believe their tridiagonal algorithm is stabilised, whereas our current Thomas algorithm is unstablised.) The only reason we haven't pursued this is that it's not straightforward to make this happen in JAX -- calling out to custom kernels like this isn't very well supported. |
Couple of disparate thoughts
|
Something like that! |
Would there be any interest in randomized solvers (e.g., via sketching or randomized trace estimation)? |
Absolutely! |
The text was updated successfully, but these errors were encountered: