Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Improved max_iters handling #3235

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions ignite/base/mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from collections import OrderedDict
from collections.abc import Mapping
from typing import Tuple
from typing import List, Tuple


class Serializable:
_state_dict_all_req_keys: Tuple = ()
_state_dict_one_of_opt_keys: Tuple = ()
_state_dict_all_req_keys: Tuple[str, ...] = ()
_state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),)

def __init__(self) -> None:
self._state_dict_user_keys: List[str] = []

@property
def state_dict_user_keys(self) -> List:
return self._state_dict_user_keys

def state_dict(self) -> OrderedDict:
raise NotImplementedError
Expand All @@ -19,6 +26,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
raise ValueError(
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
opts = [k in state_dict for k in one_of_opt_keys]
if len(opts) > 0 and (not any(opts)) or (all(opts)):
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")

for k in self._state_dict_user_keys:
if k not in state_dict:
raise ValueError(
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
146 changes: 133 additions & 13 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,16 @@ def compute_mean_std(engine, batch):
"""

_state_dict_all_req_keys = ("epoch_length", "max_epochs")
_state_dict_one_of_opt_keys = ("iteration", "epoch")
_state_dict_one_of_opt_keys = (
(
"iteration",
"epoch",
),
(
"max_epochs",
"max_iters",
),
)

# Flag to disable engine._internal_run as generator feature for BC
interrupt_resume_enabled = True
Expand Down Expand Up @@ -310,6 +319,7 @@ def execute_something():
for e in event_name:
self.add_event_handler(e, handler, *args, **kwargs)
return RemovableEventHandle(event_name, handler, self)

if isinstance(event_name, CallableEventWithFilter) and event_name.filter is not None:
event_filter = event_name.filter
handler = self._handler_wrapper(handler, event_name, event_filter)
Expand All @@ -332,6 +342,16 @@ def execute_something():

return RemovableEventHandle(event_name, handler, self)

@staticmethod
def _assert_non_filtered_event(event_name: Any) -> None:
if (
isinstance(event_name, CallableEventWithFilter)
and event_name.filter != CallableEventWithFilter.default_event_filter
):
raise TypeError(
"Argument event_name should not be a filtered event, " "please use event without any event filtering"
)

def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool:
"""Check if the specified event has the specified handler.

Expand Down Expand Up @@ -675,7 +695,12 @@ def save_engine(_):
a dictionary containing engine's state

"""
keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],)
keys: Tuple[str, ...] = self._state_dict_all_req_keys
keys += ("iteration",)
if self.state.max_epochs is not None:
keys += ("max_epochs",)
else:
keys += ("max_iters",)
keys += tuple(self._state_dict_user_keys)
return OrderedDict([(k, getattr(self.state, k)) for k in keys])

Expand Down Expand Up @@ -728,6 +753,8 @@ def load_state_dict(self, state_dict: Mapping) -> None:
f"Input state_dict: {state_dict}"
)
self.state.iteration = self.state.epoch_length * self.state.epoch
self._check_and_set_max_epochs(state_dict.get("max_epochs", None))
self._check_and_set_max_iters(state_dict.get("max_iters", None))

@staticmethod
def _is_done(state: State) -> bool:
Expand Down Expand Up @@ -864,12 +891,26 @@ def switch_batch(engine):

epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
raise ValueError(
"Argument epoch_length is invalid. Please, either set a"
" correct epoch_length value or check if input data has"
" non-zero size."
)

if max_iters is None:
if max_epochs is None:
max_epochs = 1
else:
if max_iters < 1:
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
if (self.state.max_iters is not None) and max_iters <= self.state.iteration:
raise ValueError(
"Argument max_iters should be greater than the current iteration "
f"defined in the state: {max_iters} vs {self.state.iteration}. "
"Please, set engine.state.max_iters = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_iters = max_iters
if max_epochs is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive."
Expand Down Expand Up @@ -932,6 +973,53 @@ def _setup_dataloader_iter(self) -> None:
else:
self._dataloader_iter = iter(self.state.dataloader)

def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None:
if max_epochs is not None:
if max_epochs < 1:
raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value")
if self.state.max_epochs is not None and max_epochs <= self.state.epoch:
raise ValueError(
"Argument max_epochs should be greater than the current epoch "
f"defined in the state: {max_epochs} vs {self.state.epoch}. "
"Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_epochs = max_epochs

def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None:
if max_iters is not None:
if max_iters < 1:
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
if (self.state.max_iters is not None) and max_iters <= self.state.iteration:
raise ValueError(
"Argument max_iters should be greater than the current iteration "
f"defined in the state: {max_iters} vs {self.state.iteration}. "
"Please, set engine.state.max_iters = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_iters = max_iters

def _check_and_set_epoch_length(self, data: Iterable, epoch_length: Optional[int] = None) -> None:
# Can't we accept a redefinition ?
if self.state.epoch_length is not None:
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)
else:
if epoch_length is None:
epoch_length = self._get_data_length(data)

if epoch_length is not None and epoch_length < 1:
raise ValueError(
"Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
"check if input data has non-zero size."
)

self.state.epoch_length = epoch_length

def _setup_engine(self) -> None:
self._setup_dataloader_iter()

Expand Down Expand Up @@ -976,17 +1064,19 @@ def _internal_run_as_gen(self) -> Generator:
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
if self.state.epoch_length is not None and self.state.iteration % self.state.epoch_length == 0:
# max_iters can cause training to complete without an epoch ending
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
self.logger.info(
f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}"
)
yield from self._maybe_terminate_or_interrupt()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
self.logger.info(
f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}"
)

except _EngineTerminateException:
self._fire_event(Events.TERMINATE)

Expand Down Expand Up @@ -1064,6 +1154,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
self.state.epoch_length = iter_counter
if self.state.max_iters is not None:
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
# Warn but will continue until max iters is reached
warnings.warn(
"Data iterator can not provide data anymore but required total number of "
"iterations to run is not reached. "
f"Current iteration: {self.state.iteration} vs Total iterations to run :"
f" {self.state.max_iters}"
)
break

# Should exit while loop if we can not iterate
Expand Down Expand Up @@ -1106,7 +1203,13 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:

if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
self.should_terminate = True
raise _EngineTerminateException()
warnings.warn(
"Data iterator can not provide data anymore but required total number of "
"iterations to run is not reached. "
f"Current iteration: {self.state.iteration} vs Total iterations to run : ? total_iters"
)
break
# raise _EngineTerminateException()

except _EngineTerminateSingleEpochException:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
Expand Down Expand Up @@ -1231,6 +1334,13 @@ def _run_once_on_dataset_legacy(self) -> float:
self.state.epoch_length = iter_counter
if self.state.max_iters is not None:
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
# Warn but will continue until max iters is reached
warnings.warn(
"Data iterator can not provide data anymore but required total number of "
"iterations to run is not reached. "
f"Current iteration: {self.state.iteration} vs Total iterations to run :"
f" {self.state.max_iters}"
)
break

# Should exit while loop if we can not iterate
Expand Down Expand Up @@ -1291,6 +1401,16 @@ def _run_once_on_dataset_legacy(self) -> float:

return time.time() - start_time

def debug(self, enabled: bool = True) -> None:
"""Enables/disables engine's logging debug mode"""
from ignite.utils import setup_logger

if enabled:
setattr(self, "_stored_logger", self.logger)
self.logger = setup_logger(level=logging.DEBUG)
elif hasattr(self, "_stored_logger"):
self.logger = getattr(self, "_stored_logger")


def _get_none_data_iter(size: int) -> Iterator:
# Sized iterator for data as None
Expand Down
30 changes: 28 additions & 2 deletions tests/ignite/base/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,38 @@
from ignite.base import Serializable


class ExampleSerializable(Serializable):
_state_dict_all_req_keys = ("a", "b")
_state_dict_one_of_opt_keys = (("c", "d"), ("e", "f"))


def test_state_dict():
s = Serializable()
with pytest.raises(NotImplementedError):
s.state_dict()


def test_load_state_dict():
s = Serializable()
s.load_state_dict({})

s = ExampleSerializable()
with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"):
s.load_state_dict("abc")

with pytest.raises(ValueError, match=r"is absent in provided state_dict"):
s.load_state_dict({})

with pytest.raises(ValueError, match=r"is absent in provided state_dict"):
s.load_state_dict({"a": 1})

with pytest.raises(ValueError, match=r"state_dict should contain only one of"):
s.load_state_dict({"a": 1, "b": 2})

with pytest.raises(ValueError, match=r"state_dict should contain only one of"):
s.load_state_dict({"a": 1, "b": 2, "c": 3, "d": 4, "e": 5})

with pytest.raises(ValueError, match=r"state_dict should contain only one of"):
s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 5, "f": 5})

s.state_dict_user_keys.append("alpha")
with pytest.raises(ValueError, match=r"Required user state attribute"):
s.load_state_dict({"a": 1, "b": 2, "c": 3, "e": 4})
Loading
Loading