Skip to content

Commit

Permalink
Improved shutdown behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Jun 24, 2024
1 parent 2682b5e commit 68c60ea
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 42 deletions.
2 changes: 1 addition & 1 deletion zerofun/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '2.0.6'
__version__ = '2.1.0'

import multiprocessing as mp
try:
Expand Down
17 changes: 17 additions & 0 deletions zerofun/pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import concurrent.futures


class ThreadPool:

def __init__(self, workers, name):
self.pool = concurrent.futures.ThreadPoolExecutor(workers, name)

def submit(self, fn, *args, **kwargs):
future = self.pool.submit(fn, *args, **kwargs)
# Prevent deamon threads from hanging due to exit handlers registered by
# the concurrent.futures modules.
concurrent.futures.thread._threads_queues.clear()
return future

def close(self, wait=False):
self.pool.shutdown(wait=wait)
14 changes: 9 additions & 5 deletions zerofun/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Process:
current_name = None

def __init__(self, fn, *args, name=None, start=False, pass_running=False):
name = name or fn.__name__
name = name or getattr(fn, '__name__', None)
fn = cloudpickle.dumps(fn)
inits = cloudpickle.dumps(self.initializers)
context = mp.get_context()
Expand All @@ -38,7 +38,7 @@ def pid(self):
def running(self):
running = self.process.is_alive()
if running:
assert self.exitcode is None, self.exitcode
assert self.exitcode is None, (self.name, self.exitcode)
return running

@property
Expand Down Expand Up @@ -78,9 +78,13 @@ def kill(self):
self.process.terminate()
self.process.join(timeout=0.1)
if self.running:
os.kill(self.pid, signal.SIGKILL)
self.process.join(timeout=1.0)
assert not self.running, self.name
try:
os.kill(self.pid, signal.SIGKILL)
self.process.join(timeout=0.1)
except ProcessLookupError:
pass
if self.running:
print(f'Process {self.name} did not shut down yet.')

def __repr__(self):
attrs = ('name', 'pid', 'running', 'exitcode')
Expand Down
27 changes: 15 additions & 12 deletions zerofun/server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import time
import concurrent.futures
import time
import traceback
from collections import deque, namedtuple

import elements
import numpy as np

from . import sockets
from . import pool as poollib
from . import thread


Expand All @@ -23,9 +25,9 @@ def __init__(
self.errors = errors
self.ipv6 = ipv6
self.methods = {}
self.default_pool = concurrent.futures.ThreadPoolExecutor(workers, 'work')
self.default_pool = poollib.ThreadPool(workers, 'work')
self.other_pools = []
self.done_pool = concurrent.futures.ThreadPoolExecutor(1, 'log')
self.done_pool = poollib.ThreadPool(1, 'log')
self.result_set = set()
self.done_queue = deque()
self.done_proms = deque()
Expand All @@ -35,7 +37,7 @@ def __init__(

def bind(self, name, workfn, donefn=None, workers=0, batch=0):
if workers:
pool = concurrent.futures.ThreadPoolExecutor(workers, name)
pool = poollib.ThreadPool(workers, name)
self.other_pools.append(pool)
else:
workers = self.workers
Expand All @@ -50,8 +52,6 @@ def start(self):

def check(self):
self.loop.check()
for pool in [self.default_pool] + self.other_pools:
assert not pool._broken
[not x.done() or x.result() for x in self.result_set.copy()]
[not x.done() or x.result() for x in self.done_proms.copy()]
if self.exception:
Expand All @@ -61,12 +61,11 @@ def check(self):

def close(self):
self._print('Shutting down')
concurrent.futures.wait(self.result_set)
concurrent.futures.wait(self.done_proms)
self.loop.stop()
self.default_pool.shutdown()
self.default_pool.close()
self.done_pool.close()
for pool in self.other_pools:
pool.shutdown()
pool.close()

def run(self):
try:
Expand Down Expand Up @@ -148,11 +147,15 @@ def _handle_results(self, socket, now):
for addr, rid, payload in zip(addr, rid, payload):
socket.send_result(addr, rid, payload)
for recvd in recvd:
self.agg.add('result_time', now - recvd, ('min', 'avg', 'max'))
self.agg.add(method.name, now - recvd, ('min', 'avg', 'max'))
else:
socket.send_result(addr, rid, payload)
self.agg.add('result_time', now - recvd, ('min', 'avg', 'max'))
self.agg.add(method.name, now - recvd, ('min', 'avg', 'max'))
except Exception as e:
print(f'Exception in server {self.name}:')
typ, tb = type(e), e.__traceback__
full = ''.join(traceback.format_exception(typ, e, tb)).strip('\n')
print(full)
if method.batched:
for addr, rid in zip(future.addr, future.rid):
socket.send_error(addr, rid, repr(e))
Expand Down
20 changes: 6 additions & 14 deletions zerofun/thread.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ctypes
import threading

from . import utils
Expand All @@ -10,7 +9,7 @@ def __init__(self, fn, *args, name=None, start=False):
self.fn = fn
self._exitcode = None
self.exception = None
name = name or fn.__name__
name = name or getattr(fn, '__name__', None)
self.thread = threading.Thread(
target=self._wrapper, args=args, name=name, daemon=True)
self.started = False
Expand All @@ -28,7 +27,7 @@ def ident(self):
def running(self):
running = self.thread.is_alive()
if running:
assert self.exitcode is None, self.exitcode
assert self.exitcode is None, (self.name, self.exitcode)
return running

@property
Expand All @@ -52,17 +51,10 @@ def join(self, timeout=None):
def kill(self):
if not self.running:
return
thread = self.thread
if hasattr(thread, '_thread_id'):
thread_id = thread._thread_id
else:
thread_id = [k for k, v in threading._active.items() if v is thread][0]
result = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), ctypes.py_object(SystemExit))
if result > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), None)
utils.kill_thread(self.thread)
self.thread.join(0.1)
if self.running:
print(f'Thread {self.name} did not shut down yet.')

def __repr__(self):
attrs = ('name', 'ident', 'running', 'exitcode')
Expand All @@ -73,7 +65,7 @@ def _wrapper(self, *args):
try:
self.fn(*args)
except (SystemExit, KeyboardInterrupt):
pass
return
except Exception as e:
utils.warn_remote_error(e, self.name)
self._exitcode = 1
Expand Down
51 changes: 41 additions & 10 deletions zerofun/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ctypes
import multiprocessing as mp
import os
import socket
import sys
import threading
import time
import traceback

import elements
import embodied
import numpy as np
import psutil


Expand All @@ -16,10 +19,10 @@ def get_print_lock():
return _PRINT_LOCK


_PORTS = iter(range(5000, 8000))
def get_free_port():
rng = np.random.default_rng()
while True:
port = next(_PORTS)
port = int(rng.integers(5000, 8000))
if port_free(port):
return port

Expand All @@ -29,7 +32,7 @@ def port_free(port):
return s.connect_ex(('localhost', int(port)))


def run(workers, duration=None):
def run(workers, duration=None, exit_after=False):
try:

for worker in workers:
Expand All @@ -52,10 +55,14 @@ def run(workers, duration=None):
for worker in workers:
if worker.exitcode not in (None, 0):
time.sleep(0.1) # Wait for workers to print their error messages.
msg = f'Terminated workers due to crash in {worker.name}.'
msg = f'Shutting down workers due to crash in {worker.name}.'
print(msg)
worker.check()
raise RuntimeError(msg) # In case the check did not raise.
if exit_after:
for worker in workers:
if hasattr(worker, 'pid'):
kill_subprocs(worker.pid)
worker.check() # Raise the forwarded exception.
raise RuntimeError(msg) # In case exception was not forwarded.
time.sleep(0.1)

finally:
Expand All @@ -66,6 +73,16 @@ def run(workers, duration=None):
[x.kill() for x in workers]


def assert_no_children(parent=None):
procs = list(psutil.Process(parent).children(recursive=True))
threads = list(threading.enumerate())
print(
f'Process {os.getpid()} should have no children.\n' +
f'Threads: {threads}\n'
f'Subprocs: {procs}')
kill_subprocs(parent)


def kill_subprocs(parent=None):
try:
procs = list(psutil.Process(parent).children(recursive=True))
Expand Down Expand Up @@ -115,6 +132,20 @@ def proc_alive(pid):
return False


def kill_thread(thread):
if isinstance(thread, int):
thread_id = int(thread)
elif hasattr(thread, '_thread_id'):
thread_id = thread._thread_id
else:
thread_id = [k for k, v in threading._active.items() if v is thread][0]
result = ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), ctypes.py_object(SystemExit))
if result > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread_id), None)


def warn_remote_error(e, name, lock=get_print_lock):
lock = lock() if callable(lock) else lock
typ, tb = type(e), e.__traceback__
Expand All @@ -124,8 +155,8 @@ def warn_remote_error(e, name, lock=get_print_lock):
msg += 'Call check() to reraise in main process. '
msg += f'Worker stack trace:\n{full}'
with lock:
elements.print(msg, 'red')
if sys.version_info.minor >= 11:
embodied.print(msg, color='red')
if hasattr(e, 'add_note'):
e.add_note(f'\nWorker stack trace:\n\n{full}')


Expand Down

0 comments on commit 68c60ea

Please sign in to comment.