From c2004cda38ddea4ca00528a233937adbec477122 Mon Sep 17 00:00:00 2001 From: Bojana Todorovic Date: Wed, 21 Aug 2024 11:32:40 +0200 Subject: [PATCH 1/2] WS tasks refactoring --- app/websocket_api.py | 20 +++++++++++++++----- utils/api_utils.py | 6 +----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/app/websocket_api.py b/app/websocket_api.py index 22cc6b2b..2012b53b 100644 --- a/app/websocket_api.py +++ b/app/websocket_api.py @@ -1,4 +1,5 @@ import celery.states +import asyncio from fastapi import APIRouter, WebSocket, WebSocketDisconnect from starlette.websockets import WebSocketState @@ -8,11 +9,20 @@ router = APIRouter() +async def process_and_send_result(action, payload, ws, logging_info): + result = await api_utils.process_incoming_ws_request(action, payload, ws, logging_info) + if result: + await ws.send_json(result) @router.websocket("/ws") async def websocket(ws: WebSocket): await ws.accept() ws.created_tasks = [] + + async def handle_pong(payload): + await ws.send_json({"pong": payload}) + logger.info("ANSWER TO PING WEBSOCKET MESSAGE IS SENT") + while ws.client_state != WebSocketState.DISCONNECTED: try: data = await ws.receive_json() @@ -20,12 +30,12 @@ async def websocket(ws: WebSocket): payload = data["payload"] logging_info = data.get("logging_info") - result = await api_utils.process_incoming_ws_request( - action, payload, ws, logging_info - ) + if action == "ping": + # Immediately handle pong response in its own task + asyncio.create_task(handle_pong(payload)) - if result: - await ws.send_json(result) + # Process the incoming WebSocket request + asyncio.create_task(process_and_send_result(action, payload, ws, logging_info)) except KeyError as e: logger.error(e) await ws.send_json({"error": "Invalid message format."}) diff --git a/utils/api_utils.py b/utils/api_utils.py index b155738b..10f43493 100644 --- a/utils/api_utils.py +++ b/utils/api_utils.py @@ -258,11 +258,7 @@ async def process_incoming_ws_request( ) -> typing.Dict: result = {} - if action == "ping": - await ws.send_json({"pong": payload}) - logger.info("ANSWER TO PING WEBSOCKET MESSAGE FOR IS SENT") - - elif action == "schedule_multiple_xpath_generations": + if action == "schedule_multiple_xpath_generations": logging_info = LoggingInfoModel(**logging_info) if ENV != "LOCAL": mongodb.create_initial_log_entry( From 927088188f8956f56928227343813843e673e892 Mon Sep 17 00:00:00 2001 From: Bojana Todorovic Date: Thu, 22 Aug 2024 10:31:07 +0200 Subject: [PATCH 2/2] Lint fix --- app/websocket_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/websocket_api.py b/app/websocket_api.py index 2012b53b..85d2c378 100644 --- a/app/websocket_api.py +++ b/app/websocket_api.py @@ -9,11 +9,13 @@ router = APIRouter() + async def process_and_send_result(action, payload, ws, logging_info): result = await api_utils.process_incoming_ws_request(action, payload, ws, logging_info) if result: await ws.send_json(result) + @router.websocket("/ws") async def websocket(ws: WebSocket): await ws.accept()