Skip to content

Commit

Permalink
Remove potentially wrong check on __reversed__
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Oct 1, 2023
1 parent 6e9ac04 commit d150e1b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
15 changes: 10 additions & 5 deletions src/cappr/huggingface/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
self.pad_token_id = tokenizer.pad_token_id

def __enter__(self):
# Note: PreTrainedTokenizerBase is smart about setting auxiliary attributes,
# e.g., it updates tokenizer.special_tokens_map after setting
# tokenizer.pad_token_id.
if self.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

Expand Down Expand Up @@ -110,6 +113,8 @@ def __exit__(self, *args):

_DEFAULT_CONTEXTS_MODEL = (_model_eval_mode,)
"Model settings: set the model in eval mode."


_DEFAULT_CONTEXTS_TOKENIZER = (
_tokenizer_pad,
_tokenizer_pad_on_right,
Expand Down Expand Up @@ -138,13 +143,13 @@ def set_up_model_and_tokenizer(

init_contexts_model = [context(model) for context in contexts_tokenizer]
init_contexts_tokenizer = [context(tokenizer) for context in contexts_model]
int_contexts = init_contexts_model + init_contexts_tokenizer
for init_context in int_contexts:
init_contexts = init_contexts_model + init_contexts_tokenizer
for init_context in init_contexts:
init_context.__enter__()

yield

for init_context in int_contexts:
for init_context in init_contexts:
init_context.__exit__()


Expand Down Expand Up @@ -200,8 +205,8 @@ def logits_to_log_probs(
# logits.shape is (# texts, max # tokens in texts, vocab size)
log_probs = F.log_softmax(logits, dim=2)

# Only keep the log-prob from the vocab dimension whose index is is the
# next token's input ID.
# Only keep the log-prob from the vocab dimension whose index is is the next token's
# input ID.
# input_ids.shape is (# texts, max # tokens in texts)
return (
log_probs[:, :logits_end_idx, :]
Expand Down
10 changes: 6 additions & 4 deletions src/cappr/utils/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@


def _is_reversible(object) -> bool:
# First, cheaply check if it implements __reversed__
if hasattr(object, "__reversed__"):
return True
# Some objects don't have the attribute, but still can be reversible, like a tuple.
# Returns True for:
# - list, tuple, dict keys, dict values
# - numpy array, torch Tensor
# - pandas and polars Series
# Returns False for:
# - set
# reversed(object) is often a generator, so checking this is often cheap.
try:
reversed(object)
Expand Down

0 comments on commit d150e1b

Please sign in to comment.