Skip to content

Commit

Permalink
Properly determine if the cache is cleared
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Nov 21, 2023
1 parent 30914a1 commit 4e444b5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
15 changes: 12 additions & 3 deletions src/cappr/huggingface/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,21 @@ def _past_key_values_get(
class _CAPPr:
model: ModelForCausalLM
logits_all: bool = True
past: tuple[BatchEncodingPT, CausalLMOutputWithPast] | None = None
_past: tuple[BatchEncodingPT, CausalLMOutputWithPast] | None = None
batch_idxs: torch.Tensor | None = None
update_cache: bool = False
_is_cache_cleared: bool = False

@property
def past(self):
return self._past

@past.setter
def past(self, new_past: tuple[BatchEncodingPT, CausalLMOutputWithPast] | None):
if new_past is None:
self._is_cache_cleared = True
self._past = new_past


class _CacheClearedError(Exception):
"""Raise to prevent a user from using a cached model outside of the context"""
Expand All @@ -197,7 +207,7 @@ def __init__(
past: tuple[BatchEncodingPT, CausalLMOutputWithPast] | None = None,
logits_all: bool = True,
):
self._cappr = _CAPPr(model, logits_all=logits_all, past=past)
self._cappr = _CAPPr(model, logits_all, past)
"""
Contains data which controls the cache
"""
Expand Down Expand Up @@ -554,7 +564,6 @@ def cache(
# hidden_states (usually taking up GPU RAM)—that should be cleared when we exit
# the context
model_with_cache._cappr.past = None
model_with_cache._cappr._is_cache_cleared = True
else:
model_with_cache._cappr.past = past

Expand Down
6 changes: 5 additions & 1 deletion tests/huggingface/test_huggingface_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,14 @@ def test_cache_nested(model_and_tokenizer, atol):
assert torch.allclose(logits4, logits_correct(["a b c d"]), atol=atol)

# Test clear_cache_on_exit
device = model_and_tokenizer[0].device
with pytest.raises(
classify._CacheClearedError, match="This model is no longer usable."
):
cached_a[0](input_ids=None, attention_mask=None)
cached_a[0](
input_ids=torch.tensor([[1]], device=device),
attention_mask=torch.tensor([[1]], device=device),
)

with classify.cache(
model_and_tokenizer, "a", clear_cache_on_exit=False
Expand Down

0 comments on commit 4e444b5

Please sign in to comment.