How to Use vmap with Variable-Length Sequences in JAX? #21437
-
I'm working with JAX and trying to vectorize a function that operates on variable-length sequences derived from indices. The sequences vary in length according to a previously calculated array of segment lengths. I'm facing challenges with JAX's static shape requirements when trying to use
curr_ind = ind[jnp.arange(a, b)] results in the following error: ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]. It arose in the jnp.arange argument 'start'. repeated_matrix = jnp.tile(param, (b - a, 1)) results in the following error: TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape int32[]. How can I use vmap to handle these operations for variable-length sequences without introducing padding or restructuring my arrays significantly? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Unfortunately, it's not possible to do what you're asking with I suspect your best option here would be to pad all batches to the same length, and then use |
Beta Was this translation helpful? Give feedback.
Unfortunately, it's not possible to do what you're asking with
jax.vmap
. We've had some experiments along these lines (see e.g. #16541 and related work) but nothing that's yet ready to use.I suspect your best option here would be to pad all batches to the same length, and then use
vmap
on the padded version.