diff --git a/tscv/_split.py b/tscv/_split.py index 31e6521..2cbbe78 100644 --- a/tscv/_split.py +++ b/tscv/_split.py @@ -69,6 +69,14 @@ class GapCrossValidator(metaclass=ABCMeta): Implementations must define one of the following 4 methods: `_iter_train_indices`, `_iter_train_masks`, `_iter_test_indices`, `_iter_test_masks`. + + Parameters + ---------- + gap_before : int, default=0 + Gap before the test sets. + + gap_after : int, default=0 + Gap after the test sets. """ def __init__(self, gap_before=0, gap_after=0): @@ -106,6 +114,8 @@ def split(self, X, y=None, groups=None): # Since subclasses implement any of the following 4 methods, # none can be abstract. + # _iter_train_indices <- _iter_test_indices <- + # _iter_test_masks <- _iter_train_masks <- _iter_train_indices def _iter_train_indices(self, X=None, y=None, groups=None): """Generates integer indices corresponding to training sets. @@ -151,7 +161,8 @@ def __indices_to_masks(indices, n_samples): yield mask def __complement_masks(self, masks): - before, after = self.gap_before, self.gap_after + # switch gap_before and gap_after because of different viewpoints + before, after = self.gap_after, self.gap_before for mask in masks: complement = np.ones(len(mask), dtype=np.bool_) for i, masked in enumerate(mask): diff --git a/tscv/tests/test_split.py b/tscv/tests/test_split.py index 2b43ec0..ef8b8a4 100644 --- a/tscv/tests/test_split.py +++ b/tscv/tests/test_split.py @@ -61,8 +61,8 @@ def _iter_train_indices(self, X=None, y=None, groups=None): masks = cv._GapCrossValidator__complement_masks( [[False, True, True, True, False, False], [False, False, False, False, False, True]]) - assert_array_equal(next(masks), [True, False, False, False, False, False]) - assert_array_equal(next(masks), [True, True, True, True, True, False]) + assert_array_equal(next(masks), [False, False, False, False, True, True]) + assert_array_equal(next(masks), [True, True, True, False, False, False]) indices = cv._GapCrossValidator__complement_indices([[1, 2, 3], [5]], 7) assert_array_equal(next(indices), [0, 6]) @@ -73,12 +73,12 @@ def _iter_train_indices(self, X=None, y=None, groups=None): assert_array_equal(next(masks), [False, False, True, False, True]) masks = cv._iter_test_masks("abcde") - assert_array_equal(next(masks), [True, False, False, False, False]) - assert_array_equal(next(masks), [True, True, False, False, False]) + assert_array_equal(next(masks), [False, False, False, False, True]) + assert_array_equal(next(masks), [False, False, False, False, False]) indices = cv._iter_test_indices("abcde") - assert_array_equal(next(indices), [0]) - assert_array_equal(next(indices), [0, 1]) + assert_array_equal(next(indices), [4]) + assert_array_equal(next(indices), []) # Another dummy subclass class test2CV(GapCrossValidator):