Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket #20

Open
wants to merge 9 commits into
base: integration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "/Users/alexloembe/NIST/usnistgov/oar-pdr-py/python/tests/nistoar/midas/dbio/test_inmem.py",
"console": "integratedTerminal",
"pythonPath": "/Users/alexloembe/anaconda3/envs/dbio/bin/python",
"env": {
"PYTHONPATH": "/Users/alexloembe/NIST/usnistgov/oar-pdr-py/python",
},
"terminal": "integrated",
"justMyCode": false
}
]
}

25 changes: 22 additions & 3 deletions python/nistoar/midas/dbio/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
"""
import time
import math
import asyncio
import threading
import inspect
from concurrent.futures import ThreadPoolExecutor
from abc import ABC, ABCMeta, abstractmethod, abstractproperty
from copy import deepcopy
from collections.abc import Mapping, MutableMapping, Set
Expand All @@ -24,6 +28,7 @@
from nistoar.pdr.utils.prov import Action
from .. import MIDASException
from .status import RecordStatus
from .websocket import WebSocketServer
from nistoar.pdr.utils.prov import ANONYMOUS_USER

DAP_PROJECTS = "dap"
Expand Down Expand Up @@ -822,14 +827,17 @@ class DBClient(ABC):
the only allowed shoulder will be the default, ``grp0``.
"""

def __init__(self, config: Mapping, projcoll: str, nativeclient=None, foruser: str = ANONYMOUS):
def __init__(self, config: Mapping, projcoll: str,websocket_server: WebSocketServer, nativeclient=None, foruser: str = ANONYMOUS):
self._cfg = config
self._native = nativeclient
self._projcoll = projcoll
self._who = foruser
self._whogrps = None

# Get the current stack frame
stack = inspect.stack()
self._dbgroups = DBGroups(self)
self.websocket_server = websocket_server


@property
def project(self) -> str:
Expand Down Expand Up @@ -891,7 +899,17 @@ def create_record(self, name: str, shoulder: str = None, foruser: str = None) ->
rec['name'] = name
rec = ProjectRecord(self._projcoll, rec, self)
rec.save()
# Send a message through the WebSocket
message = f"{name}"
self._send_message_async(message)
return rec

def _send_message_async(self, message):
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.create_task(self.websocket_server.send_message_to_clients(message))
else:
asyncio.run(self.websocket_server.send_message_to_clients(message))

def _default_shoulder(self):
out = self._cfg.get("default_shoulder")
Expand Down Expand Up @@ -1282,7 +1300,7 @@ class DBClientFactory(ABC):
an abstract class for creating client connections to the database
"""

def __init__(self, config):
def __init__(self, config,websocket_server: WebSocketServer):
"""
initialize the factory with its configuration. The configuration provided here serves as
the default parameters for the cient as these can be overridden by the configuration parameters
Expand All @@ -1292,6 +1310,7 @@ def __init__(self, config):
depend on the type of project being access (e.g. "dmp" vs. "dap").
"""
self._cfg = config
self.websocket_server = websocket_server

@abstractmethod
def create_client(self, servicetype: str, config: Mapping = {}, foruser: str = ANONYMOUS):
Expand Down
14 changes: 8 additions & 6 deletions python/nistoar/midas/dbio/inmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Mapping, MutableMapping, Set
from typing import Iterator, List
from . import base
from .websocket import WebSocketServer

from nistoar.base.config import merge_config

Expand All @@ -15,9 +16,9 @@ class InMemoryDBClient(base.DBClient):
an in-memory DBClient implementation
"""

def __init__(self, dbdata: Mapping, config: Mapping, projcoll: str, foruser: str = base.ANONYMOUS):
def __init__(self, dbdata: Mapping, config: Mapping, projcoll: str,websocket=WebSocketServer ,foruser: str = base.ANONYMOUS):
self._db = dbdata
super(InMemoryDBClient, self).__init__(config, projcoll, self._db, foruser)
super(InMemoryDBClient, self).__init__(config, projcoll,websocket, self._db, foruser)

def _next_recnum(self, shoulder):
if shoulder not in self._db['nextnum']:
Expand Down Expand Up @@ -79,7 +80,6 @@ def select_records(self, perm: base.Permissions=base.ACLs.OWN, **cnsts) -> Itera
for p in perm:
if rec.authorized(p):
yield deepcopy(rec)
break

def adv_select_records(self, filter:dict,
perm: base.Permissions=base.ACLs.OWN,) -> Iterator[base.ProjectRecord]:
Expand Down Expand Up @@ -139,7 +139,7 @@ class InMemoryDBClientFactory(base.DBClientFactory):
clients it creates.
"""

def __init__(self, config: Mapping, _dbdata = None):
def __init__(self, config: Mapping,websocket_server:WebSocketServer, _dbdata = None):
"""
Create the factory with the given configuration.

Expand All @@ -148,7 +148,8 @@ def __init__(self, config: Mapping, _dbdata = None):
of the in-memory data structure required to use this input.) If
not provided, an empty database is created.
"""
super(InMemoryDBClientFactory, self).__init__(config)
super(InMemoryDBClientFactory, self).__init__(config,websocket_server)
self.websocket_server = websocket_server
self._db = {
base.DAP_PROJECTS: {},
base.DMP_PROJECTS: {},
Expand All @@ -161,8 +162,9 @@ def __init__(self, config: Mapping, _dbdata = None):


def create_client(self, servicetype: str, config: Mapping={}, foruser: str = base.ANONYMOUS):

cfg = merge_config(config, deepcopy(self._cfg))
if servicetype not in self._db:
self._db[servicetype] = {}
return InMemoryDBClient(self._db, cfg, servicetype, foruser)
return InMemoryDBClient(self._db, cfg, servicetype,self.websocket_server, foruser)

56 changes: 56 additions & 0 deletions python/nistoar/midas/dbio/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# websocket_server.py
import asyncio
import websockets
from concurrent.futures import ThreadPoolExecutor
import copy

class WebSocketServer:
def __init__(self, host="localhost", port=8765):
self.host = host
self.port = port
self.server = None
self.clients = set() # Initialize the clients set

async def start(self):
self.server = await websockets.serve(self.websocket_handler, self.host, self.port)
#print(f"WebSocket server started on ws://{self.host}:{self.port}")


async def websocket_handler(self, websocket):
# Add the new client to the set
self.clients.add(websocket)
try:
async for message in websocket:
await self.send_message_to_clients(message)
finally:
# Remove the client from the set when they disconnect
self.clients.remove(websocket)

async def send_message_to_clients(self, message):
for client in self.clients:
print(client)
if self.clients:
for client in self.clients:
asyncio.create_task(client.send(message))

async def stop(self):
if self.server:
self.server.close()
await self.server.wait_closed()
self.server = None

async def wait_closed(self):
if self.server:
await self.server.wait_closed()

def __deepcopy__(self, memo):
# Create a shallow copy of the object
new_copy = copy.copy(self)
# Deep copy the attributes that are not problematic
new_copy.host = copy.deepcopy(self.host, memo)
new_copy.port = copy.deepcopy(self.port, memo)
new_copy.clients = copy.deepcopy(self.clients, memo)
# Do not copy the problematic attribute
new_copy.server = self.server
return new_copy

128 changes: 123 additions & 5 deletions python/tests/nistoar/midas/dbio/test_inmem.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os, json, pdb, logging
import os, json, pdb, logging,asyncio
from pathlib import Path
import unittest as test
import websockets

from nistoar.midas.dbio import inmem, base
from nistoar.pdr.utils.prov import Action, Agent
from nistoar.midas.dbio.websocket import WebSocketServer

testuser = Agent("dbio", Agent.AUTO, "tester", "test")
testdir = Path(__file__).parents[0]
Expand All @@ -26,12 +28,49 @@
with open(dmp_path, 'r') as file:
dmp = json.load(file)


class TestInMemoryDBClientFactory(test.TestCase):

@classmethod
def initialize_websocket_server(cls):
websocket_server = WebSocketServer()
try:
cls.loop = asyncio.get_event_loop()
if cls.loop.is_closed():
raise RuntimeError
except RuntimeError:
cls.loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls.loop)
cls.loop.run_until_complete(websocket_server.start())
#print("WebSocketServer initialized:", websocket_server)
return websocket_server

@classmethod
def setUpClass(cls):
cls.websocket_server = cls.initialize_websocket_server()

@classmethod
def tearDownClass(cls):
# Ensure the WebSocket server is properly closed
cls.loop.run_until_complete(cls.websocket_server.stop())
cls.loop.run_until_complete(cls.websocket_server.wait_closed())

# Cancel all lingering tasks
asyncio.set_event_loop(cls.loop) # Set the event loop as the current event loop
tasks = asyncio.all_tasks(loop=cls.loop)
for task in tasks:
task.cancel()
cls.loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))

# Close the event loop
cls.loop.close()



def setUp(self):
self.cfg = {"goob": "gurn"}
self.fact = inmem.InMemoryDBClientFactory(
self.cfg, {"nextnum": {"hank": 2}})
self.cfg,self.websocket_server, {"nextnum": {"hank": 2}})

def test_ctor(self):
self.assertEqual(self.fact._cfg, self.cfg)
Expand All @@ -55,11 +94,44 @@ def test_create_client(self):

class TestInMemoryDBClient(test.TestCase):

@classmethod
def initialize_websocket_server(cls):
websocket_server = WebSocketServer()
try:
cls.loop = asyncio.get_event_loop()
except RuntimeError:
cls.loop = asyncio.new_event_loop()
asyncio.set_event_loop(cls.loop)
cls.loop.run_until_complete(websocket_server.start())
#print("WebSocketServer initialized:", websocket_server)
return websocket_server

@classmethod
def setUpClass(cls):
cls.websocket_server = cls.initialize_websocket_server()

@classmethod
def tearDownClass(cls):
# Ensure the WebSocket server is properly closed
cls.loop.run_until_complete(cls.websocket_server.stop())
cls.loop.run_until_complete(cls.websocket_server.wait_closed())

# Cancel all lingering tasks
asyncio.set_event_loop(cls.loop)
tasks = asyncio.all_tasks(loop=cls.loop)
for task in tasks:
task.cancel()
cls.loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))

# Close the event loop
cls.loop.close()

def setUp(self):
self.cfg = {"default_shoulder": "mds3"}
self.user = "nist0:ava1"
self.cli = inmem.InMemoryDBClientFactory({}).create_client(
self.cli = inmem.InMemoryDBClientFactory({},self.websocket_server).create_client(
base.DMP_PROJECTS, self.cfg, self.user)


def test_next_recnum(self):
self.assertEqual(self.cli._next_recnum("goob"), 1)
Expand Down Expand Up @@ -270,8 +342,6 @@ def test_adv_select_records(self):
self.cli._db[base.DMP_PROJECTS][id] = rec.to_dict()




id = "pdr0:0006"
rec = base.ProjectRecord(
base.DMP_PROJECTS, {"id": id, "name": "test 2", "status": {
Expand Down Expand Up @@ -443,5 +513,53 @@ def test_record_action(self):
self.assertEqual(acts[1]['type'], Action.COMMENT)


class TestWebSocketServer(test.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.websocket_server = WebSocketServer()
self.loop = asyncio.get_event_loop()
await self.websocket_server.start()

# Initialize the InMemoryDBClientFactory with the websocket_server
self.cfg = {"default_shoulder": "mds3"}
self.user = "nist0:ava1"
self.cli = inmem.InMemoryDBClientFactory({},self.websocket_server).create_client(
base.DMP_PROJECTS, self.cfg, self.user)

async def asyncTearDown(self):
await self.websocket_server.stop()
await self.websocket_server.wait_closed()


async def test_create_records_websocket(self):
messages = []

async def receive_messages(uri):
try:
async with websockets.connect(uri) as websocket:
while True:
message = await websocket.recv()
#print(f"Received message: {message}")
messages.append(message)
#print(f"Messages: {messages}")
# Break the loop after receiving the first message for this test
except Exception as e:
print(f"Failed to connect to WebSocket server: {e}")

# Start the WebSocket client to receive messages
uri = 'ws://localhost:8765'
receive_task = asyncio.create_task(receive_messages(uri))
await asyncio.sleep(2)

#await self.websocket_server.send_message_to_clients("Connection established")
# Inject some data into the database
rec = self.cli.create_record("mine1")
await asyncio.sleep(2)

#print(f"Messages: {messages}")
self.assertEqual(len(messages), 1)
self.assertEqual(messages[0], "mine1")



if __name__ == '__main__':
test.main()
Loading
Loading