Skip to content

Commit

Permalink
refactor: use monorepo structure
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandrodnm committed Oct 9, 2024
1 parent 57df731 commit db69e5e
Show file tree
Hide file tree
Showing 114 changed files with 3,952 additions and 73 deletions.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ RUN set -eux; \
ENV PIP_BREAK_SYSTEM_PACKAGES=1
COPY requirements-test.txt /build/requirements-test.txt
RUN pip install -r /build/requirements-test.txt
COPY projects/pgai/requirements.txt /build/requirements-pgai.txt
RUN pip install -r /build/requirements-pgai.txt
RUN rm -r /build

WORKDIR /pgai
WORKDIR /pgai
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ test-server:
vectorizer:
@./build.py vectorizer

.PHONY: test-vectorizer
test-vectorizer:
@cd projects/pgai && pytest

.PHONY: test
test:
test: test-vectorizer
@./build.py test

.PHONY: lint-sql
Expand Down Expand Up @@ -146,4 +150,3 @@ docker-shell:
.PHONY: psql-shell
psql-shell:
@docker exec -it -u postgres pgai /bin/bash -c "set -e; if [ -f .env ]; then set -a; source .env; set +a; fi; psql"

18 changes: 12 additions & 6 deletions build.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python3
import os
import platform
import subprocess
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -69,11 +69,11 @@ def project_dir() -> Path:


def sql_dir() -> Path:
return project_dir().joinpath("sql").resolve()
return src_extension_dir().joinpath("sql").resolve()


def src_dir() -> Path:
return project_dir().joinpath("src").resolve()
return project_dir().joinpath("projects").resolve()


def src_extension_dir() -> Path:
Expand Down Expand Up @@ -135,7 +135,7 @@ def output_sql_file() -> Path:


def tests_dir() -> Path:
return project_dir().joinpath("tests")
return project_dir().joinpath("projects/extension/tests")


def where_am_i() -> str:
Expand Down Expand Up @@ -459,7 +459,7 @@ def clean_vectorizer() -> None:

def uninstall_vectorizer() -> None:
subprocess.run(
f'pip3 uninstall -v -y vectorizer',
"pip3 uninstall -v -y vectorizer",
check=True,
shell=True,
env=os.environ,
Expand Down Expand Up @@ -497,7 +497,7 @@ def clean() -> None:

def test_server() -> None:
if where_am_i() == "host":
cmd = "docker exec -it -w /pgai/tests/vectorizer pgai fastapi dev server.py"
cmd = "docker exec -it -w /projects/extension/tests/vectorizer pgai fastapi dev server.py"
subprocess.run(cmd, shell=True, check=True, env=os.environ, cwd=project_dir())
else:
cmd = "fastapi dev server.py"
Expand Down Expand Up @@ -583,10 +583,16 @@ def docker_build_vectorizer() -> None:


def docker_run() -> None:
# Set TESTCONTAINERS_HOST_OVERRIDE when running on MacOS.
env_var = ""
if platform.system() == "Darwin":
env_var = "-e TESTCONTAINERS_HOST_OVERRIDE=host.docker.internal"
cmd = " ".join(
[
"docker run -d --name pgai -p 127.0.0.1:5432:5432 -e POSTGRES_HOST_AUTH_METHOD=trust",
"-v /var/run/docker.sock:/var/run/docker.sock",
f"--mount type=bind,src={project_dir()},dst=/pgai",
env_var, # Include the environment variable if on macOS
"pgai",
"-c shared_preload_libraries='timescaledb, pgextwlist'",
"-c extwlist.extensions='ai,vector'",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions tests/conftest.py → projects/extension/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ def create_test_db(cur: psycopg.Cursor) -> None:
@pytest.fixture(scope="session", autouse=True)
def set_up_test_db() -> None:
# create a test user and test database owned by the test user
with psycopg.connect(f"postgres://[email protected]:5432/postgres", autocommit=True) as con:
with psycopg.connect("postgres://[email protected]:5432/postgres", autocommit=True) as con:
with con.cursor() as cur:
create_test_user(cur)
create_test_db(cur)
# grant some things to the test user in the test database
with psycopg.connect(f"postgres://[email protected]:5432/test", autocommit=True) as con:
with psycopg.connect("postgres://[email protected]:5432/test", autocommit=True) as con:
with con.cursor() as cur:
cur.execute("grant execute on function pg_read_binary_file(text) to test")
cur.execute("grant pg_read_server_files to test")
# use the test user to create the extension in the test database
with psycopg.connect(f"postgres://[email protected]:5432/test") as con:
with psycopg.connect("postgres://[email protected]:5432/test") as con:
with con.cursor() as cur:
cur.execute("create extension ai cascade")
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def dump_db() -> None:
def restore_db() -> None:
with psycopg.connect(db_url(user=USER, dbname="dst")) as con:
with con.cursor() as cur:
cur.execute(f"create extension ai cascade")
cur.execute("create extension ai cascade")
cmd = " ".join([
"psql",
f'''-d "{db_url(USER, "dst")}"''',
Expand Down Expand Up @@ -145,4 +145,3 @@ def test_dump_restore():
assert dst == src
after_dst() # make sure we can USE the restored db
assert count_vectorizers() == 2

File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import os
import subprocess
import json

import psycopg
from psycopg.rows import namedtuple_row
import pytest
from psycopg.rows import namedtuple_row

# skip tests in this module if disabled
enable_vectorizer_tests = os.getenv("ENABLE_VECTORIZER_TESTS")
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_vectorizer_timescaledb():
assert actual == 3

# bob should have select on the source table
cur.execute(f"select has_table_privilege('bob', 'website.blog', 'select')")
cur.execute("select has_table_privilege('bob', 'website.blog', 'select')")
actual = cur.fetchone()[0]
assert actual

Expand Down Expand Up @@ -375,7 +375,7 @@ def test_vectorizer_timescaledb():
cur2.execute("begin transaction")
# lock 1 row from the queue
cur2.execute(f"select * from {vec.queue_schema}.{vec.queue_table} where title = 'how to grill a steak' for update")
locked = cur2.fetchone()
cur2.fetchone()
# check that vectorizer queue depth still gets the correct count
cur.execute("select ai.vectorizer_queue_pending(%s)", (vectorizer_id,))
actual = cur.fetchone()[0]
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_drop_vectorizer():
assert actual == 0

# does the func that backed the trigger exist? (it should not)
cur.execute(f"""
cur.execute("""
select count(*)
from pg_proc
where oid = %s
Expand All @@ -537,7 +537,7 @@ def test_drop_vectorizer():
assert actual == 0

# does the timescaledb job exist? (it should not)
cur.execute(f"""
cur.execute("""
select count(*)
from timescaledb_information.jobs
where job_id = %s
Expand Down Expand Up @@ -625,7 +625,7 @@ def index_creation_tester(cur: psycopg.Cursor, vectorizer_id: int) -> None:
cur.execute(f"insert into {vectorizer.queue_schema}.{vectorizer.queue_table}(id) select generate_series(1, 5)")

# should NOT create index
cur.execute(f"""
cur.execute("""
select ai._vectorizer_should_create_vector_index(v)
from ai.vectorizer v
where v.id = %s
Expand All @@ -650,7 +650,7 @@ def index_creation_tester(cur: psycopg.Cursor, vectorizer_id: int) -> None:
cur.execute(f"delete from {vectorizer.queue_schema}.{vectorizer.queue_table}")

# SHOULD create index
cur.execute(f"""
cur.execute("""
select ai._vectorizer_should_create_vector_index(v)
from ai.vectorizer v
where v.id = %s
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
13 changes: 9 additions & 4 deletions src/vectorizer/pyproject.toml → projects/pgai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@ requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[project]
name = "vectorizer"
name = "pgai"
dynamic = ["version", "dependencies"]
requires-python = ">=3.10"

[tool.setuptools.dynamic]
version = {attr = "vectorizer.__version__"}
version = {attr = "pgai.__version__"}
dependencies = {file = "requirements.txt"}

[tool.setuptools]
packages = ["vectorizer"]
packages = ["pgai"]

[project.scripts]
vectorizer = "vectorizer.cli:run"
vectorizer = "pgai.cli:run"

[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
]
2 changes: 2 additions & 0 deletions projects/pgai/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
python_files = test_*.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ openai==1.44.0
python-dotenv==1.0.1
structlog==24.4.0
datadog-lambda
pgvector==0.3.3
pgvector==0.3.3
tiktoken==0.7.0
8 changes: 4 additions & 4 deletions src/vectorizer/setup.cfg → projects/pgai/setup.cfg
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[metadata]
name = vectorizer
version = attr: vectorizer.__version__
name = pgai
version = attr: pgai.__version__

[options]
python_requires = >=3.10
packages = vectorizer
packages = pgai
install_requires = file: requirements.txt

[options.entry_points]
console_scripts =
vectorizer = vectorizer.cli:run
vectorizer = pgai.cli:run
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions src/vectorizer/vectorizer/cli.py → projects/pgai/src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from psycopg.rows import dict_row, namedtuple_row

from .__init__ import __version__
from .secrets import Secrets
from .vectorizer import Vectorizer, Worker
from vectorizer.secrets import Secrets
from vectorizer.vectorizer import Vectorizer, Worker

load_dotenv()
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
Expand Down
Empty file.
File renamed without changes.
15 changes: 15 additions & 0 deletions projects/pgai/src/vectorizer/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass


@dataclass
class ConnInfo:
host: str
port: int
role: str
password: str
db_name: str
ssl_mode: str = "require"

@property
def url(self) -> str:
return f"postgres://{self.role}:{self.password}@{self.host}:{self.port}/{self.db_name}?sslmode={self.ssl_mode}"
File renamed without changes.
File renamed without changes.
File renamed without changes.
85 changes: 85 additions & 0 deletions projects/pgai/src/vectorizer/lambda_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
import logging
import os
from typing import Any

import structlog
from pydantic import AliasChoices, Field, ValidationError
from pydantic.dataclasses import dataclass

from . import db
from .env import get_bool_env
from .processing import CloudFunctions
from .secrets import Secrets
from .vectorizer import Vectorizer, Worker

TIKTOKEN_CACHE_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "tiktoken_cache"
)
structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.INFO))
logger = structlog.get_logger()


@dataclass
class UpdateEmbeddings:
db: db.ConnInfo
secrets: Secrets


@dataclass
class Event:
update_embeddings: UpdateEmbeddings
vectorizer: Vectorizer = Field(validation_alias=AliasChoices("payload"))


async def run_workers(
concurrency: int,
conn_info: db.ConnInfo,
vectorizer: Vectorizer,
) -> list[int]:
"""Runs the embedding tasks and wait for them to finish."""
# TODO: handle timeout so that lambdas are not killed by AWS
tasks = [
asyncio.create_task(Worker(conn_info.url, vectorizer).run())
for _ in range(concurrency)
]
return await asyncio.gather(*tasks)


def set_log_level(cf: CloudFunctions):
mapping = logging.getLevelNamesMapping()
if cf.log_level != "INFO" and cf.log_level in mapping:
structlog.configure(
wrapper_class=structlog.make_filtering_bound_logger(mapping[cf.log_level])
)


def lambda_handler(raw_event: dict[str, Any], _: Any) -> dict[str, int]:
"""Lambda entry point. Validates the config given via the event, and
starts the embedding tasks.
Args:
raw_event (dict): maps to the `Event` dataclass.
"""
try:
event = Event(**raw_event)
except ValidationError as e:
raise e

# The type error we are ignoring is because there's only one type available
# for Config.processing. We keep the check to signal intent, in case we add
# other types in the future.
if isinstance(event.vectorizer.config.processing, CloudFunctions): # type: ignore
set_log_level(event.vectorizer.config.processing)

event.vectorizer.config.embedding.set_api_key(event.update_embeddings.secrets)

os.environ["TIKTOKEN_CACHE_DIR"] = TIKTOKEN_CACHE_DIR
results = asyncio.run(
run_workers(
event.vectorizer.config.processing.concurrency,
event.update_embeddings.db,
event.vectorizer,
)
)
return {"statusCode": 200, "processed_tasks": sum(results)}
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit db69e5e

Please sign in to comment.