diff --git a/KDEpy/bw_selection.py b/KDEpy/bw_selection.py index 5cf2659..fb694a6 100644 --- a/KDEpy/bw_selection.py +++ b/KDEpy/bw_selection.py @@ -243,10 +243,6 @@ def silvermans_rule(data, weights=None): Returns optimal smoothing (standard deviation) if the data is close to normal. - TODO: Extend to multidimensional: - https://docs.scipy.org/doc/scipy-0.13.0/reference/generated/scipy. - stats.gaussian_kde.html#r216 - Examples -------- >>> data = np.arange(9).reshape(-1, 1) @@ -256,31 +252,30 @@ def silvermans_rule(data, weights=None): if not len(data.shape) == 2: raise ValueError("Data must be of shape (obs, dims).") obs, dims = data.shape - if not dims == 1: - raise ValueError("Silverman's rule is only available for 1D data.") if weights is not None: warnings.warn("Silverman's rule currently ignores all weights") if obs == 1: - return 1 + return 1.0 if dims < 2 else np.asarray([1.0] * dims) if obs < 1: raise ValueError("Data must be of length > 0.") - sigma = np.std(data, ddof=1) + sigma = np.std(data, axis=0) # scipy.stats.norm.ppf(.75) - scipy.stats.norm.ppf(.25) -> 1.3489795003921634 - IQR = (np.percentile(data, q=75) - np.percentile(data, q=25)) / 1.3489795003921634 + IQR = (np.percentile(data, axis=0, q=75) - np.percentile(data, axis=0, q=25)) / 1.3489795003921634 - sigma = min(sigma, IQR) + sigma = np.min(np.stack([sigma, IQR]), axis=0) # The logic below is not related to silverman's rule, but if the data is constant # it's nice to return a value instead of getting an error. A warning will be raised. - if sigma > 0: - return sigma * (obs * 3 / 4.0) ** (-1 / 5) + if np.min(sigma) > 0: + res = sigma * (obs * 3 / 4.0) ** (-1 / 5) + return res[0] if dims < 2 else res else: # stats.norm.ppf(.99) - stats.norm.ppf(.01) = 4.6526957480816815 - IQR = (np.percentile(data, q=99) - np.percentile(data, q=1)) / 4.6526957480816815 - if IQR > 0: + IQR = (np.percentile(data, axis=0, q=99) - np.percentile(data, axis=0, q=1)) / 4.6526957480816815 + if np.min(IQR) > 0: bw = IQR * (obs * 3 / 4.0) ** (-1 / 5) warnings.warn( "Silverman's rule failed. Too many idential values. \ @@ -288,11 +283,11 @@ def silvermans_rule(data, weights=None): bw ) ) - return bw + return bw[0] if dims < 2 else bw # Here, all values are basically constant warnings.warn("Silverman's rule failed. Too many idential values. Setting bw = 1.0") - return 1.0 + return 1.0 if dims < 2 else np.asarray([1.0] * dims) _bw_methods = { diff --git a/KDEpy/tests/test_bw_selection.py b/KDEpy/tests/test_bw_selection.py index f5fde2d..583a5c5 100644 --- a/KDEpy/tests/test_bw_selection.py +++ b/KDEpy/tests/test_bw_selection.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from KDEpy.bw_selection import _bw_methods, improved_sheather_jones +from KDEpy.bw_selection import _bw_methods, improved_sheather_jones, silvermans_rule @pytest.fixture(scope="module") @@ -15,6 +15,11 @@ def data() -> np.ndarray: return np.random.randn(100, 1) +@pytest.fixture(scope="module") +def multidim_data() -> np.ndarray: + return np.random.randn(100, 2) + + @pytest.mark.parametrize("method", _bw_methods.values()) def test_equal_weights_dont_changed_bw(data, method): weights = np.ones_like(data).squeeze() * 2 @@ -23,6 +28,13 @@ def test_equal_weights_dont_changed_bw(data, method): np.testing.assert_almost_equal(bw_no_weights, bw_weighted) +def test_multidim_silvermans_rule_weights_dont_changed_bw(multidim_data): + weights = np.ones_like(multidim_data).squeeze() * 2 + bw_no_weights = silvermans_rule(multidim_data, weights=None) + bw_weighted = silvermans_rule(multidim_data, weights=weights) + np.testing.assert_almost_equal(bw_no_weights, bw_weighted) + + def test_isj_bw_weights_single_zero_weighted_point(data): data_with_outlier = np.concatenate((data.copy(), np.array([[1000]]))) weights = np.ones_like(data_with_outlier).squeeze() @@ -45,6 +57,19 @@ def test_isj_bw_weights_same_as_resampling(data, execution_number): ) +def test_onedim_silvermans_rule_shape(data): + sr_res = silvermans_rule(data) + # dims is a float + assert isinstance(sr_res, float) + + +def test_multidim_silvermans_rule_shape(multidim_data): + sr_res = silvermans_rule(multidim_data) + # dims shape is 2 + dim = sr_res.shape[0] + assert dim == 2 + + if __name__ == "__main__": # --durations=10 <- May be used to show potentially slow tests pytest.main(args=[__file__, "--doctest-modules", "-v", "--durations=15"])