Help and tips on how to implement an algorithm in jax #21658
-
Hello, I am wondering how i might go about implementing efficiently the random walk algorithm above using jax. One way i thought about was to vmap a tensor of positions (N, 3) then in the vmaped function use a lax.while with the given conditions. This however is slow especially for multiple walks. What i would ideally want to do is to walk all my points simultaneously. This can easily be done with lax.for_i however it is not very straighforward to enforce the conditions on each element of the tensor when implementing in this way. Any help and tips would be greatly appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi - thanks for the question! In general JAX is not going to be particularly performant for loopy algorithms like this one. JAX works best with vectorized operations on accelerators, and iterative algorithms explicitly disallow the kinds of implicit parallelism that makes operations on GPU and TPU fast. That said, if I were implementing this algorithm I'd use |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! In general JAX is not going to be particularly performant for loopy algorithms like this one. JAX works best with vectorized operations on accelerators, and iterative algorithms explicitly disallow the kinds of implicit parallelism that makes operations on GPU and TPU fast.
That said, if I were implementing this algorithm I'd use
jax.lax.while_loop
to implement the logic for a single value, and then usevmap
to efficiently create a batched version. Then you don't have to worry about manually tracking the conditions of each element in the tensor, and thevmap
solution will be as fast as any hand-written batched version.