Skip to content

Commit

Permalink
[PyTorch] Avoid using LRU cache for cu_seqlens (#798)
Browse files Browse the repository at this point in the history
* Try using global buffer for cu_seqlens

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Avoid using functools.lru_cache

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
  • Loading branch information
ksivaman and vasunvidia committed Apr 24, 2024
1 parent 33be946 commit 8245067
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""Attention."""
import collections
from contextlib import nullcontext
import functools
from importlib.metadata import version
import math
import os
Expand Down Expand Up @@ -278,8 +277,7 @@ def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor:

return indices


@functools.lru_cache
_cu_seqlens_cache = {}
def _get_full_cu_seqlens(
batch_size: int,
max_seqlen: int,
Expand All @@ -290,13 +288,16 @@ def _get_full_cu_seqlens(
All sequences in batch have the maximum sequence length.
"""
return torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
global _cu_seqlens_cache
if (batch_size, max_seqlen) not in _cu_seqlens_cache:
_cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange(
0,
(batch_size + 1) * max_seqlen,
step=max_seqlen,
dtype=torch.int32,
device=device,
)
return _cu_seqlens_cache[(batch_size, max_seqlen)]


@jit_fuser
Expand Down

0 comments on commit 8245067

Please sign in to comment.