Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New solvers #3

Open
4 tasks
patrick-kidger opened this issue May 28, 2023 · 6 comments
Open
4 tasks

New solvers #3

patrick-kidger opened this issue May 28, 2023 · 6 comments
Labels
feature New feature

Comments

@patrick-kidger
Copy link
Owner

patrick-kidger commented May 28, 2023

  • Stabilised Thomas algorithm. (For tridiagonal solves.)
  • Upper Hessenberg. (Useful inside GMRES? Needs to offer a pseudoinverse solution.)
  • Incremental GMRES
  • Handle "diagonal + low-rank" operators using Woodbury.
@patrick-kidger patrick-kidger added the feature New feature label May 28, 2023
@joglekara
Copy link

Thanks for the library. As they said, and another one...

Do you have any intuition whether it would be useful to swap in the cusparse kernels for, say, a tridiagonal solve, and add a custom_linear_solve for the gradient or is that more trouble than worth the performance boost (if any!)?

At the moment, I have a custom written broadcast-based function similar to the one in here but with a, b, c, d being 2D. I suppose I could've used vmap but I thought to trust broadcasting.

But would something like this be better than what XLA comes up with?

https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-t-gtsv2stridedbatch

Thanks again

@patrick-kidger
Copy link
Owner Author

I think that would definitely be useful. Note that we actually already have the gradient via our custom primitive (linear_solve_p) so the only difficulty with using cusparse might be adding batch dimensions; off the top of my head I'm not sure if they support that.

But at least in the unbatched case, I reckon that probably would give a performance boost. (Plus I believe their tridiagonal algorithm is stabilised, whereas our current Thomas algorithm is unstablised.)

The only reason we haven't pursued this is that it's not straightforward to make this happen in JAX -- calling out to custom kernels like this isn't very well supported.

@joglekara
Copy link

Couple of disparate thoughts

@patrick-kidger
Copy link
Owner Author

Something like that!
@packquickly may recall why we didn't use the built-in tridiagonal solve algorithm. IIRC it doesn't support batching, and that has to be patched in? (E.g. by vmap'ing jax._src.lax.linalg_tridiagonal_solve_jax.)

@quattro
Copy link
Contributor

quattro commented Aug 10, 2023

Would there be any interest in randomized solvers (e.g., via sketching or randomized trace estimation)?

@patrick-kidger
Copy link
Owner Author

Absolutely!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants