Skip to content

Commit

Permalink
nits
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed May 11, 2024
1 parent b38ae4e commit 62b5e4a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
5 changes: 4 additions & 1 deletion docs/source/design_choices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ is incredibly powerful with inheritance. For an example, see the `tests
<https://github.com/kddubey/cappr/blob/main/tests/llama_cpp/test_llama_cpp_classify.py>`_
for llama-cpp models.

There are still a few testing todos.
There are still a few testing todos. One problem is that there are dependencies in the
tests; if ``test_log_probs_conditional`` fails, the rest of the tests will fail.
Ideally, for example, ``test_predict_proba`` assumes ``log_probs_conditional`` is
correct.


Mistakes were made
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from setuptools import setup, find_packages


# lol
with open(os.path.join("src", "cappr", "__init__.py")) as f:
for line in f:
if line.startswith("__version__ = "):
Expand Down Expand Up @@ -40,6 +41,7 @@
"pydata-sphinx-theme>=0.13.1",
"pytest>=7.2.1",
"pytest-cov>=4.0.0",
"pytest-sugar>=1.0.0",
"ruff>=0.3.0",
"sphinx>=6.1.3",
"sphinx-copybutton>=0.5.2",
Expand Down
2 changes: 1 addition & 1 deletion src/cappr/utils/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _is_reversible(object) -> bool:
# Returns False for:
# - set
try:
reversed(object) # often a generator, so checking this is often cheap
reversed(object) # often a generator, so checking this is often free
except TypeError:
return False
else:
Expand Down
21 changes: 10 additions & 11 deletions src/cappr/utils/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def _agg_log_probs_vectorized(
)
for completion_idx in range(num_completions_per_prompt)
]
except (np.VisibleDeprecationWarning, ValueError):
except (np.VisibleDeprecationWarning, ValueError) as exception:
raise ValueError(
"log_probs has a constant # of completions, but there are a "
"non-constant # of tokens. Vectorization is not possible."
)
) from exception
# Now apply the vectorized function to each array in the list
likelihoods: npt.NDArray[np.floating] = np.exp(
[func(array, axis=1) for array in array_list]
Expand Down Expand Up @@ -189,10 +189,7 @@ def agg_log_probs(
# 2. Run the aggregation computation, vectorizing if possible
try:
likelihoods = _agg_log_probs_vectorized(log_probs, func)
except (
ValueError, # log_probs is jagged
TypeError, # func doesn't take an axis argument
):
except ValueError: # log_probs is doubly jagged
likelihoods = _agg_log_probs(log_probs, func)

# 3. If we wrapped it, unwrap it for user convenience
Expand Down Expand Up @@ -239,7 +236,8 @@ def posterior_prob(
`likelihoods`
"""
# Input checks and preprocessing
likelihoods = np.array(likelihoods) # it should not be jagged/inhomogenous
if not isinstance(likelihoods, np.ndarray):
likelihoods = np.array(likelihoods) # it cannot be jagged/inhomogenous
if not isinstance(normalize, (Sequence, np.ndarray)):
# For code simplicity, just repeat it
# If likelihoods is 1-D, there's only a single probability distr to normalize
Expand All @@ -260,7 +258,7 @@ def posterior_prob(
else:
posteriors_unnorm = likelihoods * prior
marginals = posteriors_unnorm.sum(axis=axis, keepdims=True)
marginals[~normalize] = 1 # denominator of 1 <=> no normalization
marginals[~normalize] = 1 # denominator of 1 means no normalization
return posteriors_unnorm / marginals


Expand Down Expand Up @@ -399,7 +397,8 @@ def _predict_proba(log_probs_conditional):
def wrapper(
prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs
) -> npt.NDArray[np.floating]:
# Check inputs before making expensive model calls
# 1. Check inputs before making expensive model calls

# Check the prior
prior = kwargs.get("prior", None)
prior = _check.prior(prior, expected_length=len(completions))
Expand All @@ -417,7 +416,7 @@ def wrapper(
"because discount_completions was not set."
)

# Do the expensive model calls
# 2. Do the expensive model calls
log_probs_completions = log_probs_conditional(
prompts, completions, *args, **kwargs
)
Expand All @@ -435,7 +434,7 @@ def wrapper(
**kwargs,
)

# Aggregate probs
# 3. Aggregate probs
likelihoods = agg_log_probs(log_probs_completions)
axis = 0 if is_single_input else 1
return posterior_prob(likelihoods, axis=axis, prior=prior, normalize=normalize)
Expand Down

0 comments on commit 62b5e4a

Please sign in to comment.