Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow opt-out of schema dump to dict #57

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions fastapi_events/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def _derive_event_name_and_payload_from_pydantic_model(
event_name_or_model: Union[EventName, PydanticModel],
event_name: EventName,
payload: Payload,
payload_schema_cls_dict_args: Dict[str, Any]
payload_schema_cls_dict_args: Dict[str, Any],
payload_schema_dump: bool
):
"""
Derive event_name and payload from Pydantic model
Expand All @@ -123,10 +124,13 @@ def _derive_event_name_and_payload_from_pydantic_model(

if not payload:
payload_schema_cls_dict_args = payload_schema_cls_dict_args or DEFAULT_PAYLOAD_SCHEMA_CLS_DICT_ARGS
if IS_PYDANTIC_V1:
payload = event_name_or_model.dict(**payload_schema_cls_dict_args)
if payload_schema_dump:
if IS_PYDANTIC_V1:
payload = event_name_or_model.dict(**payload_schema_cls_dict_args)
else:
payload = event_name_or_model.model_dump(**payload_schema_cls_dict_args)
else:
payload = event_name_or_model.model_dump(**payload_schema_cls_dict_args)
payload = event_name_or_model

return event_name, payload

Expand All @@ -135,7 +139,8 @@ def _validate_payload(
event_name: EventName,
payload: Payload,
payload_schema_registry: BaseEventPayloadSchemaRegistry,
payload_schema_cls_dict_args: Dict[str, Any]
payload_schema_cls_dict_args: Dict[str, Any],
payload_schema_dump: bool = True
):
"""
Validate payload if a corresponding payload schema is registered
Expand All @@ -146,10 +151,14 @@ def _validate_payload(
payload_schema_cls = payload_schema_registry.get(event_name)
if payload_schema_cls:
payload_schema_cls_dict_args = payload_schema_cls_dict_args or DEFAULT_PAYLOAD_SCHEMA_CLS_DICT_ARGS
if IS_PYDANTIC_V1:
payload = payload_schema_cls(**(payload or {})).dict(**payload_schema_cls_dict_args)
deserialized_payload = payload_schema_cls(**(payload or {}))
if payload_schema_dump:
if IS_PYDANTIC_V1:
payload = deserialized_payload.dict(**payload_schema_cls_dict_args)
else:
payload = deserialized_payload.model_dump(**payload_schema_cls_dict_args)
else:
payload = payload_schema_cls(**(payload or {})).model_dump(**payload_schema_cls_dict_args)
payload = deserialized_payload
else:
logger.debug("Payload schema for event %s not found. Skipping validation...", event_name)

Expand All @@ -163,7 +172,8 @@ def dispatch(
validate_payload: bool = True,
payload_schema_cls_dict_args: Optional[Dict[str, Any]] = None,
payload_schema_registry: Optional[BaseEventPayloadSchemaRegistry] = None,
middleware_id: Optional[int] = None
middleware_id: Optional[int] = None,
payload_schema_dump: bool = True
) -> None:
"""
A wrapper of the main dispatcher function with additional checks.
Expand All @@ -184,7 +194,8 @@ def dispatch(
event_name_or_model=event_name_or_model,
event_name=event_name,
payload=payload,
payload_schema_cls_dict_args=payload_schema_cls_dict_args
payload_schema_cls_dict_args=payload_schema_cls_dict_args,
payload_schema_dump=payload_schema_dump,
)

# Validate event payload with schema registered
Expand All @@ -194,7 +205,8 @@ def dispatch(
event_name=event_name,
payload=payload,
payload_schema_registry=payload_schema_registry,
payload_schema_cls_dict_args=payload_schema_cls_dict_args
payload_schema_cls_dict_args=payload_schema_cls_dict_args,
payload_schema_dump=payload_schema_dump
)

# OTEL
Expand Down
21 changes: 16 additions & 5 deletions tests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ async def test_suppression_of_events_in_req_res_cycle(
({"user_id": uuid.uuid4()}, True),
({}, True),
(None, True)))
@pytest.mark.parametrize(
"payload_schema_dump",
(True, False)
)
async def test_payload_validation_with_pydantic_in_req_res_cycle(
event_payload, should_raise_error, setup_mocks_for_events_in_req_res_cycle
event_payload, should_raise_error, payload_schema_dump, setup_mocks_for_events_in_req_res_cycle,
):
"""
Test if event payloads are properly validated when a payload schema is registered.
Expand All @@ -88,7 +92,8 @@ class _SignUpEventSchema(pydantic.BaseModel):
dispatch_fn = functools.partial(dispatch,
event_name=UserEvents.SIGNED_UP,
payload=event_payload,
payload_schema_registry=payload_schema)
payload_schema_registry=payload_schema,
payload_schema_dump=payload_schema_dump)

if should_raise_error:
with pytest.raises(pydantic.ValidationError):
Expand All @@ -100,8 +105,12 @@ class _SignUpEventSchema(pydantic.BaseModel):


@pytest.mark.asyncio
@pytest.mark.parametrize(
"payload_schema_dump",
(True, False)
)
async def test_dispatching_with_pydantic_model(
setup_mocks_for_events_in_req_res_cycle, mocker
payload_schema_dump, setup_mocks_for_events_in_req_res_cycle, mocker
):
payload_schema = EventPayloadSchemaRegistry()

Expand All @@ -114,12 +123,14 @@ class UserSignedUpEventSchema(pydantic.BaseModel):

username: str

dispatch(UserSignedUpEventSchema(username="USER_ABC"))
event = UserSignedUpEventSchema(username="USER_ABC")
expected_payload = {"username": "USER_ABC"} if payload_schema_dump else event
dispatch(event, payload_schema_dump=payload_schema_dump)

assert mocks["spy_event_store_ctx_var"].get.called
spy__dispatch.assert_called_with(
event_name="USER_SIGNED_UP",
payload={"username": "USER_ABC"}
payload=expected_payload
)


Expand Down
Loading