Skip to content

How to Use vmap with Variable-Length Sequences in JAX? #21437

Answered by jakevdp
ajithmoola asked this question in Q&A
Discussion options

You must be logged in to vote

Unfortunately, it's not possible to do what you're asking with jax.vmap. We've had some experiments along these lines (see e.g. #16541 and related work) but nothing that's yet ready to use.

I suspect your best option here would be to pad all batches to the same length, and then use vmap on the padded version.

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@ajithmoola
Comment options

@jakevdp
Comment options

@ajithmoola
Comment options

@ajithmoola
Comment options

@jakevdp
Comment options

Answer selected by ajithmoola
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants