From 62b5e4a64ef732fd08e29d80e4a9a1cb52263a0e Mon Sep 17 00:00:00 2001 From: Kush Dubey Date: Sat, 11 May 2024 04:43:12 -0700 Subject: [PATCH] nits --- docs/source/design_choices.rst | 5 ++++- setup.py | 2 ++ src/cappr/utils/_check.py | 2 +- src/cappr/utils/classify.py | 21 ++++++++++----------- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/docs/source/design_choices.rst b/docs/source/design_choices.rst index 93be22a..daadc2b 100644 --- a/docs/source/design_choices.rst +++ b/docs/source/design_choices.rst @@ -175,7 +175,10 @@ is incredibly powerful with inheritance. For an example, see the `tests `_ for llama-cpp models. -There are still a few testing todos. +There are still a few testing todos. One problem is that there are dependencies in the +tests; if ``test_log_probs_conditional`` fails, the rest of the tests will fail. +Ideally, for example, ``test_predict_proba`` assumes ``log_probs_conditional`` is +correct. Mistakes were made diff --git a/setup.py b/setup.py index 404cf69..1c7e095 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ from setuptools import setup, find_packages +# lol with open(os.path.join("src", "cappr", "__init__.py")) as f: for line in f: if line.startswith("__version__ = "): @@ -40,6 +41,7 @@ "pydata-sphinx-theme>=0.13.1", "pytest>=7.2.1", "pytest-cov>=4.0.0", + "pytest-sugar>=1.0.0", "ruff>=0.3.0", "sphinx>=6.1.3", "sphinx-copybutton>=0.5.2", diff --git a/src/cappr/utils/_check.py b/src/cappr/utils/_check.py index 0496c5d..84bafbc 100644 --- a/src/cappr/utils/_check.py +++ b/src/cappr/utils/_check.py @@ -27,7 +27,7 @@ def _is_reversible(object) -> bool: # Returns False for: # - set try: - reversed(object) # often a generator, so checking this is often cheap + reversed(object) # often a generator, so checking this is often free except TypeError: return False else: diff --git a/src/cappr/utils/classify.py b/src/cappr/utils/classify.py index 28f791c..1fcec0c 100644 --- a/src/cappr/utils/classify.py +++ b/src/cappr/utils/classify.py @@ -90,11 +90,11 @@ def _agg_log_probs_vectorized( ) for completion_idx in range(num_completions_per_prompt) ] - except (np.VisibleDeprecationWarning, ValueError): + except (np.VisibleDeprecationWarning, ValueError) as exception: raise ValueError( "log_probs has a constant # of completions, but there are a " "non-constant # of tokens. Vectorization is not possible." - ) + ) from exception # Now apply the vectorized function to each array in the list likelihoods: npt.NDArray[np.floating] = np.exp( [func(array, axis=1) for array in array_list] @@ -189,10 +189,7 @@ def agg_log_probs( # 2. Run the aggregation computation, vectorizing if possible try: likelihoods = _agg_log_probs_vectorized(log_probs, func) - except ( - ValueError, # log_probs is jagged - TypeError, # func doesn't take an axis argument - ): + except ValueError: # log_probs is doubly jagged likelihoods = _agg_log_probs(log_probs, func) # 3. If we wrapped it, unwrap it for user convenience @@ -239,7 +236,8 @@ def posterior_prob( `likelihoods` """ # Input checks and preprocessing - likelihoods = np.array(likelihoods) # it should not be jagged/inhomogenous + if not isinstance(likelihoods, np.ndarray): + likelihoods = np.array(likelihoods) # it cannot be jagged/inhomogenous if not isinstance(normalize, (Sequence, np.ndarray)): # For code simplicity, just repeat it # If likelihoods is 1-D, there's only a single probability distr to normalize @@ -260,7 +258,7 @@ def posterior_prob( else: posteriors_unnorm = likelihoods * prior marginals = posteriors_unnorm.sum(axis=axis, keepdims=True) - marginals[~normalize] = 1 # denominator of 1 <=> no normalization + marginals[~normalize] = 1 # denominator of 1 means no normalization return posteriors_unnorm / marginals @@ -399,7 +397,8 @@ def _predict_proba(log_probs_conditional): def wrapper( prompts: str | Sequence[str], completions: Sequence[str], *args, **kwargs ) -> npt.NDArray[np.floating]: - # Check inputs before making expensive model calls + # 1. Check inputs before making expensive model calls + # Check the prior prior = kwargs.get("prior", None) prior = _check.prior(prior, expected_length=len(completions)) @@ -417,7 +416,7 @@ def wrapper( "because discount_completions was not set." ) - # Do the expensive model calls + # 2. Do the expensive model calls log_probs_completions = log_probs_conditional( prompts, completions, *args, **kwargs ) @@ -435,7 +434,7 @@ def wrapper( **kwargs, ) - # Aggregate probs + # 3. Aggregate probs likelihoods = agg_log_probs(log_probs_completions) axis = 0 if is_single_input else 1 return posterior_prob(likelihoods, axis=axis, prior=prior, normalize=normalize)