diff --git a/aaq/serializers.py b/aaq/serializers.py index 2406188f..abac13e1 100644 --- a/aaq/serializers.py +++ b/aaq/serializers.py @@ -41,8 +41,8 @@ class ResponseFeedbackSerializer(serializers.Serializer): class SearchSerializer(serializers.Serializer): query_text = serializers.CharField(required=True) - generate_llm_response = serializers.BooleanField(required=False) - query_metadata = serializers.JSONField(required=False) + generate_llm_response = serializers.BooleanField(required=False, default=False) + query_metadata = serializers.JSONField(required=False, default=dict) class ContentFeedbackSerializer(serializers.Serializer): diff --git a/aaq/tests/helpers.py b/aaq/tests/helpers.py index 2a0ca636..4bdb9256 100644 --- a/aaq/tests/helpers.py +++ b/aaq/tests/helpers.py @@ -118,3 +118,32 @@ def post_search_return_empty(self, request): } return (200, {}, json.dumps(resp_body)) + + +class FakeAaqUdV2Api: + def post_urgency_detect_return_true(self, request): + resp_body = { + "details": { + "0": {"distance": 0.1, "urgency_rule": "Blurry vision and dizziness"}, + "1": {"distance": 0.2, "urgency_rule": "Nausea that lasts for 3 days"}, + }, + "is_urgent": True, + "matched_rules": [ + "Blurry vision and dizziness", + "Nausea that lasts for 3 days", + ], + } + + return (200, {}, json.dumps(resp_body)) + + def post_urgency_detect_return_false(self, request): + resp_body = { + "details": { + "0": {"distance": 0.1, "urgency_rule": "Baby okay"}, + "1": {"distance": 0.2, "urgency_rule": "Baby healthy"}, + }, + "is_urgent": False, + "matched_rules": ["Baby okay", "Baby healthy"], + } + + return (200, {}, json.dumps(resp_body)) diff --git a/aaq/tests/test_utils.py b/aaq/tests/test_utils.py new file mode 100644 index 00000000..617fa93d --- /dev/null +++ b/aaq/tests/test_utils.py @@ -0,0 +1,135 @@ +import json + +import responses +from django.contrib.auth import get_user_model +from rest_framework.test import APITestCase + +from ..utils import check_urgency_v2, search +from .helpers import FakeAaqApi, FakeAaqUdV2Api + + +class SearchFunctionTest(APITestCase): + + @responses.activate + def test_search_function(self): + user = get_user_model().objects.create_user("test") + self.client.force_authenticate(user) + + fakeAaqApi = FakeAaqApi() + responses.add_callback( + responses.POST, + "http://aaq_v2/search", + callback=fakeAaqApi.post_search, + content_type="application/json", + ) + + fakeAaqUdV2Api = FakeAaqUdV2Api() + responses.add_callback( + responses.POST, + "http://aaq_v2/check-urgency", + callback=fakeAaqUdV2Api.post_urgency_detect_return_true, + content_type="application/json", + ) + + query_text = "test query" + generate_llm_response = False + query_metadata = {} + + payload = { + "generate_llm_response": generate_llm_response, + "query_metadata": query_metadata, + "query_text": query_text, + } + + response = search(query_text, generate_llm_response, query_metadata) + + search_request = responses.calls[0] + + self.assertIn("message", response) + self.assertIn("body", response) + self.assertIn("feedback_secret_key", response) + self.assertIn("query_id", response) + self.assertEqual(response["query_id"], 1) + self.assertEqual(json.loads(search_request.request.body), payload) + assert response == { + "message": "*0* - Example content title\n" + "*1* - Another example content title", + "body": { + "0": {"text": "Example content text", "id": 23}, + "1": {"text": "Another example content text", "id": 12}, + }, + "feedback_secret_key": "secret-key-12345-abcde", + "query_id": 1, + "details": { + "0": {"distance": 0.1, "urgency_rule": "Blurry vision and dizziness"}, + "1": {"distance": 0.2, "urgency_rule": "Nausea that lasts for 3 days"}, + }, + "is_urgent": True, + "matched_rules": [ + "Blurry vision and dizziness", + "Nausea that lasts for 3 days", + ], + } + + @responses.activate + def test_urgency_check(self): + user = get_user_model().objects.create_user("test") + self.client.force_authenticate(user) + + fakeAaqUdV2Api = FakeAaqUdV2Api() + responses.add_callback( + responses.POST, + "http://aaq_v2/check-urgency", + callback=fakeAaqUdV2Api.post_urgency_detect_return_true, + content_type="application/json", + ) + + message_text = "Test message" + + response = check_urgency_v2(message_text) + + [request] = responses.calls + + self.assertIn("details", response) + self.assertIn("is_urgent", response) + self.assertIn("matched_rules", response) + self.assertEqual(json.loads(request.request.body), message_text) + + assert response == { + "details": { + "0": {"distance": 0.1, "urgency_rule": "Blurry vision and dizziness"}, + "1": {"distance": 0.2, "urgency_rule": "Nausea that lasts for 3 days"}, + }, + "is_urgent": True, + "matched_rules": [ + "Blurry vision and dizziness", + "Nausea that lasts for 3 days", + ], + } + + @responses.activate + def test_search_gibberish(self): + """ + Check that we get a response with an empty list in the search results part + """ + user = get_user_model().objects.create_user("test") + self.client.force_authenticate(user) + fakeAaqApi = FakeAaqApi() + responses.add_callback( + responses.POST, + "http://aaq_v2/search", + callback=fakeAaqApi.post_search_return_empty, + content_type="application/json", + ) + + query_text = "jgghkjfhtfftf" + generate_llm_response = False + query_metadata = {} + response = search(query_text, generate_llm_response, query_metadata) + + assert response.data == { + "message": "Gibberish Detected", + "body": {}, + "feedback_secret_key": "secret-key-12345-abcde", + "query_id": 1, + } diff --git a/aaq/tests/test_views.py b/aaq/tests/test_views.py index d61f58f1..fe32fd43 100644 --- a/aaq/tests/test_views.py +++ b/aaq/tests/test_views.py @@ -6,7 +6,7 @@ from rest_framework import status from rest_framework.test import APITestCase -from .helpers import FakeAaqApi, FakeAaqCoreApi, FakeAaqUdApi, FakeTask +from .helpers import FakeAaqApi, FakeAaqCoreApi, FakeAaqUdApi, FakeAaqUdV2Api, FakeTask class GetFirstPageViewTests(APITestCase): @@ -354,9 +354,9 @@ def test_search(self): """ Test that search returns data. """ - user = get_user_model().objects.create_user("test") self.client.force_authenticate(user) + fakeAaqApi = FakeAaqApi() responses.add_callback( responses.POST, @@ -365,85 +365,102 @@ def test_search(self): content_type="application/json", ) - payload = json.dumps( - { - "generate_llm_response": False, - "query_metadata": {"some_key": "query_metadata"}, - "query_text": "Breastfeeding", - } + fakeAaqUdV2Api = FakeAaqUdV2Api() + responses.add_callback( + responses.POST, + "http://aaq_v2/check-urgency", + callback=fakeAaqUdV2Api.post_urgency_detect_return_true, + content_type="application/json", ) + payload = { + "generate_llm_response": False, + "query_metadata": {}, + "query_text": "query_text", + } + response = self.client.post( - self.url, data=payload, content_type="application/json" + self.url, data=json.dumps(payload), content_type="application/json" ) self.assertEqual(response.status_code, 200) - self.assertIn("message", response.data) - self.assertIn("body", response.data) - self.assertIn("query_id", response.data) - self.assertIn("feedback_secret_key", response.data) + self.assertIn("message", response.json()) + self.assertIn("body", response.json()) + self.assertIn("query_id", response.json()) + self.assertIn("feedback_secret_key", response.json()) + self.assertIn("details", response.json()) + self.assertIn("is_urgent", response.json()) + self.assertIn("matched_rules", response.json()) assert response.json() == { - "message": "*0* - Example content title\n*1* -" - " Another example content title", + "message": "*0* - Example content title\n" + "*1* - Another example content title", "body": { "0": {"text": "Example content text", "id": 23}, "1": {"text": "Another example content text", "id": 12}, }, "feedback_secret_key": "secret-key-12345-abcde", "query_id": 1, + "details": { + "0": {"distance": 0.1, "urgency_rule": "Blurry vision and dizziness"}, + "1": {"distance": 0.2, "urgency_rule": "Nausea that lasts for 3 days"}, + }, + "is_urgent": True, + "matched_rules": [ + "Blurry vision and dizziness", + "Nausea that lasts for 3 days", + ], } @responses.activate - def test_search_gibberish(self): + def test_search_invalid_request_body(self): """ - Check that we get a response with an empty list in the search results part + Test search valid request. """ user = get_user_model().objects.create_user("test") self.client.force_authenticate(user) + + response = self.client.post( + self.url, data=json.dumps({}), content_type="application/json" + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.json(), {"query_text": ["This field is required."]}) + + @responses.activate + def test_request(self): + user = get_user_model().objects.create_user("test") + self.client.force_authenticate(user) + fakeAaqApi = FakeAaqApi() responses.add_callback( responses.POST, "http://aaq_v2/search", - callback=fakeAaqApi.post_search_return_empty, + callback=fakeAaqApi.post_search, content_type="application/json", ) - payload = json.dumps( - { - "generate_llm_response": False, - "query_metadata": {"some_key": "query_metadata"}, - "query_text": "yjyvcgrfeuyikbjmfb", - } - ) - - response = self.client.post( - self.url, data=payload, content_type="application/json" + fakeAaqUdV2Api = FakeAaqUdV2Api() + responses.add_callback( + responses.POST, + "http://aaq_v2/check-urgency", + callback=fakeAaqUdV2Api.post_urgency_detect_return_true, + content_type="application/json", ) - assert response.json() == { - "message": "Gibberish Detected", - "body": {}, - "feedback_secret_key": "secret-key-12345-abcde", - "query_id": 1, + payload = { + "generate_llm_response": "testing", + "query_metadata": {}, + "query_text": "query_text", } - @responses.activate - def test_search_invalid_request_body(self): - """ - Test search valid request. - """ - user = get_user_model().objects.create_user("test") - self.client.force_authenticate(user) - - payload = json.dumps({}) - response = self.client.post( - self.url, data=payload, content_type="application/json" + self.url, data=json.dumps(payload), content_type="application/json" ) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.json(), {"query_text": ["This field is required."]}) + self.assertEqual( + response.json(), {"generate_llm_response": ["Must be a valid boolean."]} + ) class ContentFeedbackViewTests(APITestCase): diff --git a/aaq/urls.py b/aaq/urls.py index ece9abd2..960eef45 100644 --- a/aaq/urls.py +++ b/aaq/urls.py @@ -35,7 +35,7 @@ ), re_path( r"^api/v2/search", - views.search, + views.aaq_search, name="aaq-search", ), ] diff --git a/aaq/utils.py b/aaq/utils.py new file mode 100644 index 00000000..7a0b9cbe --- /dev/null +++ b/aaq/utils.py @@ -0,0 +1,71 @@ +import urllib + +import requests +from django.conf import settings +from rest_framework import status +from rest_framework.response import Response + + +def check_urgency_v2(message_text): + url = urllib.parse.urljoin(settings.AAQ_V2_API_URL, "check-urgency") + headers = { + "Authorization": settings.AAQ_V2_AUTH, + "Content-Type": "application/json", + } + + response = requests.request("POST", url, json=message_text, headers=headers) + + return response.json() + + +def search(query_text, generate_llm_response, query_metadata): + url = urllib.parse.urljoin(settings.AAQ_V2_API_URL, "search") + payload = { + "query_text": query_text, + "generate_llm_response": generate_llm_response, + "query_metadata": query_metadata, + } + headers = { + "Authorization": settings.AAQ_V2_AUTH, + "Content-Type": "application/json", + } + + response = requests.request("POST", url, json=payload, headers=headers) + + query_id = response.json()["query_id"] + feedback_secret_key = response.json()["feedback_secret_key"] + search_results = response.json()["search_results"] + + if search_results == {}: + json_msg = { + "message": "Gibberish Detected", + "body": {}, + "feedback_secret_key": feedback_secret_key, + "query_id": query_id, + } + return Response(json_msg, status=status.HTTP_200_OK) + + json_msg = {} + body_content = {} + message_titles = [] + + for key, value in search_results.items(): + text = value["text"] + id = value["id"] + title = value["title"] + + body_content[key] = {"text": text, "id": id} + message_titles.append(f"*{key}* - {title}") + + json_msg = { + "message": "\n".join(message_titles), + "body": body_content, + "feedback_secret_key": feedback_secret_key, + "query_id": query_id, + } + + check_urgency_response = check_urgency_v2(query_text) + + json_msg.update(check_urgency_response) + + return json_msg diff --git a/aaq/views.py b/aaq/views.py index f7e1196b..81742e63 100644 --- a/aaq/views.py +++ b/aaq/views.py @@ -18,6 +18,7 @@ ) from .tasks import send_feedback_task, send_feedback_task_v2 +from .utils import search logger = logging.getLogger(__name__) @@ -170,55 +171,14 @@ def response_feedback(request, *args, **kwargs): @api_view(("POST",)) @renderer_classes((JSONRenderer,)) -def search(request, *args, **kwargs): +def aaq_search(request): + serializer = SearchSerializer(data=request.data) serializer.is_valid(raise_exception=True) query_text = serializer.validated_data["query_text"] generate_llm_response = serializer.validated_data["generate_llm_response"] query_metadata = serializer.validated_data["query_metadata"] - url = urllib.parse.urljoin(settings.AAQ_V2_API_URL, "search") - payload = { - "query_text": query_text, - "generate_llm_response": generate_llm_response, - "query_metadata": query_metadata, - } - headers = { - "Authorization": settings.AAQ_V2_AUTH, - "Content-Type": "application/json", - } - - response = requests.request("POST", url, json=payload, headers=headers) - - query_id = response.json()["query_id"] - feedback_secret_key = response.json()["feedback_secret_key"] - search_results = response.json()["search_results"] - - if search_results == {}: - json_msg = { - "message": "Gibberish Detected", - "body": {}, - "feedback_secret_key": feedback_secret_key, - "query_id": query_id, - } - return Response(json_msg, status=status.HTTP_200_OK) - - json_msg = {} - body_content = {} - message_titles = [] - - for key, value in search_results.items(): - text = value["text"] - id = value["id"] - title = value["title"] - body_content[key] = {"text": text, "id": id} - message_titles.append(f"*{key}* - {title}") - - json_msg = { - "message": "\n".join(message_titles), - "body": body_content, - "feedback_secret_key": feedback_secret_key, - "query_id": query_id, - } + response = search(query_text, generate_llm_response, query_metadata) - return Response(json_msg, status=status.HTTP_200_OK) + return Response(response, status=status.HTTP_200_OK)