Skip to content

Commit

Permalink
Tree support and fever memory copies.
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed May 10, 2024
1 parent c5b26f0 commit 888a4ef
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 22 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
cloudpickle
elements
elements>=3.3.0
msgpack
numpy
pyzmq
30 changes: 25 additions & 5 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

sys.path.append(str(pathlib.Path(__file__).parent.parent))

import elements
import numpy as np
import pytest
import zerofun
Expand Down Expand Up @@ -446,14 +447,33 @@ def test_proxy_batched(self, Server, inner_addr, outer_addr, workers):

@pytest.mark.parametrize('Server', SERVERS)
@pytest.mark.parametrize('addr', ADDRESSES)
def test_empty_dict(self, Server, addr):
@pytest.mark.parametrize('data', (
{'a': np.zeros((3, 2), np.float32), 'b': np.ones((1,), np.uint8)},
{'a': 12, 'b': [np.ones((1,), np.uint8), 13]},
{'a': 12, 'b': ['c', [1, 2, 3]]},
[],
{},
12,
[[{}, []]],
))
def test_tree_data(self, Server, addr, data):
data = elements.tree.map(np.asarray, data)
print(data)
def tree_equal(tree1, tree2):
try:
comps = elements.tree.map(lambda x, y: np.all(x == y), tree1, tree2)
comps, _ = elements.tree.flatten(comps)
return all(comps)
except TypeError:
return False
addr = addr.format(port=zerofun.get_free_port())
client = zerofun.Client(addr, pings=0, maxage=1)
server = Server(addr)
def workfn(data):
assert data == {}
return {}
def workfn(indata):
assert tree_equal(indata, data)
return indata
server.bind('function', workfn)
with server:
client.connect(retry=False, timeout=1)
assert client.function({}).result() == {}
outdata = client.function(data).result()
assert tree_equal(outdata, data)
2 changes: 1 addition & 1 deletion zerofun/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '1.2.0'
__version__ = '2.0.0'

import multiprocessing as mp
try:
Expand Down
2 changes: 0 additions & 2 deletions zerofun/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def call(self, method, data):
self.queue.popleft().result()
except IndexError:
pass
assert isinstance(data, dict)
data = {k: np.asarray(v) for k, v in data.items()}
data = sockets.pack(data)
rid = self.socket.send_call(method, data)
self.send_per_sec.step(1)
Expand Down
1 change: 0 additions & 1 deletion zerofun/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def _work(self, method, addr, rid, payload, recvd):
result, logs = method.workfn(data)
else:
result = method.workfn(data)
result = result or {}
logs = None
if method.batched:
results = [
Expand Down
28 changes: 17 additions & 11 deletions zerofun/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import threading
import time

import elements
import numpy as np
import zmq


DEBUG = False
# DEBUG = True

Expand Down Expand Up @@ -203,7 +205,8 @@ def send_ping(self, addr):

def send_result(self, addr, rid, payload):
with self.lock:
self.socket.send_multipart([addr, Type.RESULT.value, rid, *payload])
self.socket.send_multipart(
[addr, Type.RESULT.value, rid, *payload], copy=False, track=True)

def send_error(self, addr, rid, text):
text = text.encode('utf-8')
Expand All @@ -216,23 +219,26 @@ def close(self):


def pack(data):
data = {k: np.asarray(v) for k, v in data.items()}
leaves, structure = elements.tree.flatten(data)
dtypes, shapes, buffers = [], [], []
items = sorted(data.items(), key=lambda x: x[0])
keys, vals = zip(*items) if items else ((), ())
dtypes = [v.dtype.str for v in vals]
shapes = [v.shape for v in vals]
buffers = [v.tobytes() for v in vals]
meta = (keys, dtypes, shapes)
for value in leaves:
value = np.asarray(value)
assert value.data.c_contiguous, (
"Array is not contiguous in memory. Use np.asarray(arr, order='C') " +
"before passing the data into pack().")
dtypes.append(value.dtype.str)
shapes.append(value.shape)
buffers.append(value.data)
meta = (structure, dtypes, shapes)
payload = [msgpack.packb(meta), *buffers]
return payload


def unpack(payload):
meta, *buffers = payload
keys, dtypes, shapes = msgpack.unpackb(meta)
vals = [
structure, dtypes, shapes = msgpack.unpackb(meta)
leaves = [
np.frombuffer(b, d).reshape(s)
for i, (d, s, b) in enumerate(zip(dtypes, shapes, buffers))]
data = dict(zip(keys, vals))
data = elements.tree.unflatten(leaves, structure)
return data
1 change: 0 additions & 1 deletion zerofun/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(self, fn, *args, name=None, start=False):
self._exitcode = None
self.exception = None
name = name or fn.__name__
self.old_name = name[:]
self.thread = threading.Thread(
target=self._wrapper, args=args, name=name, daemon=True)
self.started = False
Expand Down

0 comments on commit 888a4ef

Please sign in to comment.