Skip to content

Commit

Permalink
add normalized_emb
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Jul 2, 2024
1 parent a8642aa commit dc887f1
Showing 1 changed file with 107 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from absl import flags
from absl import app
from keras.src.layers import LayerNormalization

os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" #VERY IMPORTANT!
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
Expand Down Expand Up @@ -240,6 +241,106 @@ def get_kv_creator(mpi_size: int,
return de.CuckooHashTableCreator(saver=saver)


class DynamicLayerNormalization(LayerNormalization):

def call(self, inputs):
# TODO(b/229545225): Remove the RaggedTensor check.
is_ragged = isinstance(inputs, tf.RaggedTensor)
if is_ragged:
inputs_lengths = inputs.nested_row_lengths()
inputs = inputs.to_tensor()
inputs = tf.cast(inputs, self.compute_dtype)
# Compute the axes along which to reduce the mean / variance
input_shape = tf.shape(inputs)
ndims = input_shape.shape[0] # Get the number of dimensions dynamically

# Broadcasting only necessary for norm when the axis is not just
# the last dimension
broadcast_shape = [1] * ndims
for dim in self.axis:
broadcast_shape[dim] = input_shape[dim]

def _broadcast(v):
if (
v is not None
and len(v.shape) != ndims
and self.axis != [ndims - 1]
):
return tf.reshape(v, broadcast_shape)
return v

if not self._fused:
input_dtype = inputs.dtype
if (
input_dtype in ("float16", "bfloat16")
and self.dtype == "float32"
):
# If mixed precision is used, cast inputs to float32 so that
# this is at least as numerically stable as the fused version.
inputs = tf.cast(inputs, "float32")

# Calculate the moments on the last axis (layer activations).
mean, variance = tf.nn.moments(inputs, self.axis, keepdims=True)

scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

# Compute layer normalization using the batch_normalization
# function.
outputs = tf.nn.batch_normalization(
inputs,
mean,
variance,
offset=offset,
scale=scale,
variance_epsilon=self.epsilon,
)
outputs = tf.cast(outputs, input_dtype)
else:
# Collapse dims before self.axis, and dims in self.axis

axis = sorted(self.axis)
tensor_shape = tf.shape(inputs)
pre_dim = tf.reduce_prod(tensor_shape[: axis[0]])
in_dim = tf.reduce_prod(tensor_shape[axis[0] :])
squeezed_shape = [1, pre_dim, in_dim, 1]
# This fused operation requires reshaped inputs to be NCHW.
data_format = "NCHW"

inputs = tf.reshape(inputs, squeezed_shape)

# self.gamma and self.beta have the wrong shape for
# fused_batch_norm, so we cannot pass them as the scale and offset
# parameters. Therefore, we create two constant tensors in correct
# shapes for fused_batch_norm and later construct a separate
# calculation on the scale and offset.
scale = tf.ones([pre_dim], dtype=self.dtype)
offset = tf.zeros([pre_dim], dtype=self.dtype)

# Compute layer normalization using the fused_batch_norm function.
outputs, _, _ = tf.compat.v1.nn.fused_batch_norm(
inputs,
scale=scale,
offset=offset,
epsilon=self.epsilon,
data_format=data_format,
)

outputs = tf.reshape(outputs, tensor_shape)

scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

if scale is not None:
outputs = outputs * tf.cast(scale, outputs.dtype)
if offset is not None:
outputs = outputs + tf.cast(offset, outputs.dtype)

# If some components of the shape got lost due to adjustments, fix that.
outputs = tf.reshape(outputs, input_shape)

if is_ragged:
outputs = tf.RaggedTensor.from_tensor(outputs, inputs_lengths)
return outputs

class ChannelEmbeddingLayers(tf.keras.layers.Layer):

def __init__(self,
Expand Down Expand Up @@ -371,7 +472,7 @@ def __init__(self,
embedding_initializer=embedding_initializer,
mpi_size=mpi_size,
mpi_rank=mpi_rank)

self.dynamic_layer_norm = DynamicLayerNormalization()
self.dnn1 = tf.keras.layers.Dense(
64,
activation='relu',
Expand Down Expand Up @@ -427,22 +528,24 @@ def call(self, features):
for key, value in feature_info_spec.items()
if key in user_fea
}
user_latent = self.user_embedding(user_fea_info)
movie_fea = ['movie_id', 'movie_genres', 'user_occupation_label']
movie_fea = [i for i in features.keys() if i in movie_fea]
movie_fea_info = {
key: value
for key, value in feature_info_spec.items()
if key in movie_fea
}
user_latent = self.user_embedding(user_fea_info)
movie_latent = self.movie_embedding(movie_fea_info)
latent = tf.concat([user_latent, movie_latent], axis=1)

x = self.dnn1(latent)

normalized_emb = self.dynamic_layer_norm(latent)
x = self.dnn1(normalized_emb)
x = self.dnn2(x)
x = self.dnn3(x)

bias = self.bias_net(latent)
bias = self.bias_net(normalized_emb)
x = 0.2 * x + 0.8 * bias
user_rating = tf.keras.layers.Lambda(lambda x: x, name='user_rating')(x)
return {'user_rating': user_rating}
Expand Down

0 comments on commit dc887f1

Please sign in to comment.