Skip to content

Help and tips on how to implement an algorithm in jax #21658

Answered by jakevdp
elientumba2019 asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - thanks for the question! In general JAX is not going to be particularly performant for loopy algorithms like this one. JAX works best with vectorized operations on accelerators, and iterative algorithms explicitly disallow the kinds of implicit parallelism that makes operations on GPU and TPU fast.

That said, if I were implementing this algorithm I'd use jax.lax.while_loop to implement the logic for a single value, and then use vmap to efficiently create a batched version. Then you don't have to worry about manually tracking the conditions of each element in the tensor, and the vmap solution will be as fast as any hand-written batched version.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by elientumba2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants