Skip to content

Commit

Permalink
chore: add extensive tests for alembic migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
Askir committed Nov 29, 2024
1 parent 51c4cb9 commit 8acabab
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 124 deletions.
45 changes: 39 additions & 6 deletions projects/pgai/pgai/configuration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import textwrap
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, replace
from typing import Any, Literal, Protocol, runtime_checkable

from alembic.autogenerate.api import AutogenContext
Expand All @@ -12,9 +12,9 @@
)
from pgai.vectorizer.embeddings import OpenAI, Ollama
from pgai.vectorizer.formatting import ChunkValue, PythonTemplate
from pgai.vectorizer.indexing import DiskANNIndexing, HNSWIndexing
from pgai.vectorizer.indexing import DiskANNIndexing, HNSWIndexing, NoIndexing
from pgai.vectorizer.processing import ProcessingDefault
from pgai.vectorizer.scheduling import TimescaleScheduling
from pgai.vectorizer.scheduling import TimescaleScheduling, NoScheduling


def equivalent_value(a: Any, b: Any, default: Any) -> bool:
Expand Down Expand Up @@ -176,6 +176,14 @@ class ChunkingConfig:

@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, ChunkingConfig):
return False
# Handle the separator special case
if self.separator is not None and other.separator is not None:
if isinstance(self.separator, str) and isinstance(other.separator, list):
other = replace(other, separator=[other.separator[0]] if len(other.separator) == 1 else other.separator)
elif isinstance(self.separator, list) and isinstance(other.separator, str):
other = replace(other, separator=[other.separator])
return equivalent_dataclass_with_defaults(self, other, self._defaults)

def to_sql_argument(self) -> str:
Expand Down Expand Up @@ -229,6 +237,22 @@ def from_db_config(
is_separator_regex=config.is_separator_regex,
)

@dataclass
class NoIndexingConfig:
@override
def __eq__(self, other: object) -> bool:
return isinstance(other, NoIndexingConfig)

def to_sql_argument(self) -> str:
return ", indexing => ai.indexing_none()"

def to_python_arg(self) -> str:
return format_python_arg("indexing", self)

@classmethod
def from_db_config(cls, config: NoIndexing) -> "NoIndexingConfig":
return cls()


@dataclass
class DiskANNIndexingConfig:
Expand Down Expand Up @@ -289,7 +313,7 @@ def from_db_config(cls, config: DiskANNIndexing) -> "DiskANNIndexingConfig":
@dataclass
class HNSWIndexingConfig:
min_rows: int | None = None
opclass: Literal["vector_cosine_ops", "vector_l2_ops", "vector_ip_ops"] | None = (
opclass: Literal["vector_cosine_ops", "vector_l1_ops", "vector_ip_ops"] | None = (
None
)
m: int | None = None
Expand Down Expand Up @@ -347,6 +371,10 @@ def to_sql_argument(self) -> str:

def to_python_arg(self) -> str:
return format_python_arg("scheduling", self)

@classmethod
def from_db_config(cls, config: NoScheduling) -> "NoSchedulingConfig":
return cls()


@dataclass
Expand Down Expand Up @@ -430,7 +458,7 @@ class CreateVectorizerParams:
source_table: str | None
embedding: OpenAIEmbeddingConfig | OllamaEmbeddingConfig | None = None
chunking: ChunkingConfig | None = None
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | NoIndexingConfig | None = None
formatting_template: str | None = None
scheduling: SchedulingConfig | NoSchedulingConfig | None = None
processing: ProcessingConfig | None = None
Expand All @@ -448,6 +476,7 @@ class CreateVectorizerParams:
"enqueue_existing": True,
"processing": ProcessingConfig(),
"scheduling": NoSchedulingConfig(),
"indexing": NoIndexingConfig(),
"queue_schema": "ai",
}

Expand Down Expand Up @@ -560,15 +589,19 @@ def from_db_config(cls, vectorizer: Vectorizer) -> "CreateVectorizerParams":
embedding_config = OllamaEmbeddingConfig.from_db_config(vectorizer.config.embedding)
chunking_config = ChunkingConfig.from_db_config(vectorizer.config.chunking)
processing_config = ProcessingConfig.from_db_config(vectorizer.config.processing)
indexing_config: None | DiskANNIndexingConfig | HNSWIndexingConfig = None
indexing_config: None | DiskANNIndexingConfig | HNSWIndexingConfig | NoIndexingConfig = None
if isinstance(vectorizer.config.indexing, DiskANNIndexing):
indexing_config = DiskANNIndexingConfig.from_db_config(vectorizer.config.indexing)
if isinstance(vectorizer.config.indexing, HNSWIndexing):
indexing_config = HNSWIndexingConfig.from_db_config(vectorizer.config.indexing)
if isinstance(vectorizer.config.indexing, NoIndexing):
indexing_config = NoIndexingConfig.from_db_config(vectorizer.config.indexing)

scheduling_config: None | NoSchedulingConfig | SchedulingConfig = None
if isinstance(vectorizer.config.scheduling, TimescaleScheduling):
scheduling_config = SchedulingConfig.from_db_config(vectorizer.config.scheduling)
if isinstance(vectorizer.config.scheduling, NoScheduling):
scheduling_config = NoSchedulingConfig.from_db_config(vectorizer.config.scheduling)

# Get formatting template
formatting_template = None
Expand Down
14 changes: 7 additions & 7 deletions projects/pgai/pgai/vectorizer/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ class DiskANNIndexing(BaseModel):

implementation: Literal["diskann"]
min_rows: int
storage_layout: Literal["memory_optimized", "plain"]
num_neighbors: int
search_list_size: int
max_alpha: float
num_dimensions: int
num_bits_per_dimension: int
storage_layout: Literal["memory_optimized", "plain"] | None = None
num_neighbors: int | None = None
search_list_size: int | None = None
max_alpha: float | None = None
num_dimensions: int | None = None
num_bits_per_dimension: int | None = None
create_when_queue_empty: bool


Expand All @@ -47,7 +47,7 @@ class HNSWIndexing(BaseModel):

implementation: Literal["hnsw"]
min_rows: int
opclass: Literal["vector_cosine_ops", "vector_l2_ops", "vector_ip_ops"]
opclass: Literal["vector_cosine_ops", "vector_l1_ops", "vector_ip_ops"]
m: int
ef_construction: int
create_when_queue_empty: bool
Expand Down
8 changes: 4 additions & 4 deletions projects/pgai/pgai/vectorizer/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class TimescaleScheduling(BaseModel):
"""

implementation: Literal["timescaledb"]
schedule_interval: datetime.timedelta
initial_start: str
job_id: int
schedule_interval: datetime.timedelta | None = None
initial_start: str | None = None
job_id: int | None = None
fixed_schedule: bool
timezone: str
timezone: str | None = None

class NoScheduling(BaseModel):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Create Date: {create_date}
from alembic import op
from pgai.alembic import CreateVectorizerOp, DropVectorizerOp
import sqlalchemy as sa
from pgai.configuration import ChunkingConfig, DiskANNIndexingConfig, OpenAIEmbeddingConfig, ProcessingConfig, SchedulingConfig
from pgai.configuration import *

# revision identifiers
revision = '{revision_id}'
Expand All @@ -25,30 +25,16 @@ def upgrade():
)
op.create_vectorizer(
'blog_posts',
embedding=OpenAIEmbeddingConfig(
model='text-embedding-3-small',
dimensions=768,
chat_user='test_user',
api_key_name='test_key'
),
chunking=ChunkingConfig(
chunk_column='content',
chunk_size=500,
chunk_overlap=10,
separator=' ',
is_separator_regex=True
),
scheduling=SchedulingConfig(
schedule_interval= "1h",
initial_start= "2022-01-01T00:00:00Z",
fixed_schedule= True,
timezone= "UTC"
),
{indexing},
processing=ProcessingConfig(
batch_size=10,
concurrency=5
),
{embedding}
,
{chunking}
,
{scheduling}
,
{indexing}
,
{processing}
,
target_schema='timescale',
target_table='blog_posts_embedding',
view_schema='timescale',
Expand Down
83 changes: 1 addition & 82 deletions projects/pgai/tests/vectorizer/extensions/test_alembic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from datetime import timedelta
from pathlib import Path

from alembic.command import downgrade, upgrade
from alembic.config import Config
from sqlalchemy import Engine, inspect, text

from pgai.configuration import DiskANNIndexingConfig
from pgai.vectorizer import Vectorizer
from pgai.vectorizer.indexing import DiskANNIndexing
from pgai.vectorizer.scheduling import TimescaleScheduling
from tests.vectorizer.extensions.conftest import load_template


Expand Down Expand Up @@ -129,80 +124,4 @@ def test_vectorizer_migration(
# Verify everything is gone
inspector = inspect(initialized_engine)
tables = inspector.get_table_names()
assert "blog" not in tables



def test_vectorizer_migration_all_fields(
alembic_config: Config,
initialized_engine: Engine,
cleanup_modules: None, # noqa: ARG001
):
"""Test vectorizer creation with a bunch of fields"""
migrations_dir = Path(alembic_config.get_main_option("script_location")) # type: ignore
versions_dir = migrations_dir / "versions"

with initialized_engine.connect() as conn:
conn.execute(
text(
"""
CREATE schema timescale;
CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE;
"""
)
)
conn.commit()

# First migration - create blog table
indexing_config = DiskANNIndexingConfig(
min_rows=10,
storage_layout='plain',
num_neighbors=5,
search_list_size=10,
max_alpha=0.5,
num_dimensions=10,
num_bits_per_dimension=10,
create_when_queue_empty=False
)
blog_content = load_template(
"migrations/001_create_blog_table.py.template",
revision_id="001",
revises="",
create_date="2024-03-19 10:00:00.000000",
down_revision="None",
)
with open(versions_dir / "001_create_blog_table.py", "w") as f:
f.write(blog_content)

# Second migration - create vectorizer
vectorizer_content = load_template(
"migrations/002_create_vectorizer_all_fields.py.template",
revision_id="002",
revises="001",
create_date="2024-03-19 10:01:00.000000",
down_revision="001",
indexing=indexing_config.to_python_arg()
)
with open(versions_dir / "002_create_vectorizer.py", "w") as f:
f.write(vectorizer_content)

# Run upgrade
upgrade(alembic_config, "head")

# Verify vectorizer exists
with initialized_engine.connect() as conn:
rows = conn.execute(
text("""
select pg_catalog.to_jsonb(v) as vectorizer from ai.vectorizer v
""")
).fetchall()
assert len(rows) == 1
parsed_vectorizer = Vectorizer.model_validate(rows[0].vectorizer) # type: ignore
assert parsed_vectorizer.target_table == "blog_posts_embedding"
assert isinstance(parsed_vectorizer.config.scheduling, TimescaleScheduling)
assert parsed_vectorizer.config.scheduling.fixed_schedule == True
assert parsed_vectorizer.config.scheduling.schedule_interval == timedelta(hours=1)

assert isinstance(parsed_vectorizer.config.indexing, DiskANNIndexing)
assert parsed_vectorizer.config.indexing.min_rows == 10
assert parsed_vectorizer.config.indexing.storage_layout == 'plain'
assert "blog" not in tables
Loading

0 comments on commit 8acabab

Please sign in to comment.