diff --git a/keras_fsl/models/head_models/learnt_norms.py b/keras_fsl/models/head_models/learnt_norms.py index 2bfb390..e407597 100644 --- a/keras_fsl/models/head_models/learnt_norms.py +++ b/keras_fsl/models/head_models/learnt_norms.py @@ -1,5 +1,5 @@ -import numpy as np import tensorflow as tf +import numpy as np from keras import activations from keras.layers import ( Concatenate, @@ -11,7 +11,7 @@ ) from keras.mixed_precision import global_policy from keras.models import Model -from tensorflow.python.keras.layers import Activation +from keras.layers import Activation def LearntNorms(input_shape, use_bias=True, activation="sigmoid"): diff --git a/keras_fsl/models/head_models/tests/__init__.py b/keras_fsl/models/head_models/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/keras_fsl/models/head_models/tests/learnt_norms_test.py b/keras_fsl/models/head_models/tests/learnt_norms_test.py index 4fd159b..23173b5 100644 --- a/keras_fsl/models/head_models/tests/learnt_norms_test.py +++ b/keras_fsl/models/head_models/tests/learnt_norms_test.py @@ -1,6 +1,7 @@ import numpy as np import tensorflow as tf from absl.testing import parameterized +from keras import mixed_precision from keras.optimizers import RMSprop from tensorflow.python.keras.keras_parameterized import TestCase, run_all_keras_modes, run_with_all_model_types @@ -34,8 +35,8 @@ def test_should_fit(self, input_shape): ("float64", "float64", "float64"), ) def test_last_activation_fp32_in_mixed_precision(self, mixed_precision_policy, expected_last_layer_dtype_policy): - policy = tf.keras.mixed_precision.Policy(mixed_precision_policy) - tf.keras.mixed_precision.set_policy(policy) + policy = mixed_precision.Policy(mixed_precision_policy) + mixed_precision.set_global_policy(policy) learnt_norms = LearntNorms(input_shape=(10,)) # Check dtype policy of internal non-input layers @@ -43,7 +44,7 @@ def test_last_activation_fp32_in_mixed_precision(self, mixed_precision_policy, e assert layer._dtype_policy.name == mixed_precision_policy # Check dtype policy of last layer always at least FP32 - assert learnt_norms.layers[-1]._dtype_policy.name == expected_last_layer_dtype_policy + assert learnt_norms.layers[-1].dtype_policy.name == expected_last_layer_dtype_policy if __name__ == "__main__":