Replies: 1 comment 8 replies
-
JAX's JIT compilation assumes the function you pass it is pure (see JAX sharp bits: pure functions. Under this assumption, the compiler can elliminate any computation that doesn't affect the return value. When you jit-compile a function with no return value, it means all its contents can be elliminated: and running an empty function will generally be much faster than running a funciton that does some number of computations. The root of the issue here is that you're JIT-compiling a function with impure semantics: that is, it appears to rely on side-affects like in-place mutation of the Hope that helps! |
Beta Was this translation helpful? Give feedback.
-
Hi.
I'm trying to solve multiple QPs by using vmap and JIT, but if the function runQP (where it solves QP) has the line 'return solutions' in the end, it takes 0.7s.
In contrast, if there isn't
return sol.primal
, then it takes about 0.03s. But then I can only have access to traced values.My questions are as follows:
Can anyone please help me explain why including 'return' part makes the computation much slower? Is this normal or have I made any mistake?
JIT uses traced values, so
jitted_runQP_vectorized
should return traced values, but it returns the actual data values. So does this mean that the functionrunQP
is not jitted properly?Is there any other way to make this computation much faster? My goal is to solve and return solutions within 20ms.
I will sincerely appreciate all the comments and suggestions!
Beta Was this translation helpful? Give feedback.
All reactions