Skip to content

Commit

Permalink
FP8 attention and all post fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman committed Apr 24, 2024
1 parent 9e4091e commit 090e724
Show file tree
Hide file tree
Showing 22 changed files with 2,991 additions and 464 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 120 files
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==6.2.5 onnxruntime==1.13.1
pip install pytest==7.2 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
Expand Down
594 changes: 484 additions & 110 deletions tests/pytorch/fused_attn/test_fused_attn.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)

# Check output.
atol = {torch.float32 : 2e-4,
atol = {torch.float32 : 2.5e-4,
torch.half : 2e-3,
torch.bfloat16: 2e-2,
}
Expand Down
71 changes: 54 additions & 17 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,25 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
if ((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)
&& (sm_arch_ >= 90)
&& (max_seqlen_q == max_seqlen_kv)
&& (num_attn_heads == num_gqa_groups)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)) {
if (((q_dtype == NVTEDType::kNVTEFloat8E4M3)
|| (q_dtype == NVTEDType::kNVTEFloat8E5M2))
&& (sm_arch_ >= 90)
&& (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
&& (
((cudnn_runtime_version >= 8900)
&& (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)
&& (max_seqlen_q == max_seqlen_kv)
&& (max_seqlen_q <= 512)
&& (head_dim == 64)
&& (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK))
|| ((cudnn_runtime_version >= 90100)
&& (max_seqlen_q % 128 == 0)
&& (max_seqlen_kv % 128 == 0)
&& (head_dim == 128)
&& ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
|| (qkv_format == NVTE_QKV_Format::NVTE_SBHD))
&& ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
|| (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
Expand Down Expand Up @@ -269,7 +279,7 @@ void nvte_fused_attn_fwd_qkvpacked(
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_qkvpacked(
b, h, max_seqlen, d,
is_training, attn_scale, dropout, qkv_layout,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens,
Expand Down Expand Up @@ -379,7 +389,7 @@ void nvte_fused_attn_bwd_qkvpacked(
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_qkvpacked(
b, h, max_seqlen, d,
attn_scale, dropout, qkv_layout,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
Expand Down Expand Up @@ -476,7 +486,18 @@ void nvte_fused_attn_fwd_kvpacked(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
Expand Down Expand Up @@ -580,7 +601,23 @@ void nvte_fused_attn_bwd_kvpacked(
NVTE_ERROR(err_msg);
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
NVTE_ERROR("The FP8 fused attention API only supports packed QKV input. \n");
#if (CUDNN_VERSION >= 8900)
const Tensor *input_M = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[0]);
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
output_dQ, output_dKV,
input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state,
wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
Expand Down Expand Up @@ -662,8 +699,8 @@ void nvte_fused_attn_fwd(
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
fused_attn_fp8_fwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_output_S, output_O,
Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv,
Expand Down Expand Up @@ -775,8 +812,8 @@ void nvte_fused_attn_bwd(
const Tensor *input_ZInv = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[1]);
const Tensor *input_rng_state = reinterpret_cast<const Tensor*>(Aux_CTX_Tensors->tensors[2]);
fused_attn_fp8_bwd(
b, h_q, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
input_M, input_ZInv,
input_S, input_output_dP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
scaling_factor, is_training,
dropout_probability, layout,
bias_type, mask_type,
tensorType};
tensorType, tensorType};

namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
Expand Down Expand Up @@ -147,7 +147,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(!is_training)
.set_is_inference(false)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);

Expand Down Expand Up @@ -199,19 +199,17 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout, NVTE_QKV_Matrix::NVTE_O_Matrix);
O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);

if (is_training) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});
}
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({b, h, s_q, 1})
.set_stride({h * s_q, s_q, 1, 1});

std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
std::shared_ptr<fe::graph::Tensor_attributes>, // K
std::shared_ptr<fe::graph::Tensor_attributes>, // V
std::shared_ptr<fe::graph::Tensor_attributes>, // attn_scale
std::shared_ptr<fe::graph::Tensor_attributes> > // O
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = is_training ? std::make_tuple(Stats) : std::make_tuple(nullptr);
auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto padding_tuple = is_padding ?
std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
Expand Down Expand Up @@ -258,11 +256,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
{K, devPtrK},
{V, devPtrV},
{attn_scale, &scaling_factor},
{O, devPtrO}};

if (is_training) {
variant_pack[Stats] = devPtrSoftmaxStats;
}
{O, devPtrO},
{Stats, devPtrSoftmaxStats}};

if (is_bias) {
variant_pack[bias] = devPtrBias;
Expand Down Expand Up @@ -321,7 +316,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
scaling_factor, true,
dropout_probability, layout,
bias_type, mask_type,
tensorType};
tensorType, tensorType};

namespace fe = cudnn_frontend;
using graph_and_tensors = std::tuple<std::shared_ptr<fe::graph::Graph>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace transformer_engine {
#if (CUDNN_VERSION >= 8900)
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen,
size_t head_size, bool is_training, float attn_scale,
size_t head_dim, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
const Tensor *input_QKV, const Tensor *input_Bias,
Expand Down
Loading

0 comments on commit 090e724

Please sign in to comment.