Skip to content

Commit

Permalink
add test_concurrency_correct_headers
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwojcik committed Nov 26, 2022
1 parent be36ecd commit 69aed8d
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import uuid

import httpx

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse

from starlette_context import context, plugins
from starlette_context.header_keys import HeaderKeys
from starlette_context.middleware import RawContextMiddleware
import pytest
import pytest_asyncio
from starlette.routing import Route
from asgi_lifespan import LifespanManager
from starlette.exceptions import HTTPException


def should_raise(number: int) -> bool:
return number % 2 == 0


@pytest_asyncio.fixture
async def app():
class CloudProviderException(HTTPException):
pass

async def cloud_provider_exception_handler(request: Request, exc: HTTPException):
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code, headers={
HeaderKeys.request_id: context[HeaderKeys.request_id]
})

async def index(request: Request) -> JSONResponse:
number = request.path_params['number']
if should_raise(number):
raise CloudProviderException
return JSONResponse(content={"trace_id": context[HeaderKeys.request_id], "from": "view"})

middleware = [
Middleware(
RawContextMiddleware,
plugins=(
plugins.RequestIdPlugin(),
),
)
]
exception_handlers = {
CloudProviderException: cloud_provider_exception_handler
}
app = Starlette(
middleware=middleware,
routes=[Route("/{number:int}", index)],
exception_handlers=exception_handlers
)

async with LifespanManager(app):
yield app


@pytest.mark.asyncio
async def test_concurrency_correct_headers(app):
transport = httpx.ASGITransport(
app=app,
raise_app_exceptions=False
)
async with httpx.AsyncClient(app=app, transport=transport, base_url="http://test") as client:
for number in range(1, 101):
rid = uuid.uuid4().hex
resp = await client.get(f"/{number}", headers={HeaderKeys.request_id: rid})

if should_raise(number):
assert resp.status_code == 500
assert resp.headers[HeaderKeys.request_id] == rid
else:
assert resp.status_code == 200
d = resp.json()
assert d['from'] == 'view'
assert d['trace_id'] == rid
assert resp.headers[HeaderKeys.request_id] == rid

0 comments on commit 69aed8d

Please sign in to comment.