From 6e4baf6732bb7deb7c22b4c528cd764061bba881 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Mon, 22 Apr 2024 14:31:56 +0000 Subject: [PATCH 1/6] POC of primitives selections Signed-off-by: Reese Wang --- transformer_engine/jax/cpp_extensions.py | 57 +++++++++++++++++++++++- transformer_engine/jax/layernorm.py | 12 ++--- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 3356aafef5..790be98aa9 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -8,9 +8,11 @@ from functools import partial, reduce import operator import os +import re import warnings import numpy as np +import jax import jax.numpy as jnp from jax.lib import xla_client from jax import core, dtypes @@ -124,6 +126,16 @@ def _check_valid_batch_dims(bdims): f"but got {dim=}" +def enable_primitive(primitive_name): + """ + Args: primitive name + Return: whether to enable this primitive + """ + pattern = os.getenv('NVTE_PRIMITIVES_RE', r'.*') + pattern = re.compile(pattern) + return pattern.match(primitive_name) is not None + + class BasePrimitive(metaclass=ABCMeta): """ jax primitive @@ -481,11 +493,37 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): register_primitive(LayerNormFwdPrimitive) +def native_layernorm(x, gamma, beta, zero_centered_gamma, eps): + """ + JAX native layernorm implementations + - bias is not None: layernorm + - bias is None: rmsnorm + """ + x_ = jnp.asarray(x, jnp.float32) + if beta is None: + mean = 0. + else: + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1. + if beta is None: + beta = 0. + return jnp.asarray(normed_input * gamma + beta).astype(x.dtype) + + def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): """ Wrapper for TE layernorm fwd """ + if not enable_primitive(LayerNormFwdPrimitive.name): + x_ = jnp.asarray(x, jnp.float32) + mu = jnp.mean(x_, axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon) + return native_layernorm(x, gamma, beta, zero_centered_gamma, + epsilon), mu.flatten(), rsigma.flatten() return LayerNormFwdPrimitive.outer_primitive.bind(x, gamma, beta, @@ -691,10 +729,15 @@ def sharded_impl(dz, x, mu, rsigma, gamma): def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, - gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): + gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): """ Wrapper for TE layernorm bwd """ + if not enable_primitive(LayerNormBwdPrimitive.name): + _, vjp_func = jax.vjp( + partial(native_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon), x, + gamma, beta) + return vjp_func(dz) return LayerNormBwdPrimitive.outer_primitive.bind(dz, x, mu, @@ -2659,6 +2702,8 @@ def gelu(inputs: jnp.ndarray) -> jnp.ndarray: Return geglu(inputs) Assume inputs has two dimensions shape and the memory layout is (N..., H) """ + if not enable_primitive(GeluPrimitive.name): + return jax.nn.gelu(inputs) return GeluPrimitive.outer_primitive.bind(inputs) @@ -2770,8 +2815,11 @@ def partition(mesh, arg_infos, result_infos): def dgelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: """ dgelu fusion wrapper - Return dgeglu(inputs) + Return dgelu(inputs) """ + if not enable_primitive(GeluPrimitive.name): + _, vjp_func = jax.vjp(jax.nn.gelu, gelu_inputs) + return vjp_func(inputs)[0] return DGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs) @@ -4097,6 +4145,11 @@ def gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: j gated gelu wrapper Return FP8(geglu(x)) """ + if not enable_primitive(GeluFp8Primitive.name): + gelu_out = jax.nn.gelu(x) + cast_gelu_out = (gelu_out * scale).astype(out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(x)).astype(amax.dtype)) + return cast_gelu_out, updated_amax return GeluFp8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index acf49639d4..0518f0a8f1 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -72,17 +72,18 @@ def _layernorm_fwd_rule(x, mu = None else: raise ValueError(f"{layernorm_type=} is not supported.") - return output, (x, mu, rsigma, gamma) + return output, (x, mu, rsigma, gamma, beta) def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz): - x, mu, rsigma, gamma = ctx + x, mu, rsigma, gamma, beta = ctx if layernorm_type == 'layernorm': dx, dgamma, dbeta = layernorm_bwd(dz, x, mu, rsigma, gamma, + beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) elif layernorm_type == 'rmsnorm': @@ -215,8 +216,8 @@ def _layernorm_fp8_dot_fwd_rule( get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) ctx = (ln_out, casted_kernel, fp8_max, amax, scale, scale_inv, updated_x_amax, - updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, x_contracting_dims, - k_contracting_dims) + updated_kernel_amax, x.shape, kernel.shape, mu, rsigma, x, gamma, beta, + x_contracting_dims, k_contracting_dims) return output, ctx @@ -233,7 +234,7 @@ def _layernorm_fp8_dot_bwd_rule( grad): ln_out_, casted_kernel, fp8_max, amax, scale, scale_inv, \ updated_x_amax, updated_kernel_amax, \ - x_shape, kernel_shape, mu, rsigma, x, gamma, \ + x_shape, kernel_shape, mu, rsigma, x, gamma, beta, \ x_contracting_dims, k_contracting_dims = ctx ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1) @@ -270,6 +271,7 @@ def _layernorm_fp8_dot_bwd_rule( mu, rsigma, gamma, + beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) else: From c4301ddb121c4391287b85c4608f10ace8a9e932 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 24 Apr 2024 08:19:45 +0000 Subject: [PATCH 2/6] Add RMSNorm JAX graph and fixes a few ln bugs Signed-off-by: Reese Wang --- transformer_engine/jax/cpp_extensions.py | 39 +++++++++++++++++------- transformer_engine/jax/mlp.py | 10 +++--- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 790be98aa9..5467f164da 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -133,7 +133,9 @@ def enable_primitive(primitive_name): """ pattern = os.getenv('NVTE_PRIMITIVES_RE', r'.*') pattern = re.compile(pattern) - return pattern.match(primitive_name) is not None + result = pattern.match(primitive_name) is not None + print(f'{primitive_name=} {pattern=} {result=} {pattern.match(primitive_name)=}') + return result class BasePrimitive(metaclass=ABCMeta): @@ -496,23 +498,28 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): def native_layernorm(x, gamma, beta, zero_centered_gamma, eps): """ JAX native layernorm implementations - - bias is not None: layernorm - - bias is None: rmsnorm """ x_ = jnp.asarray(x, jnp.float32) - if beta is None: - mean = 0. - else: - mean = jnp.mean(x_, axis=-1, keepdims=True) + mean = jnp.mean(x_, axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) if zero_centered_gamma: gamma += 1. - if beta is None: - beta = 0. return jnp.asarray(normed_input * gamma + beta).astype(x.dtype) +def native_rmsnorm(x, gamma, zero_centered_gamma, eps): + """ + JAX native rmsnorm implementations + """ + x_ = jnp.asarray(x, jnp.float32) + var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + normed_input = x_ * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1. + return jnp.asarray(normed_input * gamma).astype(x.dtype) + + def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): """ @@ -523,7 +530,7 @@ def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_ce mu = jnp.mean(x_, axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon) return native_layernorm(x, gamma, beta, zero_centered_gamma, - epsilon), mu.flatten(), rsigma.flatten() + epsilon), jnp.squeeze(mu, axis=-1), jnp.squeeze(rsigma, axis=-1) return LayerNormFwdPrimitive.outer_primitive.bind(x, gamma, beta, @@ -919,6 +926,12 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float): """ Wrapper for TE rmsnorm fwd """ + if not enable_primitive(RmsNormFwdPrimitive.name): + x_ = jnp.asarray(x, jnp.float32) + # mu = jnp.mean(x_, axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon) + return native_rmsnorm(x, gamma, zero_centered_gamma=False, + eps=epsilon), jnp.squeeze(rsigma, axis=-1) return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon) @@ -1105,6 +1118,10 @@ def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp """ Wrapper for TE layernorm bwd """ + if not enable_primitive(RmsNormBwdPrimitive.name): + _, vjp_func = jax.vjp(partial(native_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, + gamma) + return vjp_func(dz) return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) @@ -2817,7 +2834,7 @@ def dgelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: dgelu fusion wrapper Return dgelu(inputs) """ - if not enable_primitive(GeluPrimitive.name): + if not enable_primitive(DGeluPrimitive.name): _, vjp_func = jax.vjp(jax.nn.gelu, gelu_inputs) return vjp_func(inputs)[0] return DGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs) diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 3b531a6150..3f2a454f17 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -264,7 +264,7 @@ def _layernorm_geglu_fp8_mlp_fwd_rule( get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1, + ctx = (x, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_geglu_out, casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims) @@ -284,7 +284,7 @@ def _layernorm_geglu_fp8_mlp_bwd_rule( ffn2_ckpt_name, # pylint: disable=unused-argument ctx, grad): - x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \ + x, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_geglu_out, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ x_contracting_dims, xt_batch_dims = ctx @@ -360,6 +360,7 @@ def _layernorm_geglu_fp8_mlp_bwd_rule( mu, rsigma, gamma, + beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) else: @@ -571,7 +572,7 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( dot_2_output += jnp.reshape(bias_2, bias_2_shape) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1, + ctx = (x, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_gelu_out, casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape) @@ -592,7 +593,7 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( ffn2_ckpt_name, # pylint: disable=unused-argument ctx, grad): - x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \ + x, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_gelu_out, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx @@ -673,6 +674,7 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( mu, rsigma, gamma, + beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon) else: From 9818b2bf08da5bd382f0ff831d92fa2ad338fc19 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Wed, 24 Apr 2024 16:17:10 +0000 Subject: [PATCH 3/6] Add JAX graph to softmax Signed-off-by: Reese Wang --- transformer_engine/jax/cpp_extensions.py | 57 +++++++++++++++++++++--- transformer_engine/jax/softmax.py | 10 ++--- 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 5467f164da..21fcf718e2 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -134,7 +134,7 @@ def enable_primitive(primitive_name): pattern = os.getenv('NVTE_PRIMITIVES_RE', r'.*') pattern = re.compile(pattern) result = pattern.match(primitive_name) is not None - print(f'{primitive_name=} {pattern=} {result=} {pattern.match(primitive_name)=}') + print(f'{primitive_name=} {pattern=} {result=} {pattern.match(primitive_name)=}', flush=True) return result @@ -1444,6 +1444,9 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: scaled_softmax_forward wrapper Return FP16/BF16 tensor """ + if not enable_primitive(ScaledSoftmaxFwdPrimitive.name): + return jax.nn.softmax(scale_factor * logits) + return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) @@ -1514,12 +1517,20 @@ def partition(scale_factor, mesh, arg_infos, result_infos): register_primitive(ScaledSoftmaxBwdPrimitive) -def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, +# TODO(rewang): check if there is regression +def scaled_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: """ scaled_backward wrapper Return FP16/BF16 tensor """ + if not enable_primitive(ScaledSoftmaxBwdPrimitive.name): + + def scaled_softmax(logits): + return jax.nn.softmax(scale_factor * logits) + + _, vjp_func = jax.vjp(scaled_softmax, logits) + return vjp_func(dz)[0] return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(dz, softmax_out, scale_factor=scale_factor) @@ -1655,6 +1666,13 @@ def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray, scaled_masked_softmax_forward wrapper Return FP16/BF16 tensor """ + if not enable_primitive(ScaledMaskedSoftmaxFwdPrimitive.name): + if mask is not None: + logits += jax.lax.select(mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.).astype(logits.dtype)) + return jax.nn.softmax(logits * scale_factor) + return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits, mask, scale_factor=scale_factor) @@ -1726,17 +1744,29 @@ def partition(scale_factor, mesh, arg_infos, result_infos): register_primitive(ScaledMaskedSoftmaxBwdPrimitive) -def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: +def scaled_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, logits: jnp.ndarray, + mask: jnp.ndarray, scale_factor: float) -> jnp.ndarray: """ scaled_masked_backward wrapper Return FP16/BF16 tensor """ + if not enable_primitive(ScaledMaskedSoftmaxBwdPrimitive.name): + + def scaled_masked_softmax(logits): + if mask is not None: + logits += jax.lax.select(mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.).astype(logits.dtype)) + return jax.nn.softmax(logits * scale_factor) + + _, vjp_func = jax.vjp(scaled_masked_softmax, logits) + return vjp_func(dz)[0] return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(dz, softmax_out, scale_factor=scale_factor) +# TODO(rewang): -1e10? class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): """ Scaled Upper Triang Masked Softmax Fwd Primitive @@ -1819,6 +1849,12 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl scaled_upper_triang_masked_softmax_forward wrapper Return FP16/BF16 tensor """ + if not enable_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name): + mask = 1 - jnp.tril(jnp.ones_like(logits)) + logits += jax.lax.select(mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.).astype(logits.dtype)) + return jax.nn.softmax(logits * scale_factor) return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( logits, scale_factor=scale_factor) @@ -1893,11 +1929,22 @@ def partition(scale_factor, mesh, arg_infos, result_infos): def scaled_upper_triang_masked_softmax_bwd(dz: jnp.ndarray, softmax_out: jnp.ndarray, - scale_factor: float) -> jnp.ndarray: + logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: """ scaled_upper_triang_masked_backward wrapper Return FP16/BF16 tensor """ + if not enable_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name): + # TODO(rewang): nan + def scaled_upper_triang_masked_softmax(logits): + mask = 1 - jnp.tril(jnp.ones_like(logits)) + logits += jax.lax.select(mask > 0, + jnp.full(mask.shape, -1e10).astype(logits.dtype), + jnp.full(mask.shape, 0.).astype(logits.dtype)) + return jax.nn.softmax(logits * scale_factor) + + _, vjp_func = jax.vjp(scaled_upper_triang_masked_softmax, logits) + return vjp_func(dz)[0] return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( dz, softmax_out, scale_factor=scale_factor) diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index dece204f4d..d6962d99f9 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -70,18 +70,18 @@ def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type): else: output = scaled_softmax_fwd(logits, scale_factor) - return output, (output,) + return output, (output, logits, mask) def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz): - softmax_output, = ctx + softmax_output, logits, mask = ctx if softmax_type is SoftmaxType.SCALED_MASKED: - dgrad = scaled_masked_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED: - dgrad = scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor) else: - dgrad = scaled_softmax_bwd(dz, softmax_output, scale_factor) + dgrad = scaled_softmax_bwd(dz, softmax_output, logits, scale_factor) return (dgrad, None) From 2a3e80544375a37e9bbbbca88ea02e819b840e42 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 25 Apr 2024 08:48:33 +0000 Subject: [PATCH 4/6] Add layernorm/rmsnrom_fp8 jax graphs Signed-off-by: Reese Wang --- transformer_engine/jax/cpp_extensions.py | 59 +++++++++------------- transformer_engine/jax/layer_math.py | 64 ++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 34 deletions(-) create mode 100644 transformer_engine/jax/layer_math.py diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 21fcf718e2..3931a587d1 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -30,6 +30,7 @@ from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine_jax import NVTE_Fused_Attn_Backend +from . import layer_math from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_sum_along_dp_fsdp from .sharding import get_all_mesh_axes, num_of_devices @@ -495,31 +496,6 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): register_primitive(LayerNormFwdPrimitive) -def native_layernorm(x, gamma, beta, zero_centered_gamma, eps): - """ - JAX native layernorm implementations - """ - x_ = jnp.asarray(x, jnp.float32) - mean = jnp.mean(x_, axis=-1, keepdims=True) - var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) - normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) - if zero_centered_gamma: - gamma += 1. - return jnp.asarray(normed_input * gamma + beta).astype(x.dtype) - - -def native_rmsnorm(x, gamma, zero_centered_gamma, eps): - """ - JAX native rmsnorm implementations - """ - x_ = jnp.asarray(x, jnp.float32) - var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) - normed_input = x_ * jax.lax.rsqrt(var + eps) - if zero_centered_gamma: - gamma += 1. - return jnp.asarray(normed_input * gamma).astype(x.dtype) - - def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float): """ @@ -529,8 +505,8 @@ def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_ce x_ = jnp.asarray(x, jnp.float32) mu = jnp.mean(x_, axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon) - return native_layernorm(x, gamma, beta, zero_centered_gamma, - epsilon), jnp.squeeze(mu, axis=-1), jnp.squeeze(rsigma, axis=-1) + return layer_math.layernorm(x, gamma, beta, zero_centered_gamma, + epsilon), jnp.squeeze(mu, axis=-1), jnp.squeeze(rsigma, axis=-1) return LayerNormFwdPrimitive.outer_primitive.bind(x, gamma, beta, @@ -742,7 +718,7 @@ def layernorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp. """ if not enable_primitive(LayerNormBwdPrimitive.name): _, vjp_func = jax.vjp( - partial(native_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon), x, + partial(layer_math.layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon), x, gamma, beta) return vjp_func(dz) return LayerNormBwdPrimitive.outer_primitive.bind(dz, @@ -928,10 +904,9 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float): """ if not enable_primitive(RmsNormFwdPrimitive.name): x_ = jnp.asarray(x, jnp.float32) - # mu = jnp.mean(x_, axis=-1, keepdims=True) rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon) - return native_rmsnorm(x, gamma, zero_centered_gamma=False, - eps=epsilon), jnp.squeeze(rsigma, axis=-1) + return layer_math.rmsnorm(x, gamma, zero_centered_gamma=False, + eps=epsilon), jnp.squeeze(rsigma, axis=-1) return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon) @@ -1119,8 +1094,8 @@ def rmsnorm_bwd(dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray, gamma: jnp Wrapper for TE layernorm bwd """ if not enable_primitive(RmsNormBwdPrimitive.name): - _, vjp_func = jax.vjp(partial(native_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, - gamma) + _, vjp_func = jax.vjp(partial(layer_math.rmsnorm, zero_centered_gamma=False, eps=epsilon), + x, gamma) return vjp_func(dz) return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon) @@ -1672,7 +1647,6 @@ def scaled_masked_softmax_fwd(logits: jnp.ndarray, mask: jnp.ndarray, jnp.full(mask.shape, -1e10).astype(logits.dtype), jnp.full(mask.shape, 0.).astype(logits.dtype)) return jax.nn.softmax(logits * scale_factor) - return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(logits, mask, scale_factor=scale_factor) @@ -3831,6 +3805,15 @@ def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, ama """ Wrapper for TE layernorm fwd (fp8 out) """ + if not enable_primitive(LayerNormFwdFp8Primitive.name): + return layer_math.layernorm_fp8(x, + gamma, + beta, + scale, + amax, + out_dtype=out_dtype, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon) return LayerNormFwdFp8Primitive.outer_primitive.bind(x, gamma, beta, @@ -4064,6 +4047,14 @@ def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale """ Wrapper for TE rmsnorm fwd (fp8 out) """ + if not enable_primitive(RmsNormFwdFp8Primitive.name): + return layer_math.rmsnorm_fp8(x, + gamma, + scale, + amax, + out_dtype=out_dtype, + zero_centered_gamma=False, + eps=epsilon) return RmsNormFwdFp8Primitive.outer_primitive.bind(x, gamma, amax, diff --git a/transformer_engine/jax/layer_math.py b/transformer_engine/jax/layer_math.py new file mode 100644 index 0000000000..528d806e80 --- /dev/null +++ b/transformer_engine/jax/layer_math.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX implementation of layers""" +import jax +import jax.numpy as jnp + + +def layernorm(x, gamma, beta, zero_centered_gamma, eps): + """ + JAX native layernorm implementations + """ + x_ = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1. + return jnp.asarray(normed_input * gamma + beta).astype(x.dtype) + + +def rmsnorm(x, gamma, zero_centered_gamma, eps): + """ + JAX native rmsnorm implementations + """ + x_ = jnp.asarray(x, jnp.float32) + var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + normed_input = x_ * jax.lax.rsqrt(var + eps) + if zero_centered_gamma: + gamma += 1. + return jnp.asarray(normed_input * gamma).astype(x.dtype) + + +def layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps): + """ + JAX native layernorm fp8 implementations + """ + x_ = jnp.asarray(x, jnp.float32) + mean = jnp.mean(x_, axis=-1, keepdims=True) + var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(var + eps) + normed_input = (x_ - mean) * rsigma + if zero_centered_gamma: + gamma += 1. + output = normed_input * gamma + beta + casted_output = (scale * output).astype(out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(output)).astype(amax.dtype)) + return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax + + +def rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps): + """ + JAX native rmsnorm fp8 implementations + """ + x_ = jnp.asarray(x, jnp.float32) + var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + rsigma = jax.lax.rsqrt(var + eps) + normed_input = x_ * rsigma + if zero_centered_gamma: + gamma += 1. + output = normed_input * gamma + casted_output = (scale * output).astype(out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(output)).astype(amax.dtype)) + return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax From 82f00aaa404e87488aafe6dac604a3c3c8976934 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 25 Apr 2024 15:49:36 +0000 Subject: [PATCH 5/6] Add geglu/cast/transpose/cast_transpose jax graphs Signed-off-by: Reese Wang --- transformer_engine/jax/cpp_extensions.py | 56 +++++--------- transformer_engine/jax/layer_math.py | 95 ++++++++++++++++++++++-- 2 files changed, 108 insertions(+), 43 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 3931a587d1..b61eb79729 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -31,6 +31,7 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend from . import layer_math +from .layer_math import _multidim_transpose, _normalize_axis_boundary from .sharding import all_reduce_max_along_all_axes_except_PP from .sharding import all_reduce_sum_along_dp_fsdp from .sharding import get_all_mesh_axes, num_of_devices @@ -2968,6 +2969,8 @@ def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray: Return FP8(geglu(inputs)) Assume inputs has two dimensions shape and the memory layout is (N, 2, H) """ + if not enable_primitive(GatedGeluPrimitive.name): + return layer_math.gated_gelu(inputs) return GatedGeluPrimitive.outer_primitive.bind(inputs) @@ -3085,47 +3088,15 @@ def partition(mesh, arg_infos, result_infos): register_primitive(DgatedGeluPrimitive) -def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray: +def dgated_gelu(inputs: jnp.ndarray, geglu_inputs: jnp.ndarray) -> jnp.ndarray: """ dgated_gelu fusion wrapper Return dgeglu(inputs) """ - return DgatedGeluPrimitive.outer_primitive.bind(inputs, gelu_inputs) - - -def _normalize_axis_boundary(axis, ndim): - return axis if axis >= 0 else ndim + axis - - -def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): - """ - te_cast_transpose_p multi-dims transpose - - static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be - involved into transpose, -1 means all axes involve into transpose. - transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for - transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary - - examples: - X in shape (dim0, dim1, dim2, dim3, dim4) - - static_axis_boundary == -1, transpose_axis_boundary == 2 - Xt = (dim2, dim3, dim4, dim0, dim1) - - static_axis_boundary == 0, transpose_axis_boundary == 2 - Xt = (dim0, dim2, dim3, dim4, dim1) - - static_axis_boundary == 0, transpose_axis_boundary == 3 - Xt = (dim0, dim3, dim4, dim1. dim2) - """ - if static_axis_boundary < 0: - static_axis_boundary = -1 # means no static axes - assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. - transpose_start_idx = static_axis_boundary + 1 - transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape)) - assert transpose_start_idx < transpose_axis_boundary - return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:], - *shape[transpose_start_idx:transpose_axis_boundary]) + if not enable_primitive(DgatedGeluPrimitive.name): + _, vjp_func = jax.vjp(layer_math.gated_gelu, geglu_inputs) + return vjp_func(inputs)[0] + return DgatedGeluPrimitive.outer_primitive.bind(inputs, geglu_inputs) class CastTransposePrimitive(BasePrimitive): @@ -3291,6 +3262,13 @@ def cast_transpose(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_ cast transpose wrapper Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale` """ + if not enable_primitive(CastTransposePrimitive.name): + return layer_math.cast_transpose(x, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) return CastTransposePrimitive.outer_primitive.bind( x, amax, @@ -3425,6 +3403,8 @@ def cast_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: j Cast wrapper Return FP8 tensor """ + if not enable_primitive(CastFP8Primitive.name): + return layer_math.cast_fp8(x, scale, amax, out_dtype=out_dtype) return CastFP8Primitive.outer_primitive.bind(x, amax, scale, scale_inv, out_dtype=out_dtype) @@ -3548,6 +3528,8 @@ def transpose(x: jnp.ndarray, static_axis_boundary: int, """ transpose wrapper """ + if not enable_primitive(TransposePrimitive.name): + return layer_math.transpose(x, static_axis_boundary, transpose_axis_boundary) return TransposePrimitive.outer_primitive.bind(x, static_axis_boundary=static_axis_boundary, transpose_axis_boundary=transpose_axis_boundary) diff --git a/transformer_engine/jax/layer_math.py b/transformer_engine/jax/layer_math.py index 528d806e80..4b697b4a4d 100644 --- a/transformer_engine/jax/layer_math.py +++ b/transformer_engine/jax/layer_math.py @@ -8,7 +8,7 @@ def layernorm(x, gamma, beta, zero_centered_gamma, eps): """ - JAX native layernorm implementations + JAX native layernorm implementation """ x_ = jnp.asarray(x, jnp.float32) mean = jnp.mean(x_, axis=-1, keepdims=True) @@ -21,7 +21,7 @@ def layernorm(x, gamma, beta, zero_centered_gamma, eps): def rmsnorm(x, gamma, zero_centered_gamma, eps): """ - JAX native rmsnorm implementations + JAX native rmsnorm implementation """ x_ = jnp.asarray(x, jnp.float32) var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) @@ -31,9 +31,19 @@ def rmsnorm(x, gamma, zero_centered_gamma, eps): return jnp.asarray(normed_input * gamma).astype(x.dtype) +def quantize(x, scale, q_dtype): + """ + Quantize with scale + """ + dtype_max = (jnp.finfo(q_dtype).max).astype(x.dtype) + scale = scale.astype(x.dtype) + clipped_scaled_x = jnp.clip((x * scale), -dtype_max, dtype_max) + return clipped_scaled_x.astype(q_dtype) + + def layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps): """ - JAX native layernorm fp8 implementations + JAX native layernorm fp8 implementation """ x_ = jnp.asarray(x, jnp.float32) mean = jnp.mean(x_, axis=-1, keepdims=True) @@ -43,14 +53,14 @@ def layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, e if zero_centered_gamma: gamma += 1. output = normed_input * gamma + beta - casted_output = (scale * output).astype(out_dtype) + casted_output = quantize(output, scale, q_dtype=out_dtype) updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(output)).astype(amax.dtype)) return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax def rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps): """ - JAX native rmsnorm fp8 implementations + JAX native rmsnorm fp8 implementation """ x_ = jnp.asarray(x, jnp.float32) var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) @@ -59,6 +69,79 @@ def rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps): if zero_centered_gamma: gamma += 1. output = normed_input * gamma - casted_output = (scale * output).astype(out_dtype) + casted_output = quantize(output, scale, q_dtype=out_dtype) updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(output)).astype(amax.dtype)) return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax + + +def gated_gelu(inputs): + """ + JAX native gated gelu implementation + inputs: (N, 2, H) + """ + gelu_inputs, identity_inputs = jnp.split(inputs, [1], axis=-2) + gelu_outputs = jax.nn.gelu(gelu_inputs) + return jnp.squeeze(gelu_outputs * identity_inputs, axis=-2) + + +def cast_fp8(inputs, scale, amax, out_dtype): + """ + JAX native fp8 casting implementation + """ + casted_output = quantize(inputs, scale, q_dtype=out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) + return casted_output, updated_amax + + +def _normalize_axis_boundary(axis, ndim): + return axis if axis >= 0 else ndim + axis + + +def _multidim_transpose(shape, static_axis_boundary, transpose_axis_boundary): + """ + te_cast_transpose_p multi-dims transpose + + static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be + involved into transpose, -1 means all axes involve into transpose. + transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for + transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary + + examples: + X in shape (dim0, dim1, dim2, dim3, dim4) + + static_axis_boundary == -1, transpose_axis_boundary == 2 + Xt = (dim2, dim3, dim4, dim0, dim1) + + static_axis_boundary == 0, transpose_axis_boundary == 2 + Xt = (dim0, dim2, dim3, dim4, dim1) + + static_axis_boundary == 0, transpose_axis_boundary == 3 + Xt = (dim0, dim3, dim4, dim1. dim2) + """ + if static_axis_boundary < 0: + static_axis_boundary = -1 # means no static axes + assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. + transpose_start_idx = static_axis_boundary + 1 + transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, len(shape)) + assert transpose_start_idx < transpose_axis_boundary + return (*shape[:transpose_start_idx], *shape[transpose_axis_boundary:], + *shape[transpose_start_idx:transpose_axis_boundary]) + + +def transpose(inputs, static_axis_boundary, transpose_axis_boundary): + """ + JAX native transpose implementation + """ + axes = _multidim_transpose(range(inputs.ndim), static_axis_boundary, transpose_axis_boundary) + return jnp.transpose(inputs, axes=axes) + + +def cast_transpose(inputs, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary): + """ + JAX native cast_transpose implementation + """ + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(inputs)).astype(amax.dtype)) + casted_output = quantize(inputs, scale, q_dtype=out_dtype) + casted_transposed_output = transpose(casted_output, static_axis_boundary, + transpose_axis_boundary) + return casted_output, casted_transposed_output, updated_amax From e58f1b84d2f3f9479fa137e5a63b43d55a31888d Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 25 Apr 2024 16:30:03 +0000 Subject: [PATCH 6/6] add dgelu_dbias_ct/dgated_gelu_ct JAX graphs Signed-off-by: Reese Wang --- transformer_engine/jax/cpp_extensions.py | 23 +++++++++++++++++++++++ transformer_engine/jax/layer_math.py | 17 +++++++++++++++++ transformer_engine/jax/mlp.py | 11 ++++++----- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index b61eb79729..594ac60bf3 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -4400,6 +4400,7 @@ def sharded_impl(dz, x, amax, scale, scale_inv): def dgelu_dbias_cast_transpose( dz: jnp.ndarray, x: jnp.ndarray, + bias: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, @@ -4413,6 +4414,17 @@ def dgelu_dbias_cast_transpose( if static_axis_boundary < 0: static_axis_boundary = -1 # means no static axes + if not enable_primitive(DGeluDBiasCastTransposePrimitive.name): + _, vjp_func = jax.vjp(layer_math.bias_gelu, x, bias) + gelu_grad, bias_grad = vjp_func(dz) + casted_gelu_grad, ct_gelu_grad, updated_amax = layer_math.cast_transpose( + gelu_grad, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + return casted_gelu_grad, ct_gelu_grad, bias_grad, updated_amax return DGeluDBiasCastTransposePrimitive.outer_primitive.bind( dz, x, @@ -4568,6 +4580,8 @@ def gated_gelu_fp8(x: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_ gated gelu wrapper Return FP8(geglu(x)) """ + if not enable_primitive(GatedGeluFp8Primitive.name): + return layer_math.gated_gelu_fp8(x, scale, amax, out_dtype=out_dtype) return GatedGeluFp8Primitive.outer_primitive.bind(x, amax, scale, @@ -4738,6 +4752,15 @@ def dgated_gelu_cast_transpose( cast transpose d_gated_gelu fusion wrapper Return FP8(dgeglu(inputs)) """ + if not enable_primitive(DgatedGeluCastTransposePrimitive.name): + _, vjp_func = jax.vjp(layer_math.gated_gelu, x) + dx, = vjp_func(dz) + return layer_math.cast_transpose(dx, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=-2) return DgatedGeluCastTransposePrimitive.outer_primitive.bind( dz, x, diff --git a/transformer_engine/jax/layer_math.py b/transformer_engine/jax/layer_math.py index 4b697b4a4d..bd72e26061 100644 --- a/transformer_engine/jax/layer_math.py +++ b/transformer_engine/jax/layer_math.py @@ -74,6 +74,13 @@ def rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps): return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax +def bias_gelu(inputs, bias): + """ + JAX native bias_gelu implementation + """ + return jax.nn.gelu(inputs + bias) + + def gated_gelu(inputs): """ JAX native gated gelu implementation @@ -84,6 +91,16 @@ def gated_gelu(inputs): return jnp.squeeze(gelu_outputs * identity_inputs, axis=-2) +def gated_gelu_fp8(inputs, scale, amax, out_dtype): + """ + JAX native gated gelu fp8 implementation + """ + geglu_output = gated_gelu(inputs) + casted_output = quantize(geglu_output, scale, q_dtype=out_dtype) + updated_amax = jax.lax.max(amax, jnp.max(jnp.abs(geglu_output)).astype(amax.dtype)) + return casted_output, updated_amax + + def cast_fp8(inputs, scale, amax, out_dtype): """ JAX native fp8 casting implementation diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 3f2a454f17..260df75aad 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -572,10 +572,10 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( dot_2_output += jnp.reshape(bias_2, bias_2_shape) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - ctx = (x, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_gelu_out, casted_kernel_1, - casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax, - updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, - bias_1.shape, bias_2.shape) + ctx = (x, bias_1, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_gelu_out, + casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, + updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, + xt_batch_dims, bias_1.shape, bias_2.shape) return dot_2_output, ctx @@ -593,7 +593,7 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( ffn2_ckpt_name, # pylint: disable=unused-argument ctx, grad): - x, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_gelu_out, \ + x, bias_1, ln_out, mu, rsigma, gamma, beta, dot_1_output, casted_gelu_out, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx @@ -641,6 +641,7 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose( dgrad_2, dot_1_output, + bias_1, dgelu_amax, dgelu_scale, dgelu_scale_inv,