Skip to content

Commit

Permalink
Cooperate with anyio
Browse files Browse the repository at this point in the history
`anyio` allows running `async def` test functions, but the wrapper
installed by Memray to add tracking around the test function breaks
`anyio`'s detection.

Work around this by using an `async def` wrapper when the function being
wrapped is a coroutine function.

Signed-off-by: Matt Wozniski <[email protected]>
  • Loading branch information
godlygeek committed Jun 30, 2024
1 parent 0654e6b commit fd24f48
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 32 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ optional-dependencies.lint = [
"mypy==0.991",
]
optional-dependencies.test = [
"anyio>=4.4.0",
"covdefaults>=2.2.2",
"pytest>=7.2",
"coverage>=7.0.5",
Expand Down
79 changes: 47 additions & 32 deletions src/pytest_memray/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import pickle
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
Expand Down Expand Up @@ -178,39 +179,53 @@ def _build_bin_path() -> Path:
if markers and "limit_leaks" in markers:
native = trace_python_allocators = True

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> object | None:
test_result: object | Any = None
@contextmanager
def memory_reporting() -> Generator[None, None, None]:
# Restore the original function. This is needed because some
# pytest plugins (e.g. flaky) will call our pytest_pyfunc_call
# hook again with whatever is here, which will cause the wrapper
# to be wrapped again.
pyfuncitem.obj = func

result_file = _build_bin_path()
with Tracker(
result_file,
native_traces=native,
trace_python_allocators=trace_python_allocators,
file_format=FileFormat.AGGREGATED_ALLOCATIONS,
):
yield

try:
result_file = _build_bin_path()
with Tracker(
result_file,
native_traces=native,
trace_python_allocators=trace_python_allocators,
file_format=FileFormat.AGGREGATED_ALLOCATIONS,
):
test_result = func(*args, **kwargs)
try:
metadata = FileReader(result_file).metadata
except OSError:
return None
result = Result(pyfuncitem.nodeid, metadata, result_file)
metadata_path = (
self.result_metadata_path
/ result_file.with_suffix(".metadata").name
)
with open(metadata_path, "wb") as file_handler:
pickle.dump(result, file_handler)
self.results[pyfuncitem.nodeid] = result
finally:
# Restore the original function. This is needed because some
# pytest plugins (e.g. flaky) will call our pytest_pyfunc_call
# hook again with whatever is here, which will cause the wrapper
# to be wrapped again.
pyfuncitem.obj = func
return test_result

pyfuncitem.obj = wrapper
metadata = FileReader(result_file).metadata
except OSError:
return
result = Result(pyfuncitem.nodeid, metadata, result_file)
metadata_path = (
self.result_metadata_path
/ result_file.with_suffix(".metadata").name
)
with open(metadata_path, "wb") as file_handler:
pickle.dump(result, file_handler)
self.results[pyfuncitem.nodeid] = result


@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
with memory_reporting():
return func(*args, **kwargs)


@functools.wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
with memory_reporting():
return await func(*args, **kwargs)

if inspect.iscoroutinefunction(func):
pyfuncitem.obj = async_wrapper
else:
pyfuncitem.obj = wrapper

yield

@hookimpl(hookwrapper=True)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_pytest_memray.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,37 @@ def test_memory_alloc_fails():
result = pytester.runpytest("--memray")

assert result.ret == ExitCode.OK


def test_running_async_tests_with_anyio(pytester: Pytester) -> None:
xml_output_file = pytester.makefile(".xml", "")
pytester.makepyfile(
"""
import pytest
from memray._test import MemoryAllocator
allocator = MemoryAllocator()
@pytest.fixture
def anyio_backend():
return 'asyncio'
@pytest.mark.limit_leaks("5KB")
@pytest.mark.anyio
async def test_memory_alloc_fails():
for _ in range(10):
allocator.valloc(1024*10)
# No free call here
"""
)

result = pytester.runpytest("--junit-xml", xml_output_file)

assert result.ret != ExitCode.OK

root = ET.parse(str(xml_output_file)).getroot()
for testcase in root.iter("testcase"):
failure = testcase.find("failure")
assert failure.text == (
"Test was allowed to leak 5.0KiB per location"
" but at least one location leaked more"
)

0 comments on commit fd24f48

Please sign in to comment.