Skip to content

Commit

Permalink
Improve naming of token and vocab transition keys in regex.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 authored and rlouf committed Jun 5, 2024
1 parent c11a595 commit ed44a47
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions outlines/fsm/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def _walk_fsm(
fsm_transitions: Dict[Tuple[int, int], int],
fsm_initial: int,
fsm_finals: Set[int],
token_trans_key_seq: Sequence[int],
token_transition_keys: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -424,7 +424,7 @@ def _walk_fsm(

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
for i, trans_key in enumerate(token_transition_keys):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand All @@ -448,7 +448,7 @@ def _walk_fsm(

def walk_fsm(
fsm: BetterFSM,
token_trans_key_seq: Sequence[int],
token_transition_keys: Sequence[int],
start_state: int,
full_match: bool = True,
) -> List[int]:
Expand All @@ -462,7 +462,7 @@ def walk_fsm(

# Iterate over token transition key sequence. The transition key
# sequence represents the FSM traversal rules of the tokens symbols.
for i, trans_key in enumerate(token_trans_key_seq):
for i, trans_key in enumerate(token_transition_keys):
new_state = fsm_transitions.get((state, trans_key))

if new_state is None:
Expand Down Expand Up @@ -703,10 +703,10 @@ def get_token_transition_keys(
alphabet_symbol_mapping.get(symbol, alphabet_anything_value)
)

tok_trans_array = np.empty(len(token_transition_keys), dtype=np.int64)
token_transition_keys_array = np.empty(len(token_transition_keys), dtype=np.int64)
for j in range(len(token_transition_keys)):
tok_trans_array[j] = token_transition_keys[j]
return tok_trans_array
token_transition_keys_array[j] = token_transition_keys[j]
return token_transition_keys_array


@numba.njit(cache=True, nogil=True)
Expand All @@ -718,14 +718,14 @@ def get_vocabulary_transition_keys(
"""
Calculate the sequence transition keys for each token str within a vocabulary
"""
tokens_trans_keys = numba.typed.List.empty_list(numba.int64[:])
vocab_transition_keys = numba.typed.List.empty_list(numba.int64[:])
for token_str, _ in vocabulary:
trans_key_seq_array = get_token_transition_keys(
token_transition_keys = get_token_transition_keys(
alphabet_symbol_mapping, alphabet_anything_value, token_str
)
tokens_trans_keys.append(trans_key_seq_array)
vocab_transition_keys.append(token_transition_keys)

return tokens_trans_keys
return vocab_transition_keys


def create_fsm_index_end_to_end(
Expand Down

0 comments on commit ed44a47

Please sign in to comment.