From ac34de0c56e6eaa8b6d04dcbfca993e5ce372add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 20 Jan 2024 17:23:42 +0200 Subject: [PATCH] Fixed inability to start tasks from async_generator_asend objects on asyncio --- docs/versionhistory.rst | 1 + src/anyio/_backends/_asyncio.py | 6 ++---- tests/test_taskgroups.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 5d39e459..5b8c6647 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -10,6 +10,7 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed passing ``total_tokens`` to ``anyio.CapacityLimiter()`` as a keyword argument not working on the ``trio`` backend (`#515 `_) +- Fixed inability to start tasks from ``async_generator_asend`` objects on asyncio **4.2.0** diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index e884f564..ff1fc982 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -18,7 +18,7 @@ ) from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] from collections import OrderedDict, deque -from collections.abc import AsyncIterator, Generator, Iterable +from collections.abc import AsyncIterator, Coroutine, Generator, Iterable from concurrent.futures import Future from contextlib import suppress from contextvars import Context, copy_context @@ -28,7 +28,6 @@ CORO_RUNNING, CORO_SUSPENDED, getcoroutinestate, - iscoroutine, ) from io import IOBase from os import PathLike @@ -45,7 +44,6 @@ Callable, Collection, ContextManager, - Coroutine, Mapping, Optional, Sequence, @@ -741,7 +739,7 @@ def task_done(_task: asyncio.Task) -> None: parent_id = id(self.cancel_scope._host_task) coro = func(*args, **kwargs) - if not iscoroutine(coro): + if not isinstance(coro, Coroutine): prefix = f"{func.__module__}." if hasattr(func, "__module__") else "" raise TypeError( f"Expected {prefix}{func.__qualname__}() to return a coroutine, but " diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index f4c87b39..ede649e4 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -1336,6 +1336,22 @@ async def wait_cancel() -> None: await cancelled.wait() +async def test_start_soon_from_asend() -> None: + started = False + + async def genfunc() -> AsyncGenerator[None, None]: + nonlocal started + started = True + yield + + generator = genfunc() + async with anyio.create_task_group() as task_group: + task_group.start_soon(generator.asend, None) + + assert started + await generator.aclose() + + class TestTaskStatusTyping: """ These tests do not do anything at run time, but since the test suite is also checked