diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 0000000..ca39b92 --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,87 @@ +name: Run Tests API and Worker + +on: + pull_request: + workflow_dispatch: + +env: + NATS_TOKEN: test + +jobs: + test: + name: Run tests + runs-on: ubuntu-latest + steps: + - name: 👀 Checkout code + uses: actions/checkout@v2 + with: + submodules: true + + - name: Setup apt cache + uses: actions/cache@v2 + with: + path: /var/cache/apt/archives + key: ${{ runner.os }}-apt-${{ hashFiles('/etc/apt/sources.list') }} + + - name: 😭 Install system dependencies + run: | + sudo apt-get update && sudo apt-get install -y \ + netcat \ + unzip \ + libgeos-dev \ + libcurl4-openssl-dev \ + libssl-dev \ + binutils \ + curl \ + git \ + autoconf \ + automake \ + build-essential \ + libtool \ + gcc \ + libmagic-dev \ + poppler-utils \ + tesseract-ocr \ + libreoffice \ + libpq-dev \ + pandoc + + - name: 🔽 Download and Install NATS Server + run: | + curl -L https://github.com/nats-io/nats-server/releases/download/v2.10.22/nats-server-v2.10.22-linux-amd64.zip -o nats-server.zip + unzip nats-server.zip -d nats-server && sudo cp nats-server/nats-server-v2.10.22-linux-amd64/nats-server /usr/bin + + - name: 🛠️ Set up NATS arguments + run: | + nohup nats-server \ + --addr 0.0.0.0 \ + --port 4222 \ + --auth "$NATS_TOKEN" > nats.log 2>&1 & + + - name: 🔍 Verify NATS Server is Running + run: | + sleep 1 # Give the server some time to start + if nc -zv localhost 4222; then + echo "✅ NATS Server is running on port 4222." + else + echo "❌ Failed to start NATS Server." + cat nats.log + exit 1 + fi + + - name: 🔨 Install the latest version of rye + uses: eifinger/setup-rye@v4 + with: + enable-cache: true + + - name: 🎯 Cache hit! + if: steps.setup-rye.outputs.cache-hit == 'true' + run: echo "Rye cache was restored" + + - name: 🔄 Sync dependencies + run: | + UV_INDEX_STRATEGY=unsafe-first-match rye sync --no-lock + + - name: 🚀 Run tests + run: | + rye test -p megaparse-sdk diff --git a/libs/megaparse_sdk/megaparse_sdk/client.py b/libs/megaparse_sdk/megaparse_sdk/client.py index ca3086c..ddf978c 100644 --- a/libs/megaparse_sdk/megaparse_sdk/client.py +++ b/libs/megaparse_sdk/megaparse_sdk/client.py @@ -1,17 +1,20 @@ import asyncio +import enum import logging from io import BytesIO from pathlib import Path -from typing import Any +from types import TracebackType +from typing import Any, Self import httpx import nats -from nats.errors import TimeoutError +from nats.errors import NoRespondersError, TimeoutError from megaparse_sdk.config import ClientNATSConfig, MegaParseConfig from megaparse_sdk.schema.mp_exceptions import ( DownloadError, InternalServiceError, + MemoryLimitExceeded, ModelNotSupported, ParsingException, ) @@ -67,25 +70,65 @@ async def close(self): await self.session.aclose() +class ClientState(enum.Enum): + # First state of the client + UNOPENED = 1 + # Client has either sent a request, or is within a `with` block. + OPENED = 2 + # Client has either exited the `with` block, or `close()` called. + CLOSED = 3 + + class MegaParseNATSClient: - def __init__(self, config: ClientNATSConfig = ClientNATSConfig()): + def __init__(self, config: ClientNATSConfig): self.nc_config = config self.max_retries = self.nc_config.max_retries self.backoff = self.nc_config.backoff if self.nc_config.ssl_config: self.ssl_ctx = load_ssl_cxt(self.nc_config.ssl_config) + # Client connection + self._state = ClientState.UNOPENED + self._nc = None async def _get_nc(self): if self._nc is None: self._nc = await nats.connect( - self.nc_config.nats_endpoint, tls=self.ssl_ctx + self.nc_config.endpoint, + tls=self.ssl_ctx, + connect_timeout=self.nc_config.connect_timeout, + reconnect_time_wait=self.nc_config.reconnect_time_wait, + max_reconnect_attempts=self.nc_config.max_reconnect_attempts, ) return self._nc return self._nc + async def __aenter__(self: Self) -> Self: + if self._state != ClientState.UNOPENED: + msg = { + ClientState.OPENED: "Cannot open a client instance more than once.", + ClientState.CLOSED: ( + "Cannot reopen a client instance, client was closed." + ), + }[self._state] + raise RuntimeError(msg) + + self._state = ClientState.OPENED + + await self._get_nc() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + self._state = ClientState.CLOSED + await self.aclose() + async def parse_url(self, url: str): url_inp = ParseUrlInput(url=url) - await self._send_req(MPInput(input=url_inp)) + return await self._send_req(MPInput(input=url_inp)) async def parse_file(self, file: Path | BytesIO) -> str: if isinstance(file, Path): @@ -108,9 +151,10 @@ async def _send_req(self, inp: MPInput) -> str: for attempt in range(self.max_retries): try: return await self._send_req_inner(inp) - except TimeoutError: - logger.error(f"Timeout error parsing. Retrying {attempt} time") + except (TimeoutError, NoRespondersError) as e: + logger.error(f"Sending req error: {e}. Retrying for {attempt} time") if attempt < self.max_retries - 1: + logger.debug(f"Backoff for {2**self.backoff}s") await asyncio.sleep(2**self.backoff) raise ParsingException @@ -122,15 +166,17 @@ async def _send_req_inner(self, inp: MPInput): timeout=self.nc_config.timeout, ) response = MPOutput.model_validate_json(raw_response.data.decode("utf-8")) + return self._handle_mp_output(response) + + def _handle_mp_output(self, response: MPOutput) -> str: if response.output_type == MPOutputType.PARSE_OK: assert response.result, "Parsing OK but response is None" return response.result elif response.output_type == MPOutputType.PARSE_ERR: assert response.err, "Parsing OK but response is None" - match response.err.mp_err_code: case MPErrorType.MEMORY_LIMIT: - raise ModelNotSupported + raise MemoryLimitExceeded case MPErrorType.INTERNAL_SERVER_ERROR: raise InternalServiceError case MPErrorType.MODEL_NOT_SUPPORTED: @@ -139,9 +185,8 @@ async def _send_req_inner(self, inp: MPInput): raise DownloadError case MPErrorType.PARSING_ERROR: raise ParsingException - else: - raise ValueError(f"unknown service response type: {response}") + raise ValueError(f"unknown service response type: {response}") - async def close(self): + async def aclose(self): nc = await self._get_nc() await nc.close() diff --git a/libs/megaparse_sdk/megaparse_sdk/config.py b/libs/megaparse_sdk/megaparse_sdk/config.py index fdf5f1b..97ffe38 100644 --- a/libs/megaparse_sdk/megaparse_sdk/config.py +++ b/libs/megaparse_sdk/megaparse_sdk/config.py @@ -17,9 +17,9 @@ class MegaParseConfig(BaseSettings): class SSLConfig(BaseModel): - ca_cert_file: FilePath ssl_key_file: FilePath ssl_cert_file: FilePath + ca_cert_file: FilePath | None = None class ClientNATSConfig(BaseSettings): @@ -28,7 +28,10 @@ class ClientNATSConfig(BaseSettings): ) subject: Literal["parsing"] = "parsing" endpoint: str = "https://tests@nats.tooling.quivr.app:4222" - timeout: int = 600 + timeout: float = 600 max_retries: int = 5 - backoff: int = 3 + backoff: float = 3 + connect_timeout: int = 5 + reconnect_time_wait: int = 1 + max_reconnect_attempts: int = 20 ssl_config: SSLConfig | None = None diff --git a/libs/megaparse_sdk/megaparse_sdk/utils/load_ssl.py b/libs/megaparse_sdk/megaparse_sdk/utils/load_ssl.py index d4f3c5d..f7bb9e2 100644 --- a/libs/megaparse_sdk/megaparse_sdk/utils/load_ssl.py +++ b/libs/megaparse_sdk/megaparse_sdk/utils/load_ssl.py @@ -5,7 +5,8 @@ def load_ssl_cxt(ssl_config: SSLConfig): context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - context.load_verify_locations(cafile=ssl_config.ca_cert_file) + if ssl_config.ca_cert_file: + context.load_verify_locations(cafile=ssl_config.ca_cert_file) context.load_cert_chain( certfile=ssl_config.ssl_cert_file, keyfile=ssl_config.ssl_key_file ) diff --git a/libs/megaparse_sdk/tests/README.md b/libs/megaparse_sdk/tests/README.md new file mode 100644 index 0000000..e69de29 diff --git a/libs/megaparse_sdk/tests/certs/rootCA.pem b/libs/megaparse_sdk/tests/certs/rootCA.pem new file mode 100644 index 0000000..6011345 --- /dev/null +++ b/libs/megaparse_sdk/tests/certs/rootCA.pem @@ -0,0 +1,29 @@ +-----BEGIN CERTIFICATE----- +MIIFCzCCA3OgAwIBAgIQESt0eck2KvFrAMyiDyceujANBgkqhkiG9w0BAQsFADCB +nTEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTkwNwYDVQQLDDBhbWlu +ZUBhbWluZXMtTWFjQm9vay1Qcm8ubG9jYWwgKGFtaW5lIGRpcmhvdXNzaSkxQDA+ +BgNVBAMMN21rY2VydCBhbWluZUBhbWluZXMtTWFjQm9vay1Qcm8ubG9jYWwgKGFt +aW5lIGRpcmhvdXNzaSkwHhcNMjQxMTE5MTAwMTA5WhcNMzQxMTE5MTAwMTA5WjCB +nTEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTkwNwYDVQQLDDBhbWlu +ZUBhbWluZXMtTWFjQm9vay1Qcm8ubG9jYWwgKGFtaW5lIGRpcmhvdXNzaSkxQDA+ +BgNVBAMMN21rY2VydCBhbWluZUBhbWluZXMtTWFjQm9vay1Qcm8ubG9jYWwgKGFt +aW5lIGRpcmhvdXNzaSkwggGiMA0GCSqGSIb3DQEBAQUAA4IBjwAwggGKAoIBgQCw +6TX1kvqVMb8ZUQVT/vuDsedmbYgSFn68yJRlmE9BsqG7TLQHl2Kw6VQqZBSIkeZG +CypmUysX/3qrvICeArIdmmsrWUTDYPoauw/a/RY0I07rALj3YR0Y7039Hxf/UPT9 +xlUtnM2NafkZyp6WRjEN0N4ETvJDIbUQiosiiPilxhwRbJURhT/JPskaw+OM2Sw5 +dFAT20zkYC5VIc4wJBFLAMG0XzI6Sy/4wI1WdRBXd2UMpQU4u7TyD0RB4mnHorV6 +kXjtLKD/KWSrSG1nnum9SB9eVatbRD+TUgoclwAKedrlCDEM4EsXVVuUuYCizQNb ++H3BSPfj1upUW5eKfgAyB+8r4QGf2yCY9O8NMMrJ1K5Qv4vSuWAU2tZqAyE8Z4Ke +UtHsl/M0zIvIKwyki2N/rieL/m6lTzS3dwSf9vv7eePEvxd8SBClSF07MUzyxkZ5 +UYNxaK5t2ZRADZ6n/9/hAQsMscCkHiX1N2ypBFV+86Pr78BC48JgIyCMwuiBN4sC +AwEAAaNFMEMwDgYDVR0PAQH/BAQDAgIEMBIGA1UdEwEB/wQIMAYBAf8CAQAwHQYD +VR0OBBYEFFdsN4L0DOS2tdn5PNLSV6DP9eJeMA0GCSqGSIb3DQEBCwUAA4IBgQBj +KosfLfW/ZH80NM16pvpyRF3mCi+q+I+P8zrfilMYJBH4EEdEGAUgTO5do1kJXeel +Wky+FNxaP6KCNiT+0amypKg+yjBlnqLKVdnEgR5s12ZfmerV59stx1A/c/bYMEAS +re6xskBkowP2cVQHAC2dy/0Ov+lZsiNaPV2bQx6KUJurveebUQsH3uF3ZEhnUVQ6 +rt5+JGY4x9Tr1YMhvHqEDTrsipPdDB1MyW1SnCkqSXrz+DPXGd8BW0O0hpM5la81 +J+rfZGinbcUgXM6JMLIHDxLc4Xxzm4NijFzXhbR3XPXqEwsnZOuxcYYFgUGs3FwS +4ro+34a/O4uKS2KV8wsUWj/tWD2rLpduDgag4WSipCvWtaNve8gPdUiyPxUqxyoZ +aFAFg/izXwmRntogJtV0Zvo3fqAaQQDl8t2s21IIx0wmgHzgmkswb5OwFg3dOn/S +lmaH8v7FCBP7jHx/NCPTT5Sy/1EMRATmhFDUZ8Bod/TIlV3e+FCVqlX3kBBRbAU= +-----END CERTIFICATE----- diff --git a/libs/megaparse_sdk/tests/test_nats_client.py b/libs/megaparse_sdk/tests/test_nats_client.py new file mode 100644 index 0000000..42f6fe4 --- /dev/null +++ b/libs/megaparse_sdk/tests/test_nats_client.py @@ -0,0 +1,197 @@ +import asyncio +import logging +from pathlib import Path + +import nats +import pytest +import pytest_asyncio +from megaparse_sdk.client import ClientState, MegaParseNATSClient +from megaparse_sdk.config import ClientNATSConfig, SSLConfig +from megaparse_sdk.schema.mp_exceptions import ( + DownloadError, + InternalServiceError, + MemoryLimitExceeded, + ModelNotSupported, + ParsingException, +) +from megaparse_sdk.schema.mp_inputs import MPInput, ParseFileInput, ParseUrlInput +from megaparse_sdk.schema.mp_outputs import ( + MPErrorType, + MPOutput, + MPOutputType, + ParseError, +) +from nats.aio.client import Client + +logger = logging.getLogger(__name__) + +NATS_URL = "nats://test@127.0.0.1:4222" +NATS_SUBJECT = "parsing" +SSL_CERT_FILE = "./tests/certs/client-cert.pem" +SSL_KEY_FILE = "./tests/certs/client-key.pem" +CA_CERT_FILE = "./tests/certs/rootCA.pem" + + +@pytest.fixture(scope="session") +def ssl_config() -> SSLConfig: + return SSLConfig( + ca_cert_file=CA_CERT_FILE, + ssl_key_file=SSL_KEY_FILE, + ssl_cert_file=SSL_CERT_FILE, + ) + + +@pytest.fixture(scope="session") +def nc_config(ssl_config: SSLConfig) -> ClientNATSConfig: + config = ClientNATSConfig( + subject=NATS_SUBJECT, + endpoint=NATS_URL, + ssl_config=ssl_config, + timeout=0.5, + max_retries=1, + backoff=-1, + connect_timeout=1, + reconnect_time_wait=1, + max_reconnect_attempts=1, + ) + return config + + +@pytest_asyncio.fixture(scope="function") +async def nats_service(nc_config: ClientNATSConfig): + # TODO: fix TLS handshake to work in CI + # ssl_config = load_ssl_cxt(nc_config.ssl_config) + nc = await nats.connect( + nc_config.endpoint, + tls=ssl_config, + connect_timeout=nc_config.connect_timeout, + reconnect_time_wait=nc_config.reconnect_time_wait, + max_reconnect_attempts=nc_config.max_reconnect_attempts, + ) + yield nc + await nc.drain() + + +@pytest.mark.asyncio +async def test_client_state_transition(nc_config: ClientNATSConfig): + mpc = MegaParseNATSClient(nc_config) + assert mpc._state == ClientState.UNOPENED + async with mpc: + assert mpc._state == ClientState.OPENED + assert mpc._state == ClientState.CLOSED + + with pytest.raises(RuntimeError): + async with mpc: + pass + + +@pytest.mark.asyncio(loop_scope="session") +async def test_client_parse_file(nats_service: Client, nc_config: ClientNATSConfig): + async def message_handler(msg): + parsed_input = MPInput.model_validate_json(msg.data.decode("utf-8")).input + assert isinstance(parsed_input, ParseFileInput) + output = MPOutput(output_type=MPOutputType.PARSE_OK, result="test") + await nats_service.publish(msg.reply, output.model_dump_json().encode("utf-8")) + + await nats_service.subscribe(NATS_SUBJECT, "worker", cb=message_handler) + + file_path = Path("./tests/pdf/sample_table.pdf") + async with MegaParseNATSClient(nc_config) as mp_client: + resp = await mp_client.parse_file(file=file_path) + assert resp == "test" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_client_parse_url(nats_service: Client, nc_config: ClientNATSConfig): + async def message_handler(msg): + parsed_input = MPInput.model_validate_json(msg.data.decode("utf-8")).input + assert isinstance(parsed_input, ParseUrlInput) + output = MPOutput(output_type=MPOutputType.PARSE_OK, result="url") + await nats_service.publish(msg.reply, output.model_dump_json().encode("utf-8")) + + await nats_service.subscribe(NATS_SUBJECT, "worker", cb=message_handler) + + async with MegaParseNATSClient(nc_config) as mp_client: + resp = await mp_client.parse_url(url="this://this") + assert resp == "url" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_client_parse_timeout(nats_service: Client, ssl_config: SSLConfig): + nc_config = ClientNATSConfig( + subject=NATS_SUBJECT, + endpoint=NATS_URL, + ssl_config=ssl_config, + timeout=0.1, + max_retries=1, + backoff=1, + ) + + async def service(msg): + await asyncio.sleep(2 * nc_config.timeout) + + await nats_service.subscribe(NATS_SUBJECT, "worker", cb=service) + + file_path = Path("./tests/pdf/sample_table.pdf") + with pytest.raises(ParsingException): + async with MegaParseNATSClient(nc_config) as mp_client: + await mp_client.parse_file(file=file_path) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_client_parse_timeout_retry(nats_service: Client, ssl_config: SSLConfig): + nc_config = ClientNATSConfig( + subject=NATS_SUBJECT, + endpoint=NATS_URL, + ssl_config=ssl_config, + timeout=0.1, + max_retries=2, + backoff=-5, + ) + + msgs = [] + + async def service(msg): + msgs.append(msg) + await asyncio.sleep(2 * nc_config.timeout) + + await nats_service.subscribe(NATS_SUBJECT, "worker", cb=service) + + file_path = Path("./tests/pdf/sample_table.pdf") + with pytest.raises(ParsingException): + async with MegaParseNATSClient(nc_config) as mp_client: + await mp_client.parse_file(file=file_path) + assert len(msgs) == 2 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize( + "mp_error_type, exception_class", + [ + ("MEMORY_LIMIT", MemoryLimitExceeded), + ("INTERNAL_SERVER_ERROR", InternalServiceError), + ("MODEL_NOT_SUPPORTED", ModelNotSupported), + ("DOWNLOAD_ERROR", DownloadError), + ("PARSING_ERROR", ParsingException), + ], +) +async def test_client_parse_file_excp( + nats_service: Client, nc_config: ClientNATSConfig, mp_error_type, exception_class +): + async def message_handler(msg): + parsed_input = MPInput.model_validate_json(msg.data.decode("utf-8")).input + assert isinstance(parsed_input, ParseFileInput) + err = ParseError(mp_err_code=MPErrorType[mp_error_type], message="") + output = MPOutput( + output_type=MPOutputType.PARSE_ERR, + err=err, + result=None, + ) + await nats_service.publish(msg.reply, output.model_dump_json().encode("utf-8")) + + await nats_service.subscribe(NATS_SUBJECT, "worker", cb=message_handler) + + file_path = Path("./tests/pdf/sample_table.pdf") + with pytest.raises(exception_class): + async with MegaParseNATSClient(nc_config) as mp_client: + await mp_client.parse_file(file=file_path) diff --git a/libs/megaparse_sdk/tests/test_service.py b/libs/megaparse_sdk/tests/test_service.py deleted file mode 100644 index 21854a3..0000000 --- a/libs/megaparse_sdk/tests/test_service.py +++ /dev/null @@ -1,39 +0,0 @@ -import nats -import pytest -from megaparse_sdk.schema.mp_inputs import ( - FileInput, - MPInput, - ParseFileConfig, - ParseFileInput, -) -from megaparse_sdk.schema.mp_outputs import MPOutput, MPOutputType -from megaparse_sdk.utils.load_ssl import load_ssl_cxt - - -@pytest.mark.asyncio -async def test_parse_file_nats(): - NATS_URL = "nats://test@localhost:4222" - NATS_SUBJECT = "parse.file" - SSL_CERT_FILE = "./tests/certs/client-cert.pem" - SSL_KEY_FILE = "./tests/certs/client-key.pem" - CA_CERT_FILE = "./test/certs/rootCA.pem" - ctx = load_ssl_cxt( - cert_file=SSL_CERT_FILE, ca_cert_file=CA_CERT_FILE, key_file=SSL_KEY_FILE - ) - nc = await nats.connect(NATS_URL, tls=ctx) - - with open("./tests/pdf/sample_table.pdf", "rb") as f: - data = f.read() - file_input = ParseFileInput( - file_input=FileInput(file_name="test.pdf", file_size=len(data), data=data), - parse_config=ParseFileConfig(), - ) - file = MPInput(input=file_input) - - raw_response = await nc.request( - NATS_SUBJECT, file.model_dump_json().encode("utf-8"), timeout=300 - ) - response = MPOutput.model_validate_json(raw_response.data.decode("utf-8")) - assert response.output_type == MPOutputType.PARSE_OK - - await nc.close() diff --git a/pyproject.toml b/pyproject.toml index 353456f..010c4d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,3 +75,15 @@ section-order = [ "local-folder", ] known-first-party = [] + + +[tool.pytest.ini_options] +addopts = "--tb=short -ra -v" +asyncio_default_fixture_loop_scope = "session" +filterwarnings = ["ignore::DeprecationWarning"] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "base: these tests require quivr-core with extra `base` to be installed", + "tika: these tests require a tika server to be running", + "unstructured: these tests require `unstructured` dependency", +]