diff --git a/ignite/base/mixins.py b/ignite/base/mixins.py index 3ecb2922f03..c1b47d1e1ce 100644 --- a/ignite/base/mixins.py +++ b/ignite/base/mixins.py @@ -1,11 +1,18 @@ from collections import OrderedDict from collections.abc import Mapping -from typing import Tuple +from typing import List, Tuple class Serializable: - _state_dict_all_req_keys: Tuple = () - _state_dict_one_of_opt_keys: Tuple = () + _state_dict_all_req_keys: Tuple[str, ...] = () + _state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),) + + def __init__(self) -> None: + self._state_dict_user_keys: List[str] = [] + + @property + def state_dict_user_keys(self) -> List: + return self._state_dict_user_keys def state_dict(self) -> OrderedDict: raise NotImplementedError @@ -19,6 +26,13 @@ def load_state_dict(self, state_dict: Mapping) -> None: raise ValueError( f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" ) - opts = [k in state_dict for k in self._state_dict_one_of_opt_keys] - if len(opts) > 0 and ((not any(opts)) or (all(opts))): - raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys") + for one_of_opt_keys in self._state_dict_one_of_opt_keys: + opts = [k in state_dict for k in one_of_opt_keys] + if len(opts) > 0 and (not any(opts)) or (all(opts)): + raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys") + + for k in self._state_dict_user_keys: + if k not in state_dict: + raise ValueError( + f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'" + ) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 865218af359..277498ac4a8 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -129,7 +129,16 @@ def compute_mean_std(engine, batch): """ _state_dict_all_req_keys = ("epoch_length", "max_epochs") - _state_dict_one_of_opt_keys = ("iteration", "epoch") + _state_dict_one_of_opt_keys = ( + ( + "iteration", + "epoch", + ), + ( + "max_epochs", + "max_iters", + ), + ) # Flag to disable engine._internal_run as generator feature for BC interrupt_resume_enabled = True @@ -310,6 +319,7 @@ def execute_something(): for e in event_name: self.add_event_handler(e, handler, *args, **kwargs) return RemovableEventHandle(event_name, handler, self) + if isinstance(event_name, CallableEventWithFilter) and event_name.filter is not None: event_filter = event_name.filter handler = self._handler_wrapper(handler, event_name, event_filter) @@ -332,6 +342,16 @@ def execute_something(): return RemovableEventHandle(event_name, handler, self) + @staticmethod + def _assert_non_filtered_event(event_name: Any) -> None: + if ( + isinstance(event_name, CallableEventWithFilter) + and event_name.filter != CallableEventWithFilter.default_event_filter + ): + raise TypeError( + "Argument event_name should not be a filtered event, " "please use event without any event filtering" + ) + def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool: """Check if the specified event has the specified handler. @@ -675,7 +695,12 @@ def save_engine(_): a dictionary containing engine's state """ - keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) + keys: Tuple[str, ...] = self._state_dict_all_req_keys + keys += ("iteration",) + if self.state.max_epochs is not None: + keys += ("max_epochs",) + else: + keys += ("max_iters",) keys += tuple(self._state_dict_user_keys) return OrderedDict([(k, getattr(self.state, k)) for k in keys]) @@ -728,6 +753,8 @@ def load_state_dict(self, state_dict: Mapping) -> None: f"Input state_dict: {state_dict}" ) self.state.iteration = self.state.epoch_length * self.state.epoch + self._check_and_set_max_epochs(state_dict.get("max_epochs", None)) + self._check_and_set_max_iters(state_dict.get("max_iters", None)) @staticmethod def _is_done(state: State) -> bool: @@ -864,12 +891,26 @@ def switch_batch(engine): epoch_length = self._get_data_length(data) if epoch_length is not None and epoch_length < 1: - raise ValueError("Input data has zero size. Please provide non-empty data") + raise ValueError( + "Argument epoch_length is invalid. Please, either set a" + " correct epoch_length value or check if input data has" + " non-zero size." + ) if max_iters is None: if max_epochs is None: max_epochs = 1 else: + if max_iters < 1: + raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") + if (self.state.max_iters is not None) and max_iters <= self.state.iteration: + raise ValueError( + "Argument max_iters should be greater than the current iteration " + f"defined in the state: {max_iters} vs {self.state.iteration}. " + "Please, set engine.state.max_iters = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_iters = max_iters if max_epochs is not None: raise ValueError( "Arguments max_iters and max_epochs are mutually exclusive." @@ -932,6 +973,53 @@ def _setup_dataloader_iter(self) -> None: else: self._dataloader_iter = iter(self.state.dataloader) + def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None: + if max_epochs is not None: + if max_epochs < 1: + raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value") + if self.state.max_epochs is not None and max_epochs <= self.state.epoch: + raise ValueError( + "Argument max_epochs should be greater than the current epoch " + f"defined in the state: {max_epochs} vs {self.state.epoch}. " + "Please, set engine.state.max_epochs = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_epochs = max_epochs + + def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None: + if max_iters is not None: + if max_iters < 1: + raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value") + if (self.state.max_iters is not None) and max_iters <= self.state.iteration: + raise ValueError( + "Argument max_iters should be greater than the current iteration " + f"defined in the state: {max_iters} vs {self.state.iteration}. " + "Please, set engine.state.max_iters = None " + "before calling engine.run() in order to restart the training from the beginning." + ) + self.state.max_iters = max_iters + + def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None: + # Can't we accept a redefinition ? + if self.state.epoch_length is not None: + if epoch_length is not None: + if epoch_length != self.state.epoch_length: + raise ValueError( + "Argument epoch_length should be same as in the state, " + f"but given {epoch_length} vs {self.state.epoch_length}" + ) + else: + if epoch_length is None: + epoch_length = self._get_data_length(data) + + if epoch_length is not None and epoch_length < 1: + raise ValueError( + "Argument epoch_length is invalid. Please, either set a correct epoch_length value or " + "check if input data has non-zero size." + ) + + self.state.epoch_length = epoch_length + def _setup_engine(self) -> None: self._setup_dataloader_iter() @@ -976,17 +1064,19 @@ def _internal_run_as_gen(self) -> Generator: self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken handlers_start_time = time.time() - self._fire_event(Events.EPOCH_COMPLETED) - epoch_time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + if self.state.epoch_length is not None and self.state.iteration % self.state.epoch_length == 0: + # max_iters can cause training to complete without an epoch ending + self._fire_event(Events.EPOCH_COMPLETED) + epoch_time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + + hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) + self.logger.info( + f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}" + ) yield from self._maybe_terminate_or_interrupt() - hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) - self.logger.info( - f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}" - ) - except _EngineTerminateException: self._fire_event(Events.TERMINATE) @@ -1064,6 +1154,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: self.state.epoch_length = iter_counter if self.state.max_iters is not None: self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) + # Warn but will continue until max iters is reached + warnings.warn( + "Data iterator can not provide data anymore but required total number of " + "iterations to run is not reached. " + f"Current iteration: {self.state.iteration} vs Total iterations to run :" + f" {self.state.max_iters}" + ) break # Should exit while loop if we can not iterate @@ -1106,7 +1203,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]: if self.state.max_iters is not None and self.state.iteration == self.state.max_iters: self.should_terminate = True - raise _EngineTerminateException() + warnings.warn( + "Data iterator can not provide data anymore but required total number of " + "iterations to run is not reached. " + f"Current iteration: {self.state.iteration} vs Total iterations to run : ? total_iters" + ) + break + # raise _EngineTerminateException() except _EngineTerminateSingleEpochException: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) @@ -1231,6 +1334,13 @@ def _run_once_on_dataset_legacy(self) -> float: self.state.epoch_length = iter_counter if self.state.max_iters is not None: self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length) + # Warn but will continue until max iters is reached + warnings.warn( + "Data iterator can not provide data anymore but required total number of " + "iterations to run is not reached. " + f"Current iteration: {self.state.iteration} vs Total iterations to run :" + f" {self.state.max_iters}" + ) break # Should exit while loop if we can not iterate @@ -1291,6 +1401,16 @@ def _run_once_on_dataset_legacy(self) -> float: return time.time() - start_time + def debug(self, enabled: bool = True) -> None: + """Enables/disables engine's logging debug mode""" + from ignite.utils import setup_logger + + if enabled: + setattr(self, "_stored_logger", self.logger) + self.logger = setup_logger(level=logging.DEBUG) + elif hasattr(self, "_stored_logger"): + self.logger = getattr(self, "_stored_logger") + def _get_none_data_iter(size: int) -> Iterator: # Sized iterator for data as None diff --git a/tests/ignite/base/test_mixins.py b/tests/ignite/base/test_mixins.py index 0f3a39811fb..734384f3847 100644 --- a/tests/ignite/base/test_mixins.py +++ b/tests/ignite/base/test_mixins.py @@ -3,6 +3,11 @@ from ignite.base import Serializable +class ExampleSerializable(Serializable): + _state_dict_all_req_keys = ("a", "b") + _state_dict_one_of_opt_keys = (("c", "d"), ("e", "f")) + + def test_state_dict(): s = Serializable() with pytest.raises(NotImplementedError): @@ -10,5 +15,26 @@ def test_state_dict(): def test_load_state_dict(): - s = Serializable() - s.load_state_dict({}) + + s = ExampleSerializable() + with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): + s.load_state_dict("abc") + + with pytest.raises(ValueError, match=r"is absent in provided state_dict"): + s.load_state_dict({}) + + with pytest.raises(ValueError, match=r"is absent in provided state_dict"): + s.load_state_dict({"a": 1}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}) + + with pytest.raises(ValueError, match=r"state_dict should contain only one of"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5, "f": 5}) + + s.state_dict_user_keys.append("alpha") + with pytest.raises(ValueError, match=r"Required user state attribute"): + s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 4}) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 13021242650..004e8573bf5 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1,3 +1,4 @@ +import math import os import time from unittest.mock import call, MagicMock, Mock @@ -503,6 +504,15 @@ def test__is_done(self): state = State(iteration=1000, max_epochs=10, epoch_length=100) assert Engine._is_done(state) + state = State(iteration=11, epoch=2, max_epochs=None, epoch_length=11, max_iters=22) + assert not Engine._is_done(state) + + state = State(iteration=100, epoch=1, max_epochs=None, epoch_length=100, max_iters=250) + assert not Engine._is_done(state) + + state = State(iteration=250, epoch=1, max_epochs=None, epoch_length=100, max_iters=250) + assert Engine._is_done(state) + def test__setup_engine(self): engine = Engine(lambda e, b: 1) engine.state = State(iteration=10, epoch=1, max_epochs=100, epoch_length=100) @@ -515,8 +525,21 @@ def test__setup_engine(self): def test_run_asserts(self): engine = Engine(lambda e, b: 1) - with pytest.raises(ValueError, match=r"Input data has zero size. Please provide non-empty data"): + with pytest.raises( + ValueError, + match=r"Argument epoch_length is invalid. Please, either set a correct epoch_length " + r"value or check if input data has non-zero size.", + ): engine.run([]) + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than or equal to the start epoch"): + engine.state.max_epochs = 5 + engine.state.epoch = 5 + engine.run([0, 1], max_epochs=3) + + with pytest.raises(ValueError, match=r"Argument max_iters should be greater than the current"): + engine.state.max_iters = 100 + engine.state.iteration = 100 + engine.run([0, 1], max_iters=50) def test_state_get_event_attrib_value(self): state = State() @@ -573,7 +596,21 @@ def check_completed_time(): >= (sleep_time * epoch_length + extra_sleep_time) * max_epochs + extra_sleep_time ) - def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_stops=None): + def _test_check_triggered_events( + self, + data, + max_epochs=None, + epoch_length=None, + max_iters=None, + n_epoch_started=None, + n_epoch_completed=None, + n_iter_started=None, + n_iter_completed=None, + n_batch_started=None, + n_batch_completed=None, + exp_iter_stops=None, + n_terminate=None, + ): engine = Engine(lambda e, b: 1) events = [ Events.STARTED, @@ -585,6 +622,8 @@ def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_ Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED, Events.DATALOADER_STOP_ITERATION, + Events.TERMINATE, + Events.TERMINATE_SINGLE_EPOCH, ] handlers = {e: MagicMock() for e in events} @@ -592,18 +631,46 @@ def _test_check_triggered_events(self, data, max_epochs, epoch_length, exp_iter_ for e, handler in handlers.items(): engine.add_event_handler(e, handler) - engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) + engine.run(data, max_epochs=max_epochs, max_iters=max_iters, epoch_length=epoch_length) + + if epoch_length is None: + epoch_length = engine.state.epoch_length + + assert epoch_length is not None + + if n_iter_started is None: + n_iter_started = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_iter_completed is None: + n_iter_completed = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_batch_started is None: + if data is None: + n_batch_started = 0 + else: + n_batch_started = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_batch_completed is None: + if data is None: + n_batch_completed = 0 + else: + n_batch_completed = max_epochs * epoch_length if max_epochs is not None else max_iters + + if n_terminate is None: + n_terminate = int(n_epoch_started != n_epoch_completed) if max_iters is not None else 0 expected_num_calls = { Events.STARTED: 1, Events.COMPLETED: 1, - Events.EPOCH_STARTED: max_epochs, - Events.EPOCH_COMPLETED: max_epochs, - Events.ITERATION_STARTED: max_epochs * epoch_length, - Events.ITERATION_COMPLETED: max_epochs * epoch_length, - Events.GET_BATCH_STARTED: max_epochs * epoch_length if data is not None else 0, - Events.GET_BATCH_COMPLETED: max_epochs * epoch_length if data is not None else 0, + Events.EPOCH_STARTED: n_epoch_started if n_epoch_started is not None else max_epochs, + Events.EPOCH_COMPLETED: n_epoch_completed if n_epoch_completed is not None else max_epochs, + Events.ITERATION_STARTED: n_iter_started, + Events.ITERATION_COMPLETED: n_iter_completed, + Events.GET_BATCH_STARTED: n_batch_started, + Events.GET_BATCH_COMPLETED: n_batch_completed, Events.DATALOADER_STOP_ITERATION: (max_epochs - 1) if exp_iter_stops is None else exp_iter_stops, + Events.TERMINATE: n_terminate, + Events.TERMINATE_SINGLE_EPOCH: 0, } for n, handler in handlers.items(): @@ -619,6 +686,13 @@ def _test_run_check_triggered_events(self): ) self._test_check_triggered_events(None, max_epochs=5, epoch_length=150, exp_iter_stops=0) + kwargs = dict(exp_iter_stops=4, n_epoch_started=5, n_epoch_completed=5) + self._test_check_triggered_events(list(range(20)), max_iters=100, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=2, n_epoch_started=5, n_epoch_completed=5) + self._test_check_triggered_events(list(range(20)), max_iters=50, epoch_length=10, **kwargs) + kwargs = dict(exp_iter_stops=2, n_epoch_started=3, n_epoch_completed=2) + self._test_check_triggered_events(list(range(20)), max_iters=55, epoch_length=25, **kwargs) + def test_run_check_triggered_events_list(self): self._test_run_check_triggered_events() @@ -632,27 +706,88 @@ def infinite_data_iterator(): self._test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=50, exp_iter_stops=0) self._test_check_triggered_events(infinite_data_iterator(), max_epochs=5, epoch_length=150, exp_iter_stops=0) - def limited_data_iterator(): - for i in range(100): + kwargs = dict(exp_iter_stops=0, n_epoch_started=5, n_epoch_completed=5) + self._test_check_triggered_events(infinite_data_iterator(), max_iters=100, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=0, n_epoch_started=1, n_epoch_completed=0) + self._test_check_triggered_events(infinite_data_iterator(), max_iters=10, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=0, n_epoch_started=2, n_epoch_completed=1) + self._test_check_triggered_events(infinite_data_iterator(), max_iters=30, epoch_length=20, **kwargs) + + def limited_data_iterator(length=100): + for i in range(length): yield i self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=100, exp_iter_stops=0) self._test_check_triggered_events(limited_data_iterator(), max_epochs=10, epoch_length=10, exp_iter_stops=0) - # These tests should fail - with pytest.raises(AssertionError): - with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=100) - - with pytest.raises(AssertionError): - with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=75) - - with pytest.raises(AssertionError): - # Below test does not raise "Data iterator can not provide data anymore" warning as the last - # epoch is equal max_epochs - # with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): - self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=101) + kwargs = dict(exp_iter_stops=0, n_epoch_started=1, n_epoch_completed=1) + self._test_check_triggered_events(limited_data_iterator(), max_iters=20, epoch_length=20, **kwargs) + kwargs = dict(exp_iter_stops=0, n_epoch_started=2, n_epoch_completed=1) + self._test_check_triggered_events(limited_data_iterator(), max_iters=19, epoch_length=10, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(length=20), max_epochs=3, epoch_length=20, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=22, # 22 and not 21. GET_BATCH_STARTED is called once more to epoch_length + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_epochs=3, epoch_length=15, **kwargs) + + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=1, + n_epoch_completed=0, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_epochs=1, epoch_length=21, **kwargs) + + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + kwargs = dict( + exp_iter_stops=1, + n_epoch_started=2, + n_epoch_completed=1, + n_iter_started=20, + n_iter_completed=20, + n_batch_started=21, + n_batch_completed=20, + n_terminate=1, + ) + self._test_check_triggered_events(limited_data_iterator(), max_iters=21, epoch_length=12, **kwargs) def test_run_check_triggered_events_on_iterator(self): self._test_run_check_triggered_events_on_iterator() @@ -760,7 +895,14 @@ def run_evaluation(_): assert train_batches[epoch_length + i] != train_batches[2 * epoch_length + i] assert train_batches[i] == train_only_batches[i] - def test_engine_with_dataloader_no_auto_batching(self): + @pytest.mark.parametrize( + "kwargs", + [ + {"max_epochs": None, "epoch_length": 10, "max_iters": 25}, + {"max_epochs": 5, "epoch_length": 10, "max_iters": None}, + ], + ) + def test_engine_with_dataloader_no_auto_batching(self, kwargs): # tests https://github.com/pytorch/ignite/issues/941 from torch.utils.data import BatchSampler, DataLoader, RandomSampler @@ -775,9 +917,12 @@ def foo(e, b): counter[0] += 1 engine = Engine(foo) - engine.run(data_loader, epoch_length=10, max_epochs=5) + engine.run(data_loader, **kwargs) - assert counter[0] == 50 + if kwargs["max_epochs"]: + assert counter[0] == kwargs["epoch_length"] * kwargs["max_epochs"] + else: + assert counter[0] == kwargs["max_iters"] def test_run_once_finite_iterator_no_epoch_length(self): # FR: https://github.com/pytorch/ignite/issues/871 @@ -788,19 +933,41 @@ def finite_unk_size_data_iter(): for i in range(unknown_size): yield i - bc = BatchChecker(data=list(range(unknown_size))) + def _test(**kwargs): + bc = BatchChecker(data=list(range(unknown_size))) - engine = Engine(lambda e, b: bc.check(b)) + engine = Engine(lambda e, b: bc.check(b)) - completed_handler = MagicMock() - engine.add_event_handler(Events.COMPLETED, completed_handler) + epoch_completed_handler = MagicMock() + engine.add_event_handler(Events.EPOCH_COMPLETED, epoch_completed_handler) - data_iter = finite_unk_size_data_iter() - engine.run(data_iter) + completed_handler = MagicMock() + engine.add_event_handler(Events.COMPLETED, completed_handler) - assert engine.state.epoch == 1 - assert engine.state.iteration == unknown_size - assert completed_handler.call_count == 1 + data_iter = finite_unk_size_data_iter() + engine.run(data_iter, **kwargs) + + assert bc.counter == engine.state.iteration + if len(kwargs) == 0: + assert engine.state.epoch == 1 + assert engine.state.iteration == unknown_size + assert epoch_completed_handler.call_count == 1 + else: + max_iters = kwargs["max_iters"] + if max_iters <= unknown_size: + assert engine.state.epoch == 1 + assert engine.state.iteration == max_iters + else: + assert engine.state.epoch == 2 + assert engine.state.iteration == unknown_size + + assert completed_handler.call_count == 1 + + _test() + _test(max_iters=unknown_size) + _test(max_iters=unknown_size // 2) + with pytest.warns(UserWarning, match=r"Data iterator can not provide data anymore"): + _test(max_iters=unknown_size * 2) def test_run_finite_iterator_no_epoch_length(self): # FR: https://github.com/pytorch/ignite/issues/871 @@ -832,153 +999,181 @@ def finite_size_data_iter(size): for i in range(size): yield i - bc = BatchChecker(data=list(range(known_size))) + def _test(**kwargs): + bc = BatchChecker(data=list(range(known_size))) - engine = Engine(lambda e, b: bc.check(b)) + engine = Engine(lambda e, b: bc.check(b)) - @engine.on(Events.ITERATION_COMPLETED(every=known_size)) - def restart_iter(): - engine.state.dataloader = finite_size_data_iter(known_size) + @engine.on(Events.ITERATION_COMPLETED(every=known_size)) + def restart_iter(): + engine.state.dataloader = finite_size_data_iter(known_size) - data_iter = finite_size_data_iter(known_size) - engine.run(data_iter, max_epochs=5) + data_iter = finite_size_data_iter(known_size) + engine.run(data_iter, **kwargs) - assert engine.state.epoch == 5 - assert engine.state.iteration == known_size * 5 + assert bc.counter == engine.state.iteration + if "max_epochs" in kwargs: + assert engine.state.epoch == kwargs["max_epochs"] + assert engine.state.iteration == known_size * kwargs["max_epochs"] + else: + max_iters = kwargs["max_iters"] + if max_iters <= known_size: + assert engine.state.epoch == math.ceil(max_iters / known_size) + assert engine.state.iteration == max_iters + + _test(max_epochs=5) + _test(max_iters=known_size) + _test(max_iters=known_size // 2) def test_faq_inf_iterator_with_epoch_length(self): - # Code snippet from FAQ - # import torch + def _test(max_epochs, max_iters): + # Code snippet from FAQ + # import torch - torch.manual_seed(12) + torch.manual_seed(12) - def infinite_iterator(batch_size): - while True: - batch = torch.rand(batch_size, 3, 32, 32) - yield batch + def infinite_iterator(batch_size): + while True: + batch = torch.rand(batch_size, 3, 32, 32) + yield batch - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") - trainer = Engine(train_step) - # We need to specify epoch_length to define the epoch - trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=3) + trainer = Engine(train_step) + # We need to specify epoch_length to define the epoch + trainer.run(infinite_iterator(4), epoch_length=5, max_epochs=max_epochs, max_iters=max_iters) - assert trainer.state.epoch == 3 - assert trainer.state.iteration == 3 * 5 + assert trainer.state.epoch == 3 + assert trainer.state.iteration == 3 * 5 + + _test(max_epochs=3, max_iters=None) + _test(max_epochs=None, max_iters=3 * 5) def test_faq_inf_iterator_no_epoch_length(self): - # Code snippet from FAQ - # import torch + def _test(max_epochs, max_iters): + # Code snippet from FAQ + # import torch - torch.manual_seed(12) + torch.manual_seed(12) - def infinite_iterator(batch_size): - while True: - batch = torch.rand(batch_size, 3, 32, 32) - yield batch + def infinite_iterator(batch_size): + while True: + batch = torch.rand(batch_size, 3, 32, 32) + yield batch - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch.norm():.3f}") - trainer = Engine(train_step) + trainer = Engine(train_step) + + @trainer.on(Events.ITERATION_COMPLETED(once=15)) + def stop_training(): + trainer.terminate() - @trainer.on(Events.ITERATION_COMPLETED(once=15)) - def stop_training(): - trainer.terminate() + trainer.run(infinite_iterator(4), max_epochs=max_epochs, max_iters=max_iters) - trainer.run(infinite_iterator(4)) + assert trainer.state.epoch == 1 + assert trainer.state.iteration == 15 - assert trainer.state.epoch == 1 - assert trainer.state.iteration == 15 + _test(max_epochs=None, max_iters=None) + _test(max_epochs=None, max_iters=100) def test_faq_fin_iterator_unknw_size(self): - # Code snippet from FAQ - # import torch + def _test(max_epochs, max_iters): + # Code snippet from FAQ + # import torch + torch.manual_seed(12) - torch.manual_seed(12) + def finite_unk_size_data_iter(): + for i in range(11): + yield i - def finite_unk_size_data_iter(): - for i in range(11): - yield i + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + trainer = Engine(train_step) - trainer = Engine(train_step) + @trainer.on(Events.DATALOADER_STOP_ITERATION) + def restart_iter(): + trainer.state.dataloader = finite_unk_size_data_iter() - @trainer.on(Events.DATALOADER_STOP_ITERATION) - def restart_iter(): - trainer.state.dataloader = finite_unk_size_data_iter() + data_iter = finite_unk_size_data_iter() + trainer.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) - data_iter = finite_unk_size_data_iter() - trainer.run(data_iter, max_epochs=5) + assert trainer.state.epoch == 5 if max_iters is None else math.ceil(max_iters // 11) + assert trainer.state.iteration == 5 * 11 if max_iters is None else max_iters - assert trainer.state.epoch == 5 - assert trainer.state.iteration == 5 * 11 + _test(max_epochs=5, max_iters=None) + _test(max_epochs=None, max_iters=60) - # Code snippet from FAQ - # import torch + # # # # # - torch.manual_seed(12) + def _test(max_epochs, max_iters): + # import torch + torch.manual_seed(12) - def finite_unk_size_data_iter(): - for i in range(11): - yield i + def finite_unk_size_data_iter(): + for i in range(11): + yield i - def val_step(evaluator, batch): - # ... - s = evaluator.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + def val_step(evaluator, batch): + # ... + s = evaluator.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - evaluator = Engine(val_step) + evaluator = Engine(val_step) - data_iter = finite_unk_size_data_iter() - evaluator.run(data_iter) + data_iter = finite_unk_size_data_iter() + evaluator.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) - assert evaluator.state.epoch == 1 - assert evaluator.state.iteration == 1 * 11 + assert evaluator.state.epoch == 1 + assert evaluator.state.iteration == 1 * 11 + + _test(max_epochs=None, max_iters=None) def test_faq_fin_iterator(self): - # Code snippet from FAQ - # import torch + def _test(max_epochs, max_iters): + # Code snippet from FAQ - torch.manual_seed(12) + # import torch + torch.manual_seed(12) + size = 11 - size = 11 + def finite_size_data_iter(size): + for i in range(size): + yield i - def finite_size_data_iter(size): - for i in range(size): - yield i + def train_step(trainer, batch): + # ... + s = trainer.state + print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") - def train_step(trainer, batch): - # ... - s = trainer.state - print(f"{s.epoch}/{s.max_epochs} : {s.iteration} - {batch:.3f}") + trainer = Engine(train_step) - trainer = Engine(train_step) + @trainer.on(Events.ITERATION_COMPLETED(every=size)) + def restart_iter(): + trainer.state.dataloader = finite_size_data_iter(size) - @trainer.on(Events.ITERATION_COMPLETED(every=size)) - def restart_iter(): - trainer.state.dataloader = finite_size_data_iter(size) + data_iter = finite_size_data_iter(size) + trainer.run(data_iter, max_epochs=max_epochs, max_iters=max_iters) - data_iter = finite_size_data_iter(size) - trainer.run(data_iter, max_epochs=5) + assert trainer.state.epoch == 5 + assert trainer.state.iteration == 5 * size - assert trainer.state.epoch == 5 - assert trainer.state.iteration == 5 * size + _test(max_epochs=5, max_iters=None) + _test(max_epochs=None, max_iters=5 * 11) - # Code snippet from FAQ - # import torch + # # # # # + # import torch torch.manual_seed(12) - size = 11 def finite_size_data_iter(size): @@ -1227,6 +1422,82 @@ def check_iter_epoch(first_epoch_iter): assert engine.state.epoch == 10 assert engine.state.iteration == 10 * real_epoch_length + def test_restart_training(self): + data = range(10) + engine = Engine(lambda e, b: 1) + state = engine.run(data, max_epochs=5) + with pytest.raises( + ValueError, + match=r"Argument max_epochs should be greater than or equal to the start epoch" + " defined in the state: 2 vs 5. " + r"Please, .+ " + r"before calling engine.run\(\) in order to restart the training from the beginning.", + ): + engine.run(data, max_epochs=2) + state.max_epochs = None + engine.run(data, max_epochs=2) + + def test_engine_multiple_runs(self): + engine = Engine(lambda e, b: 1) + engine.debug() + + init_epoch = 0 + init_iter = 0 + epoch_length = None + + @engine.on(Events.STARTED) + def assert_resume(): + assert engine.state.epoch == init_epoch + assert engine.state.iteration == init_iter + assert engine.state.epoch_length == epoch_length + + data = range(10) + epoch_length = len(data) + engine.run(data, max_epochs=2) + assert engine.state.epoch == 2 + assert engine.state.iteration == 2 * epoch_length + + engine.debug(False) + + # Continue run with max_epochs + data = range(15) + init_epoch = 2 + init_iter = 2 * epoch_length + engine.run(data, max_epochs=5) + + assert engine.state.epoch == 5 + assert engine.state.iteration == 5 * epoch_length + + # Continue run with max_iters + data = range(15) + init_epoch = 5 + init_iter = 5 * epoch_length + with pytest.raises(ValueError, match=r"State attributes max_iters and max_epochs are mutually exclusive"): + engine.run(data, max_iters=6 * epoch_length) + + engine.state.max_epochs = None + engine.run(data, max_iters=6 * epoch_length) + + assert engine.state.epoch == 6 + assert engine.state.iteration == 6 * epoch_length + + def test_engine_multiple_runs_2(self): + + e = Engine(lambda _, b: None) + data = iter(range(100)) + + e.run(data, max_iters=50) + assert e.state.iteration == 50 + assert e.state.epoch == 1 + e.run(data, max_iters=52) + assert e.state.iteration == 52 + # should be 1 and if 2 this is a bug : https://github.com/pytorch/ignite/issues/1386 + assert e.state.epoch == 2 + e.run(data, max_iters=100) + assert e.state.iteration == 100 + # should be 1 and if 3 this is a bug : https://github.com/pytorch/ignite/issues/1386 + assert e.state.epoch == 3 + @pytest.mark.parametrize( "interrupt_event, e, i", diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index 4ccfb7ea772..1c4934eb6b0 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -14,6 +14,7 @@ def test_state_dict(): assert "iteration" in sd and sd["iteration"] == 0 assert "max_epochs" in sd and sd["max_epochs"] is None assert "epoch_length" in sd and sd["epoch_length"] is None + assert "max_iters" in sd and sd["max_iters"] is None def _test(state): engine.state = state @@ -23,8 +24,14 @@ def _test(state): assert sd["epoch_length"] == engine.state.epoch_length assert sd["max_epochs"] == engine.state.max_epochs + if state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters + _test(State(iteration=500, epoch_length=1000, max_epochs=100)) _test(State(epoch=5, epoch_length=1000, max_epochs=100)) + _test(State(epoch=5, epoch_length=1000, max_iters=500)) def test_state_dict_with_user_keys(): @@ -40,22 +47,39 @@ def _test(state): ) assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length - assert sd["max_epochs"] == engine.state.max_epochs + if state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters assert sd["alpha"] == engine.state.alpha assert sd["beta"] == engine.state.beta _test(State(iteration=500, epoch_length=1000, max_epochs=100, alpha=0.01, beta="Good")) + _test(State(iteration=500, epoch_length=1000, max_iters=2000, alpha=0.01, beta="Good")) def test_state_dict_integration(): - engine = Engine(lambda e, b: 1) - data = range(100) - engine.run(data, max_epochs=10) - sd = engine.state_dict() - assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 - assert sd["iteration"] == engine.state.iteration == 10 * 100 - assert sd["epoch_length"] == engine.state.epoch_length == 100 - assert sd["max_epochs"] == engine.state.max_epochs == 10 + def _test(max_epochs, max_iters): + engine = Engine(lambda e, b: 1) + data = range(100) + engine.run(data, max_epochs=max_epochs, max_iters=max_iters) + sd = engine.state_dict() + assert isinstance(sd, Mapping) + assert len(sd) == len(engine._state_dict_all_req_keys) + 1 + + if max_epochs is None and max_iters is None: + max_epochs = 1 + n_iters = max_iters if max_iters is not None else max_epochs * 100 + assert sd["iteration"] == engine.state.iteration == n_iters + assert sd["epoch_length"] == engine.state.epoch_length == 100 + if engine.state.max_epochs is not None: + assert sd["max_epochs"] == engine.state.max_epochs == max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters == max_iters + + _test(max_epochs=10, max_iters=None) + _test(max_epochs=None, max_iters=None) + _test(max_epochs=None, max_iters=10 * 100) def test_load_state_dict_asserts(): @@ -93,11 +117,29 @@ def _test(sd): elif "epoch" in sd: assert sd["epoch"] == engine.state.epoch assert sd["epoch_length"] == engine.state.epoch_length - assert sd["max_epochs"] == engine.state.max_epochs + if "max_epochs" in sd: + assert sd["max_epochs"] == engine.state.max_epochs + else: + assert sd["max_iters"] == engine.state.max_iters _test({"max_epochs": 100, "epoch_length": 120, "iteration": 123}) _test({"max_epochs": 100, "epoch_length": 120, "epoch": 5}) + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than the current"): + _test({"max_epochs": 10, "epoch_length": 120, "epoch": 50}) + + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than the current"): + _test({"max_epochs": 10, "epoch_length": 120, "iteration": 5000}) + + _test({"max_iters": 500, "epoch_length": 120, "iteration": 123}) + _test({"max_iters": 500, "epoch_length": 120, "epoch": 3}) + + with pytest.raises(ValueError, match=r"Argument max_iters should be greater than"): + _test({"max_iters": 500, "epoch_length": 120, "epoch": 5}) + + with pytest.raises(ValueError, match=r"Argument max_iters should be greater than"): + _test({"max_iters": 500, "epoch_length": 120, "iteration": 501}) + def test_load_state_dict_with_user_keys(): engine = Engine(lambda e, b: 1) @@ -142,8 +184,7 @@ def test_load_state_dict_with_params_overriding_integration(): assert state.max_epochs == new_max_epochs assert state.iteration == state_dict["epoch_length"] * new_max_epochs assert state.epoch == new_max_epochs - - with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than or equal to the start epoch"): + with pytest.raises(ValueError, match=r"Argument max_epochs should be greater than the current epoch"): engine.load_state_dict(state_dict) engine.run(data, max_epochs=3)