Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove tensorflow warning around tf.function on per_field_where #902

Open
cmarlin opened this issue Dec 6, 2023 · 0 comments
Open

remove tensorflow warning around tf.function on per_field_where #902

cmarlin opened this issue Dec 6, 2023 · 0 comments

Comments

@cmarlin
Copy link
Contributor

cmarlin commented Dec 6, 2023

Hello,
The tf.function decorator on function per_field_where (in file tf_agents/utils/nest_utils.py) generates a tensorflow warning:

WARNING:tensorflow:5 out of the last 8 calls to <function where.<locals>.per_field_where at 0x7f07a6484b80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.

A simple python example could reproduce this issue:

import tensorflow as tf
import tf_agents

def test(condition, true_outputs, false_outputs):
    return tf_agents.utils.nest_utils.where(condition, true_outputs, false_outputs)

if __name__ == "__main__":
  for _ in range(10):
    condition = tf.convert_to_tensor([True, True], dtype=tf.bool)
    true_outputs = tf.convert_to_tensor([0, 1], dtype=tf.int32)
    false_outputs = tf.convert_to_tensor([2, 3], dtype=tf.int32)
    test(condition, true_outputs, false_outputs)

The issue seems related to the inner/local function definition, so it could be easily fixed by either:

  • removing the tf.function decorator
  • moving the per_field_where function outside of where function. Note we need to provide 'condition' and 'condition_rank' variables, so we could replace the statement with 'return tf.nest.map_structure(lambda t, f : _per_field_where(t, f, condition, condition_rank), true_outputs, false_outputs)'

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant