diff --git a/tscv/_split.py b/tscv/_split.py index 31e6521..b516414 100644 --- a/tscv/_split.py +++ b/tscv/_split.py @@ -163,12 +163,16 @@ def __complement_masks(self, masks): def __complement_indices(self, indices, n_samples): before, after = self.gap_before, self.gap_after + allindices = np.arange(n_samples) for index in indices: - complement = np.arange(n_samples) - for i in index: - begin = max(i - before, 0) - end = min(i + after + 1, n_samples) - complement = np.setdiff1d(complement, np.arange(begin, end)) + # split index in subarrays of contiguous elements + subindexes = np.split(index, np.where(np.diff(index) != 1)[0] + 1) + complement = allindices + # find complement on arrays of contiguous elements + for subindex in subindexes: + begin = max(0, subindex[0] - before) + end = min(subindex[-1] + after + 1, n_samples) + complement = np.setdiff1d(complement, allindices[begin:end]) yield complement @abstractmethod