diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..02c4382 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,19 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "/Users/alexloembe/NIST/usnistgov/oar-pdr-py/python/tests/nistoar/midas/dbio/test_inmem.py", + "console": "integratedTerminal", + "pythonPath": "/Users/alexloembe/anaconda3/envs/dbio/bin/python", + "env": { + "PYTHONPATH": "/Users/alexloembe/NIST/usnistgov/oar-pdr-py/python", + }, + "terminal": "integrated", + "justMyCode": false + } + ] + } + \ No newline at end of file diff --git a/metadata b/metadata index 9cab880..d96af33 160000 --- a/metadata +++ b/metadata @@ -1 +1 @@ -Subproject commit 9cab8808a12e9de38b566da961e005a5c4fe3ca0 +Subproject commit d96af33d1f2ec0520e906360c21dc478acb8867b diff --git a/python/nistoar/midas/dbio/base.py b/python/nistoar/midas/dbio/base.py index a78d5fc..e8ae42d 100644 --- a/python/nistoar/midas/dbio/base.py +++ b/python/nistoar/midas/dbio/base.py @@ -12,6 +12,10 @@ """ import time import math +import asyncio +import threading +import inspect +from concurrent.futures import ThreadPoolExecutor from abc import ABC, ABCMeta, abstractmethod, abstractproperty from copy import deepcopy from collections.abc import Mapping, MutableMapping, Set @@ -24,6 +28,7 @@ from nistoar.pdr.utils.prov import Action from .. import MIDASException from .status import RecordStatus +from .websocket import WebSocketServer from nistoar.pdr.utils.prov import ANONYMOUS_USER DAP_PROJECTS = "dap" @@ -822,14 +827,17 @@ class DBClient(ABC): the only allowed shoulder will be the default, ``grp0``. """ - def __init__(self, config: Mapping, projcoll: str, nativeclient=None, foruser: str = ANONYMOUS): + def __init__(self, config: Mapping, projcoll: str,websocket_server: WebSocketServer, nativeclient=None, foruser: str = ANONYMOUS): self._cfg = config self._native = nativeclient self._projcoll = projcoll self._who = foruser self._whogrps = None - + # Get the current stack frame + stack = inspect.stack() self._dbgroups = DBGroups(self) + self.websocket_server = websocket_server + @property def project(self) -> str: @@ -891,7 +899,17 @@ def create_record(self, name: str, shoulder: str = None, foruser: str = None) -> rec['name'] = name rec = ProjectRecord(self._projcoll, rec, self) rec.save() + # Send a message through the WebSocket + message = f"{name}" + self._send_message_async(message) return rec + + def _send_message_async(self, message): + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.create_task(self.websocket_server.send_message_to_clients(message)) + else: + asyncio.run(self.websocket_server.send_message_to_clients(message)) def _default_shoulder(self): out = self._cfg.get("default_shoulder") @@ -1282,7 +1300,7 @@ class DBClientFactory(ABC): an abstract class for creating client connections to the database """ - def __init__(self, config): + def __init__(self, config,websocket_server: WebSocketServer): """ initialize the factory with its configuration. The configuration provided here serves as the default parameters for the cient as these can be overridden by the configuration parameters @@ -1292,6 +1310,7 @@ def __init__(self, config): depend on the type of project being access (e.g. "dmp" vs. "dap"). """ self._cfg = config + self.websocket_server = websocket_server @abstractmethod def create_client(self, servicetype: str, config: Mapping = {}, foruser: str = ANONYMOUS): diff --git a/python/nistoar/midas/dbio/inmem.py b/python/nistoar/midas/dbio/inmem.py index 717238f..0612655 100644 --- a/python/nistoar/midas/dbio/inmem.py +++ b/python/nistoar/midas/dbio/inmem.py @@ -7,6 +7,7 @@ from collections.abc import Mapping, MutableMapping, Set from typing import Iterator, List from . import base +from .websocket import WebSocketServer from nistoar.base.config import merge_config @@ -15,9 +16,9 @@ class InMemoryDBClient(base.DBClient): an in-memory DBClient implementation """ - def __init__(self, dbdata: Mapping, config: Mapping, projcoll: str, foruser: str = base.ANONYMOUS): + def __init__(self, dbdata: Mapping, config: Mapping, projcoll: str,websocket=WebSocketServer ,foruser: str = base.ANONYMOUS): self._db = dbdata - super(InMemoryDBClient, self).__init__(config, projcoll, self._db, foruser) + super(InMemoryDBClient, self).__init__(config, projcoll,websocket, self._db, foruser) def _next_recnum(self, shoulder): if shoulder not in self._db['nextnum']: @@ -79,7 +80,6 @@ def select_records(self, perm: base.Permissions=base.ACLs.OWN, **cnsts) -> Itera for p in perm: if rec.authorized(p): yield deepcopy(rec) - break def adv_select_records(self, filter:dict, perm: base.Permissions=base.ACLs.OWN,) -> Iterator[base.ProjectRecord]: @@ -139,7 +139,7 @@ class InMemoryDBClientFactory(base.DBClientFactory): clients it creates. """ - def __init__(self, config: Mapping, _dbdata = None): + def __init__(self, config: Mapping,websocket_server:WebSocketServer, _dbdata = None): """ Create the factory with the given configuration. @@ -148,7 +148,8 @@ def __init__(self, config: Mapping, _dbdata = None): of the in-memory data structure required to use this input.) If not provided, an empty database is created. """ - super(InMemoryDBClientFactory, self).__init__(config) + super(InMemoryDBClientFactory, self).__init__(config,websocket_server) + self.websocket_server = websocket_server self._db = { base.DAP_PROJECTS: {}, base.DMP_PROJECTS: {}, @@ -161,8 +162,9 @@ def __init__(self, config: Mapping, _dbdata = None): def create_client(self, servicetype: str, config: Mapping={}, foruser: str = base.ANONYMOUS): + cfg = merge_config(config, deepcopy(self._cfg)) if servicetype not in self._db: self._db[servicetype] = {} - return InMemoryDBClient(self._db, cfg, servicetype, foruser) + return InMemoryDBClient(self._db, cfg, servicetype,self.websocket_server, foruser) diff --git a/python/nistoar/midas/dbio/websocket.py b/python/nistoar/midas/dbio/websocket.py new file mode 100644 index 0000000..ff5bb3a --- /dev/null +++ b/python/nistoar/midas/dbio/websocket.py @@ -0,0 +1,56 @@ +# websocket_server.py +import asyncio +import websockets +from concurrent.futures import ThreadPoolExecutor +import copy + +class WebSocketServer: + def __init__(self, host="localhost", port=8765): + self.host = host + self.port = port + self.server = None + self.clients = set() # Initialize the clients set + + async def start(self): + self.server = await websockets.serve(self.websocket_handler, self.host, self.port) + #print(f"WebSocket server started on ws://{self.host}:{self.port}") + + + async def websocket_handler(self, websocket): + # Add the new client to the set + self.clients.add(websocket) + try: + async for message in websocket: + await self.send_message_to_clients(message) + finally: + # Remove the client from the set when they disconnect + self.clients.remove(websocket) + + async def send_message_to_clients(self, message): + for client in self.clients: + print(client) + if self.clients: + for client in self.clients: + asyncio.create_task(client.send(message)) + + async def stop(self): + if self.server: + self.server.close() + await self.server.wait_closed() + self.server = None + + async def wait_closed(self): + if self.server: + await self.server.wait_closed() + + def __deepcopy__(self, memo): + # Create a shallow copy of the object + new_copy = copy.copy(self) + # Deep copy the attributes that are not problematic + new_copy.host = copy.deepcopy(self.host, memo) + new_copy.port = copy.deepcopy(self.port, memo) + new_copy.clients = copy.deepcopy(self.clients, memo) + # Do not copy the problematic attribute + new_copy.server = self.server + return new_copy + diff --git a/python/tests/nistoar/midas/dbio/test_inmem.py b/python/tests/nistoar/midas/dbio/test_inmem.py index 7cfcf4b..106c106 100644 --- a/python/tests/nistoar/midas/dbio/test_inmem.py +++ b/python/tests/nistoar/midas/dbio/test_inmem.py @@ -1,9 +1,11 @@ -import os, json, pdb, logging +import os, json, pdb, logging,asyncio from pathlib import Path import unittest as test +import websockets from nistoar.midas.dbio import inmem, base from nistoar.pdr.utils.prov import Action, Agent +from nistoar.midas.dbio.websocket import WebSocketServer testuser = Agent("dbio", Agent.AUTO, "tester", "test") testdir = Path(__file__).parents[0] @@ -26,12 +28,49 @@ with open(dmp_path, 'r') as file: dmp = json.load(file) + class TestInMemoryDBClientFactory(test.TestCase): + @classmethod + def initialize_websocket_server(cls): + websocket_server = WebSocketServer() + try: + cls.loop = asyncio.get_event_loop() + if cls.loop.is_closed(): + raise RuntimeError + except RuntimeError: + cls.loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls.loop) + cls.loop.run_until_complete(websocket_server.start()) + #print("WebSocketServer initialized:", websocket_server) + return websocket_server + + @classmethod + def setUpClass(cls): + cls.websocket_server = cls.initialize_websocket_server() + + @classmethod + def tearDownClass(cls): + # Ensure the WebSocket server is properly closed + cls.loop.run_until_complete(cls.websocket_server.stop()) + cls.loop.run_until_complete(cls.websocket_server.wait_closed()) + + # Cancel all lingering tasks + asyncio.set_event_loop(cls.loop) # Set the event loop as the current event loop + tasks = asyncio.all_tasks(loop=cls.loop) + for task in tasks: + task.cancel() + cls.loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + + # Close the event loop + cls.loop.close() + + + def setUp(self): self.cfg = {"goob": "gurn"} self.fact = inmem.InMemoryDBClientFactory( - self.cfg, {"nextnum": {"hank": 2}}) + self.cfg,self.websocket_server, {"nextnum": {"hank": 2}}) def test_ctor(self): self.assertEqual(self.fact._cfg, self.cfg) @@ -55,11 +94,44 @@ def test_create_client(self): class TestInMemoryDBClient(test.TestCase): + @classmethod + def initialize_websocket_server(cls): + websocket_server = WebSocketServer() + try: + cls.loop = asyncio.get_event_loop() + except RuntimeError: + cls.loop = asyncio.new_event_loop() + asyncio.set_event_loop(cls.loop) + cls.loop.run_until_complete(websocket_server.start()) + #print("WebSocketServer initialized:", websocket_server) + return websocket_server + + @classmethod + def setUpClass(cls): + cls.websocket_server = cls.initialize_websocket_server() + + @classmethod + def tearDownClass(cls): + # Ensure the WebSocket server is properly closed + cls.loop.run_until_complete(cls.websocket_server.stop()) + cls.loop.run_until_complete(cls.websocket_server.wait_closed()) + + # Cancel all lingering tasks + asyncio.set_event_loop(cls.loop) + tasks = asyncio.all_tasks(loop=cls.loop) + for task in tasks: + task.cancel() + cls.loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + + # Close the event loop + cls.loop.close() + def setUp(self): self.cfg = {"default_shoulder": "mds3"} self.user = "nist0:ava1" - self.cli = inmem.InMemoryDBClientFactory({}).create_client( + self.cli = inmem.InMemoryDBClientFactory({},self.websocket_server).create_client( base.DMP_PROJECTS, self.cfg, self.user) + def test_next_recnum(self): self.assertEqual(self.cli._next_recnum("goob"), 1) @@ -270,8 +342,6 @@ def test_adv_select_records(self): self.cli._db[base.DMP_PROJECTS][id] = rec.to_dict() - - id = "pdr0:0006" rec = base.ProjectRecord( base.DMP_PROJECTS, {"id": id, "name": "test 2", "status": { @@ -443,5 +513,53 @@ def test_record_action(self): self.assertEqual(acts[1]['type'], Action.COMMENT) +class TestWebSocketServer(test.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.websocket_server = WebSocketServer() + self.loop = asyncio.get_event_loop() + await self.websocket_server.start() + + # Initialize the InMemoryDBClientFactory with the websocket_server + self.cfg = {"default_shoulder": "mds3"} + self.user = "nist0:ava1" + self.cli = inmem.InMemoryDBClientFactory({},self.websocket_server).create_client( + base.DMP_PROJECTS, self.cfg, self.user) + + async def asyncTearDown(self): + await self.websocket_server.stop() + await self.websocket_server.wait_closed() + + + async def test_create_records_websocket(self): + messages = [] + + async def receive_messages(uri): + try: + async with websockets.connect(uri) as websocket: + while True: + message = await websocket.recv() + #print(f"Received message: {message}") + messages.append(message) + #print(f"Messages: {messages}") + # Break the loop after receiving the first message for this test + except Exception as e: + print(f"Failed to connect to WebSocket server: {e}") + + # Start the WebSocket client to receive messages + uri = 'ws://localhost:8765' + receive_task = asyncio.create_task(receive_messages(uri)) + await asyncio.sleep(2) + + #await self.websocket_server.send_message_to_clients("Connection established") + # Inject some data into the database + rec = self.cli.create_record("mine1") + await asyncio.sleep(2) + + #print(f"Messages: {messages}") + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0], "mine1") + + + if __name__ == '__main__': test.main() diff --git a/scripts/midas-uwsgi.py b/scripts/midas-uwsgi.py index 3e3fb8f..bbbdc0e 100644 --- a/scripts/midas-uwsgi.py +++ b/scripts/midas-uwsgi.py @@ -29,9 +29,10 @@ configuration data for (default: pdr-resolve); this is only used if OAR_CONFIG_SERVICE is used. """ -import os, sys, logging, copy +import os, sys, logging, copy,asyncio from copy import deepcopy + try: import nistoar except ImportError: @@ -43,6 +44,7 @@ import nistoar from nistoar.base import config +from nistoar.midas.dbio.websocket import WebSocketServer from nistoar.midas.dbio import MongoDBClientFactory, InMemoryDBClientFactory, FSBasedDBClientFactory from nistoar.midas import wsgi @@ -61,6 +63,16 @@ def _dec(obj): # determine where the configuration is coming from confsrc = _dec(uwsgi.opt.get("oar_config_file")) + +def initialize_websocket_server(): + websocket_server = WebSocketServer() + loop = asyncio.get_event_loop() + loop.run_until_complete(websocket_server.start()) + print("WebSocketServer initialized:", websocket_server) + return websocket_server + +websocket_server = initialize_websocket_server() + if confsrc: cfg = config.resolve_configuration(confsrc) @@ -129,9 +141,10 @@ def _dec(obj): print(f"dburl: {dburl}") factory = MongoDBClientFactory(cfg.get("dbio", {}), dburl) elif dbtype == "inmem": - factory = InMemoryDBClientFactory(cfg.get("dbio", {})) + factory = InMemoryDBClientFactory(cfg.get("dbio", {}),websocket_server=websocket_server) else: raise RuntimeError("Unsupported database type: "+dbtype) application = wsgi.app(cfg, factory) +websocket_server.start_in_thread() logging.info("MIDAS service ready with "+dbtype+" backend")