diff --git a/skorch/tests/callbacks/test_scoring.py b/skorch/tests/callbacks/test_scoring.py index b67764114..d5658bb2c 100644 --- a/skorch/tests/callbacks/test_scoring.py +++ b/skorch/tests/callbacks/test_scoring.py @@ -965,7 +965,10 @@ def net(self, classifier_module, train_loss, valid_loss, classifier_data): n = 75 # n=75 with a 4/5 train/valid split -> 60/15 samples; with a # batch size of 10, that leads to train batch sizes of - # [10,10,10,10] and valid batich sizes of [10,5] + # [10,10,10,10] and valid batch sizes of [10,5]; all labels + # are set to 0 to ensure that the stratified split is exactly + # equal to the desired split + y = np.zeros_like(y) return net.fit(X[:n], y[:n]) @pytest.fixture