Skip to content

Commit

Permalink
chore: add extensive tests for alembic migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
Askir committed Nov 29, 2024
1 parent 51c4cb9 commit e64a08e
Show file tree
Hide file tree
Showing 13 changed files with 531 additions and 163 deletions.
13 changes: 9 additions & 4 deletions projects/pgai/pgai/alembic/autogenerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class ExistingVectorizer:

@comparators.dispatch_for("schema")
def compare_vectorizers(
autogen_context: AutogenContext, upgrade_ops: UpgradeOps, schemas: list[str]
autogen_context: AutogenContext,
upgrade_ops: UpgradeOps,
schemas: list[str], # noqa: ARG001
):
"""Compare vectorizers between model and database,
generating appropriate migration operations."""
Expand All @@ -46,11 +48,14 @@ def compare_vectorizers(
).fetchall()

for row in result:
parsed_vectorizer = Vectorizer.model_validate(row.vectorizer) # type: ignore
parsed_vectorizer = Vectorizer.model_validate(row.vectorizer) # type: ignore
existing_vectorizer = ExistingVectorizer(
parsed_vectorizer.id, CreateVectorizerParams.from_db_config(parsed_vectorizer)
parsed_vectorizer.id,
CreateVectorizerParams.from_db_config(parsed_vectorizer),
)
target_table = (
f"{parsed_vectorizer.target_schema}.{parsed_vectorizer.target_table}"
)
target_table = f"{parsed_vectorizer.target_schema}.{parsed_vectorizer.target_table}"
existing_vectorizers[target_table] = existing_vectorizer
# Get vectorizers from models
model_vectorizers: dict[str, CreateVectorizerParams] = {}
Expand Down
2 changes: 1 addition & 1 deletion projects/pgai/pgai/alembic/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
ChunkingConfig,
CreateVectorizerParams,
DiskANNIndexingConfig,
OpenAIEmbeddingConfig,
HNSWIndexingConfig,
NoSchedulingConfig,
OpenAIEmbeddingConfig,
ProcessingConfig,
SchedulingConfig,
)
Expand Down
99 changes: 79 additions & 20 deletions projects/pgai/pgai/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import textwrap
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, replace
from typing import Any, Literal, Protocol, runtime_checkable

from alembic.autogenerate.api import AutogenContext
Expand All @@ -10,11 +10,11 @@
LangChainCharacterTextSplitter,
LangChainRecursiveCharacterTextSplitter,
)
from pgai.vectorizer.embeddings import OpenAI, Ollama
from pgai.vectorizer.embeddings import Ollama, OpenAI
from pgai.vectorizer.formatting import ChunkValue, PythonTemplate
from pgai.vectorizer.indexing import DiskANNIndexing, HNSWIndexing
from pgai.vectorizer.indexing import DiskANNIndexing, HNSWIndexing, NoIndexing
from pgai.vectorizer.processing import ProcessingDefault
from pgai.vectorizer.scheduling import TimescaleScheduling
from pgai.vectorizer.scheduling import NoScheduling, TimescaleScheduling


def equivalent_value(a: Any, b: Any, default: Any) -> bool:
Expand Down Expand Up @@ -153,9 +153,9 @@ def from_db_config(cls, config: Ollama) -> "OllamaEmbeddingConfig":
return cls(
model=config.model,
dimensions=config.dimensions,
base_url = config.base_url,
truncate = config.truncate,
keep_alive = config.keep_alive
base_url=config.base_url,
truncate=config.truncate,
keep_alive=config.keep_alive,
)


Expand All @@ -176,6 +176,19 @@ class ChunkingConfig:

@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, ChunkingConfig):
return False
# Handle the separator special case
if self.separator is not None and other.separator is not None:
if isinstance(self.separator, str) and isinstance(other.separator, list):
other = replace(
other,
separator=[other.separator[0]]
if len(other.separator) == 1
else other.separator,
)
elif isinstance(self.separator, list) and isinstance(other.separator, str):
other = replace(other, separator=[other.separator])
return equivalent_dataclass_with_defaults(self, other, self._defaults)

def to_sql_argument(self) -> str:
Expand Down Expand Up @@ -230,6 +243,23 @@ def from_db_config(
)


@dataclass
class NoIndexingConfig:
@override
def __eq__(self, other: object) -> bool:
return isinstance(other, NoIndexingConfig)

def to_sql_argument(self) -> str:
return ", indexing => ai.indexing_none()"

def to_python_arg(self) -> str:
return format_python_arg("indexing", self)

@classmethod
def from_db_config(cls, config: NoIndexing) -> "NoIndexingConfig": # noqa: ARG003
return cls()


@dataclass
class DiskANNIndexingConfig:
min_rows: int | None = None
Expand Down Expand Up @@ -271,7 +301,7 @@ def to_sql_argument(self) -> str:

def to_python_arg(self) -> str:
return format_python_arg("indexing", self)

@classmethod
def from_db_config(cls, config: DiskANNIndexing) -> "DiskANNIndexingConfig":
return cls(
Expand All @@ -289,7 +319,7 @@ def from_db_config(cls, config: DiskANNIndexing) -> "DiskANNIndexingConfig":
@dataclass
class HNSWIndexingConfig:
min_rows: int | None = None
opclass: Literal["vector_cosine_ops", "vector_l2_ops", "vector_ip_ops"] | None = (
opclass: Literal["vector_cosine_ops", "vector_l1_ops", "vector_ip_ops"] | None = (
None
)
m: int | None = None
Expand Down Expand Up @@ -324,7 +354,7 @@ def to_sql_argument(self) -> str:

def to_python_arg(self) -> str:
return format_python_arg("indexing", self)

@classmethod
def from_db_config(cls, config: HNSWIndexing) -> "HNSWIndexingConfig":
return cls(
Expand All @@ -348,6 +378,10 @@ def to_sql_argument(self) -> str:
def to_python_arg(self) -> str:
return format_python_arg("scheduling", self)

@classmethod
def from_db_config(cls, config: NoScheduling) -> "NoSchedulingConfig": # noqa: ARG003
return cls()


@dataclass
class SchedulingConfig:
Expand Down Expand Up @@ -376,7 +410,7 @@ def to_sql_argument(self) -> str:

def to_python_arg(self) -> str:
return format_python_arg("scheduling", self)

@classmethod
def from_db_config(cls, config: TimescaleScheduling) -> "SchedulingConfig":
return cls(
Expand Down Expand Up @@ -430,7 +464,9 @@ class CreateVectorizerParams:
source_table: str | None
embedding: OpenAIEmbeddingConfig | OllamaEmbeddingConfig | None = None
chunking: ChunkingConfig | None = None
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | NoIndexingConfig | None = (
None
)
formatting_template: str | None = None
scheduling: SchedulingConfig | NoSchedulingConfig | None = None
processing: ProcessingConfig | None = None
Expand All @@ -448,6 +484,7 @@ class CreateVectorizerParams:
"enqueue_existing": True,
"processing": ProcessingConfig(),
"scheduling": NoSchedulingConfig(),
"indexing": NoIndexingConfig(),
"queue_schema": "ai",
}

Expand Down Expand Up @@ -555,20 +592,42 @@ def from_db_config(cls, vectorizer: Vectorizer) -> "CreateVectorizerParams":
"""
embedding_config: None | OpenAIEmbeddingConfig | OllamaEmbeddingConfig = None
if isinstance(vectorizer.config.embedding, OpenAI):
embedding_config = OpenAIEmbeddingConfig.from_db_config(vectorizer.config.embedding)
embedding_config = OpenAIEmbeddingConfig.from_db_config(
vectorizer.config.embedding
)
if isinstance(vectorizer.config.embedding, Ollama):
embedding_config = OllamaEmbeddingConfig.from_db_config(vectorizer.config.embedding)
embedding_config = OllamaEmbeddingConfig.from_db_config(
vectorizer.config.embedding
)
chunking_config = ChunkingConfig.from_db_config(vectorizer.config.chunking)
processing_config = ProcessingConfig.from_db_config(vectorizer.config.processing)
indexing_config: None | DiskANNIndexingConfig | HNSWIndexingConfig = None
processing_config = ProcessingConfig.from_db_config(
vectorizer.config.processing
)
indexing_config: (
None | DiskANNIndexingConfig | HNSWIndexingConfig | NoIndexingConfig
) = None
if isinstance(vectorizer.config.indexing, DiskANNIndexing):
indexing_config = DiskANNIndexingConfig.from_db_config(vectorizer.config.indexing)
indexing_config = DiskANNIndexingConfig.from_db_config(
vectorizer.config.indexing
)
if isinstance(vectorizer.config.indexing, HNSWIndexing):
indexing_config = HNSWIndexingConfig.from_db_config(vectorizer.config.indexing)

indexing_config = HNSWIndexingConfig.from_db_config(
vectorizer.config.indexing
)
if isinstance(vectorizer.config.indexing, NoIndexing):
indexing_config = NoIndexingConfig.from_db_config(
vectorizer.config.indexing
)

scheduling_config: None | NoSchedulingConfig | SchedulingConfig = None
if isinstance(vectorizer.config.scheduling, TimescaleScheduling):
scheduling_config = SchedulingConfig.from_db_config(vectorizer.config.scheduling)
scheduling_config = SchedulingConfig.from_db_config(
vectorizer.config.scheduling
)
if isinstance(vectorizer.config.scheduling, NoScheduling):
scheduling_config = NoSchedulingConfig.from_db_config(
vectorizer.config.scheduling
)

# Get formatting template
formatting_template = None
Expand Down
2 changes: 1 addition & 1 deletion projects/pgai/pgai/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
ChunkingConfig,
CreateVectorizerParams,
DiskANNIndexingConfig,
OpenAIEmbeddingConfig,
HNSWIndexingConfig,
NoSchedulingConfig,
OpenAIEmbeddingConfig,
ProcessingConfig,
SchedulingConfig,
)
Expand Down
25 changes: 12 additions & 13 deletions projects/pgai/pgai/vectorizer/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ class DiskANNIndexing(BaseModel):
num_bytes_per_index: The number of bytes to use per index.
num_bytes_per_cluster: The number of bytes to use per cluster.
"""

implementation: Literal["diskann"]
min_rows: int
storage_layout: Literal["memory_optimized", "plain"]
num_neighbors: int
search_list_size: int
max_alpha: float
num_dimensions: int
num_bits_per_dimension: int
storage_layout: Literal["memory_optimized", "plain"] | None = None
num_neighbors: int | None = None
search_list_size: int | None = None
max_alpha: float | None = None
num_dimensions: int | None = None
num_bits_per_dimension: int | None = None
create_when_queue_empty: bool


class HNSWIndexing(BaseModel):
"""
Expand All @@ -44,19 +44,18 @@ class HNSWIndexing(BaseModel):
num_bytes_per_vector: The number of bytes to use per vector.
num_bytes_per_index: The number of bytes to use per index.
"""

implementation: Literal["hnsw"]
min_rows: int
opclass: Literal["vector_cosine_ops", "vector_l2_ops", "vector_ip_ops"]
opclass: Literal["vector_cosine_ops", "vector_l1_ops", "vector_ip_ops"]
m: int
ef_construction: int
create_when_queue_empty: bool


class NoIndexing(BaseModel):
"""
No indexing configuration.
"""

implementation: Literal["none"]

15 changes: 8 additions & 7 deletions projects/pgai/pgai/vectorizer/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@ class TimescaleScheduling(BaseModel):
interval: The interval at which to run the scheduling.
retention_policy: The retention policy to use.
"""

implementation: Literal["timescaledb"]
schedule_interval: datetime.timedelta
initial_start: str
job_id: int
schedule_interval: datetime.timedelta | None = None
initial_start: str | None = None
job_id: int | None = None
fixed_schedule: bool
timezone: str
timezone: str | None = None


class NoScheduling(BaseModel):
"""
No scheduling configuration.
"""
implementation: Literal["none"]

implementation: Literal["none"]
2 changes: 1 addition & 1 deletion projects/pgai/pgai/vectorizer/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .formatting import ChunkValue, PythonTemplate
from .indexing import DiskANNIndexing, HNSWIndexing, NoIndexing
from .processing import ProcessingDefault
from .scheduling import TimescaleScheduling, NoScheduling
from .scheduling import NoScheduling, TimescaleScheduling

logger = structlog.get_logger()

Expand Down
15 changes: 8 additions & 7 deletions projects/pgai/tests/vectorizer/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from pathlib import Path
from typing import Any

import pytest
import tiktoken
import vcr # type:ignore
Expand Down Expand Up @@ -78,11 +79,11 @@ def postgres_container():
def timescale_ha_container():
load_dotenv()
with PostgresContainer(
image="timescale/timescaledb-ha:pg16",
username="tsdbquerier",
password="my-password",
dbname="tsdb",
driver=None,
command="postgres -c shared_preload_libraries=timescaledb"
image="timescale/timescaledb-ha:pg16",
username="tsdbquerier",
password="my-password",
dbname="tsdb",
driver=None,
command="postgres -c shared_preload_libraries=timescaledb",
).with_env("OPENAI_API_KEY", os.environ["OPENAI_API_KEY"]) as postgres:
yield postgres
yield postgres
6 changes: 4 additions & 2 deletions projects/pgai/tests/vectorizer/extensions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def alembic_config(alembic_dir: Path, postgres_container: PostgresContainer) ->


@pytest.fixture
def timescale_alembic_config(alembic_dir: Path, timescale_ha_container: PostgresContainer) -> Config:
def timescale_alembic_config(
alembic_dir: Path, timescale_ha_container: PostgresContainer
) -> Config:
"""Create a configured Alembic environment."""
# Create alembic.ini from template
ini_path = alembic_dir / "alembic.ini"
Expand Down Expand Up @@ -140,7 +142,7 @@ def initialized_engine(
with engine.connect() as conn:
conn.execute(text("DROP SCHEMA public CASCADE; CREATE SCHEMA public;"))
conn.commit()


@pytest.fixture
def initialized_engine_with_timescale(
Expand Down
Loading

0 comments on commit e64a08e

Please sign in to comment.