Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
laurentm committed Aug 11, 2022
1 parent 59a97b1 commit a543412
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions keras_fsl/models/head_models/learnt_norms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import tensorflow as tf
import numpy as np
from keras import activations
from keras.layers import (
Concatenate,
Expand All @@ -11,7 +11,7 @@
)
from keras.mixed_precision import global_policy
from keras.models import Model
from tensorflow.python.keras.layers import Activation
from keras.layers import Activation


def LearntNorms(input_shape, use_bias=True, activation="sigmoid"):
Expand Down
Empty file.
7 changes: 4 additions & 3 deletions keras_fsl/models/head_models/tests/learnt_norms_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import tensorflow as tf
from absl.testing import parameterized
from keras import mixed_precision
from keras.optimizers import RMSprop
from tensorflow.python.keras.keras_parameterized import TestCase, run_all_keras_modes, run_with_all_model_types

Expand Down Expand Up @@ -34,16 +35,16 @@ def test_should_fit(self, input_shape):
("float64", "float64", "float64"),
)
def test_last_activation_fp32_in_mixed_precision(self, mixed_precision_policy, expected_last_layer_dtype_policy):
policy = tf.keras.mixed_precision.Policy(mixed_precision_policy)
tf.keras.mixed_precision.set_policy(policy)
policy = mixed_precision.Policy(mixed_precision_policy)
mixed_precision.set_global_policy(policy)
learnt_norms = LearntNorms(input_shape=(10,))

# Check dtype policy of internal non-input layers
for layer in learnt_norms.layers[2:-1]:
assert layer._dtype_policy.name == mixed_precision_policy

# Check dtype policy of last layer always at least FP32
assert learnt_norms.layers[-1]._dtype_policy.name == expected_last_layer_dtype_policy
assert learnt_norms.layers[-1].dtype_policy.name == expected_last_layer_dtype_policy


if __name__ == "__main__":
Expand Down

0 comments on commit a543412

Please sign in to comment.