Skip to content

Commit

Permalink
Drop references to finished processes in local manager (#90)
Browse files Browse the repository at this point in the history
* Drop references to finished processes in local manager

This ensures we are not going to exceed file descriptor limits on
Unix platforms.

* Refactor parts of local manager `get_message` to separate methods

* Make sure we are not on Windows before importing `resource`

* Improve a comment in local manager tests

* Modify open file descriptor limit for the duration of test
  • Loading branch information
xadrianzetx committed Jun 11, 2024
1 parent 4877cad commit 9b35c80
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 15 deletions.
41 changes: 28 additions & 13 deletions optuna_distributed/managers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def __init__(self, n_trials: int, n_jobs: int) -> None:

self._workers_to_spawn = min(self._n_jobs, n_trials)
self._trials_remaining = n_trials - self._workers_to_spawn
self._pool: dict[int, Connection] = {}
self._processes: list[Process] = []

self._connections: dict[int, Connection] = {}
self._processes: dict[int, Process] = {}

def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None:
trial_ids = [study.ask()._trial_id for _ in range(self._workers_to_spawn)]
Expand All @@ -59,14 +60,15 @@ def create_futures(self, study: Study, objective: ObjectiveFuncType) -> None:
trial = DistributedTrial(trial_id, Pipe(worker))
p = Process(target=_trial_runtime, args=(objective, trial), daemon=True)
p.start()
self._processes.append(p)
self._pool[trial_id] = master
worker.close()

self._processes[trial_id] = p
self._connections[trial_id] = master

def get_message(self) -> Generator[Message, None, None]:
while True:
messages: list[Message] = []
for incoming in wait(self._pool.values(), timeout=10):
for incoming in wait(self._connections.values(), timeout=10):
# FIXME: This assertion is true only for Unix systems.
# Some refactoring is needed to support Windows as well.
# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.connection.wait
Expand All @@ -76,39 +78,52 @@ def get_message(self) -> Generator[Message, None, None]:
messages.append(message)

except EOFError:
for trial_id, connection in self._pool.items():
if incoming == connection:
break
self._pool.pop(trial_id)
self._close_connection(incoming)

self._workers_to_spawn = min(self._n_jobs - len(self._pool), self._trials_remaining)
self._set_workers_to_spawn()
if messages:
yield from messages
else:
yield HeartbeatMessage()

def after_message(self, event_loop: "EventLoop") -> None:
if self._workers_to_spawn > 0:
self._join_finished_processes()
self.create_futures(event_loop.study, event_loop.objective)

self._trials_remaining -= self._workers_to_spawn
self._workers_to_spawn = 0

def get_connection(self, trial_id: int) -> IPCPrimitive:
return Pipe(self._pool[trial_id])
return Pipe(self._connections[trial_id])

def stop_optimization(self, patience: float) -> None:
for process in self._processes:
for process in self._processes.values():
if process.is_alive():
process.kill()
process.join(timeout=patience)

def should_end_optimization(self) -> bool:
return len(self._pool) == 0 and self._trials_remaining == 0
return len(self._connections) == 0 and self._trials_remaining == 0

def register_trial_exit(self, trial_id: int) -> None:
# Noop, as worker informs us about exit by closing connection.
...

def _close_connection(self, connection: Connection) -> None:
for trial_id, open_connection in self._connections.items():
if connection == open_connection:
break

self._connections.pop(trial_id).close()

def _set_workers_to_spawn(self) -> None:
self._workers_to_spawn = min(self._n_jobs - len(self._connections), self._trials_remaining)

def _join_finished_processes(self) -> None:
for trial_id in [tid for tid, p in self._processes.items() if p.exitcode is not None]:
self._processes.pop(trial_id).join()


def _trial_runtime(func: ObjectiveFuncType, trial: DistributedTrial) -> None:
message: Message
Expand Down
44 changes: 42 additions & 2 deletions tests/test_managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from contextlib import contextmanager
from dataclasses import dataclass
import multiprocessing
import sys
import time
from typing import Generator
from unittest.mock import Mock
import uuid

Expand Down Expand Up @@ -222,7 +224,7 @@ def test_local_stops_optimziation() -> None:
manager.stop_optimization(patience=10.0)
interrupted_execution_time = time.time() - stopped_at
assert interrupted_execution_time < uninterrupted_execution_time
for process in manager._processes:
for process in manager._processes.values():
assert not process.is_alive()


Expand Down Expand Up @@ -273,6 +275,44 @@ class _MockEventLoop:
message.process(study, manager)
manager.after_message(eventloop) # type: ignore
if not manager.should_end_optimization():
assert 0 < len(manager._pool) <= multiprocessing.cpu_count()
assert 0 < len(manager._connections) <= multiprocessing.cpu_count()
else:
break


@pytest.mark.skipif(sys.platform == "win32", reason="No file descriptor limits on Windows.")
def test_local_free_resources() -> None:
@contextmanager
def _limited_nofile(limit: int) -> Generator[None, None, None]:
import resource

soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))
yield

resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))

@dataclass
class _MockEventLoop:
study: optuna.Study
objective: ObjectiveFuncType

study = optuna.create_study()
eventloop = _MockEventLoop(study, _objective_local_worker_pool_management)

# Try to run more trials than there are available file descriptors. Incorrectly managed
# optimization will fail by exceeding this limit.
n_trials = 1024
with _limited_nofile(n_trials):
manager = LocalOptimizationManager(n_trials=n_trials + 1, n_jobs=5)
manager.create_futures(study, objective=_objective_local_worker_pool_management)

try:
for message in manager.get_message():
message.process(study, manager)
manager.after_message(eventloop) # type: ignore
if manager.should_end_optimization():
break

except OSError:
pytest.fail("File descriptor limit reached")

0 comments on commit 9b35c80

Please sign in to comment.