Skip to content

Commit

Permalink
WS tasks refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Bojana Todorovic committed Aug 21, 2024
1 parent 65956de commit c2004cd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
20 changes: 15 additions & 5 deletions app/websocket_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import celery.states
import asyncio
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocketState

Expand All @@ -8,24 +9,33 @@

router = APIRouter()

async def process_and_send_result(action, payload, ws, logging_info):
result = await api_utils.process_incoming_ws_request(action, payload, ws, logging_info)
if result:
await ws.send_json(result)

@router.websocket("/ws")
async def websocket(ws: WebSocket):
await ws.accept()
ws.created_tasks = []

async def handle_pong(payload):
await ws.send_json({"pong": payload})
logger.info("ANSWER TO PING WEBSOCKET MESSAGE IS SENT")

while ws.client_state != WebSocketState.DISCONNECTED:
try:
data = await ws.receive_json()
action = data["action"]
payload = data["payload"]
logging_info = data.get("logging_info")

result = await api_utils.process_incoming_ws_request(
action, payload, ws, logging_info
)
if action == "ping":
# Immediately handle pong response in its own task
asyncio.create_task(handle_pong(payload))

if result:
await ws.send_json(result)
# Process the incoming WebSocket request
asyncio.create_task(process_and_send_result(action, payload, ws, logging_info))
except KeyError as e:
logger.error(e)
await ws.send_json({"error": "Invalid message format."})
Expand Down
6 changes: 1 addition & 5 deletions utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,7 @@ async def process_incoming_ws_request(
) -> typing.Dict:
result = {}

if action == "ping":
await ws.send_json({"pong": payload})
logger.info("ANSWER TO PING WEBSOCKET MESSAGE FOR IS SENT")

elif action == "schedule_multiple_xpath_generations":
if action == "schedule_multiple_xpath_generations":
logging_info = LoggingInfoModel(**logging_info)
if ENV != "LOCAL":
mongodb.create_initial_log_entry(
Expand Down

0 comments on commit c2004cd

Please sign in to comment.