Memory consumption with non pure functions #21541
Unanswered
LucaMantani
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all,
I have been using JAX for my project and noticed an increase in memory when using jax.jit compilation.
I figured out that the behaviour boils down to what happens in this little snippet:
I run the script with
When the first call to the jitted function happens, there is a doubling in memory if the function called is fn(A), otherwise if I call fn2(A, A), this does not happen. Do I understand correctly that this is an expected behaviour since the first function being non-pure, will save in memory a copy of the content of the array A?
I saw an old issue discussing something similar #5071 .
The issue for me is that I am building a complex function which is full of constant arrays and I was jitting the functions without passing them all the time as inputs because it was more convenient to write and more readable.
Beta Was this translation helpful? Give feedback.
All reactions