Skip to content

Commit

Permalink
Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse
Browse files Browse the repository at this point in the history
Fixes #2516
  • Loading branch information
adriangb committed Jun 13, 2024
1 parent 5a1bec3 commit 318fc78
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -198,20 +197,33 @@ async def dispatch(
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):
def __init__(
self,
content: ContentStream,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)

async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})

0 comments on commit 318fc78

Please sign in to comment.