Skip to content

Commit

Permalink
chore: add more extensive tests for vectorizer creation and add ollam…
Browse files Browse the repository at this point in the history
…a support
  • Loading branch information
Askir committed Nov 28, 2024
1 parent 128f836 commit 51c4cb9
Show file tree
Hide file tree
Showing 17 changed files with 136 additions and 63 deletions.
8 changes: 7 additions & 1 deletion projects/extension/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ RUN set -e; \
dpkg -i pgvectorscale-postgresql-${PG_MAJOR}_${PGVECTORSCALE_VERSION}-Linux_"$TARGET_ARCH".deb; \
rm pgvectorscale-${PGVECTORSCALE_VERSION}-pg${PG_MAJOR}-"$TARGET_ARCH".zip pgvectorscale-postgresql-${PG_MAJOR}_${PGVECTORSCALE_VERSION}-Linux_"$TARGET_ARCH".deb


###############################################################################
# image for use in testing the pgai library
FROM base AS pgai-test-db
Expand All @@ -51,6 +50,13 @@ WORKDIR /pgai
COPY . .
RUN just build install

# Create a custom config file in docker-entrypoint-initdb.d
RUN mkdir -p /docker-entrypoint-initdb.d && \
echo "#!/bin/bash" > /docker-entrypoint-initdb.d/configure-timescaledb.sh && \
echo "echo \"shared_preload_libraries = 'timescaledb'\" >> \${PGDATA}/postgresql.conf" >> /docker-entrypoint-initdb.d/configure-timescaledb.sh && \
chmod +x /docker-entrypoint-initdb.d/configure-timescaledb.sh



###############################################################################
# image for use in extension development
Expand Down
4 changes: 2 additions & 2 deletions projects/pgai/pgai/alembic/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ChunkingConfig,
CreateVectorizerParams,
DiskANNIndexingConfig,
EmbeddingConfig,
OpenAIEmbeddingConfig,
HNSWIndexingConfig,
NoSchedulingConfig,
ProcessingConfig,
Expand All @@ -23,7 +23,7 @@ class CreateVectorizerOp(MigrateOperation):
def __init__(
self,
source_table: str | None = None,
embedding: EmbeddingConfig | None = None,
embedding: OpenAIEmbeddingConfig | None = None,
chunking: ChunkingConfig | None = None,
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None,
formatting_template: str | None = None,
Expand Down
60 changes: 52 additions & 8 deletions projects/pgai/pgai/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
LangChainCharacterTextSplitter,
LangChainRecursiveCharacterTextSplitter,
)
from pgai.vectorizer.embeddings import OpenAI
from pgai.vectorizer.embeddings import OpenAI, Ollama
from pgai.vectorizer.formatting import ChunkValue, PythonTemplate
from pgai.vectorizer.indexing import DiskANNIndexing, HNSWIndexing
from pgai.vectorizer.processing import ProcessingDefault
Expand Down Expand Up @@ -82,7 +82,7 @@ def format_python_arg(config_type: str, instance: Any) -> str:


@dataclass
class EmbeddingConfig:
class OpenAIEmbeddingConfig:
model: str
dimensions: int
chat_user: str | None = None
Expand All @@ -109,7 +109,7 @@ def to_python_arg(self) -> str:
return format_python_arg("embedding", self)

@classmethod
def from_db_config(cls, openai_config: OpenAI) -> "EmbeddingConfig":
def from_db_config(cls, openai_config: OpenAI) -> "OpenAIEmbeddingConfig":
return cls(
model=openai_config.model,
dimensions=openai_config.dimensions or 1536,
Expand All @@ -118,6 +118,47 @@ def from_db_config(cls, openai_config: OpenAI) -> "EmbeddingConfig":
)


@dataclass
class OllamaEmbeddingConfig:
model: str
dimensions: int
base_url: str | None = None
truncate: bool | None = None
keep_alive: str | None = None

_defaults = {"dimensions": 1536, "api_key_name": "OPENAI_API_KEY"}

@override
def __eq__(self, other: object) -> bool:
return equivalent_dataclass_with_defaults(self, other, self._defaults)

def to_sql_argument(self) -> str:
params = [
f"'{self.model}'",
str(self.dimensions),
]
if self.base_url:
params.append(f"base_url=>'{self.base_url}'")
if self.truncate is False:
params.append("truncate=>false")
if self.keep_alive:
params.append(f"keep_alive=>'{self.keep_alive}'")
return f", embedding => ai.embedding_ollama({', '.join(params)})"

def to_python_arg(self) -> str:
return format_python_arg("embedding", self)

@classmethod
def from_db_config(cls, config: Ollama) -> "OllamaEmbeddingConfig":
return cls(
model=config.model,
dimensions=config.dimensions,
base_url = config.base_url,
truncate = config.truncate,
keep_alive = config.keep_alive
)


@dataclass
class ChunkingConfig:
chunk_column: str
Expand Down Expand Up @@ -339,7 +380,7 @@ def to_python_arg(self) -> str:
@classmethod
def from_db_config(cls, config: TimescaleScheduling) -> "SchedulingConfig":
return cls(
schedule_interval=config.schedule_interval,
schedule_interval=str(config.schedule_interval),
initial_start=config.initial_start,
fixed_schedule=config.fixed_schedule,
timezone=config.timezone,
Expand Down Expand Up @@ -387,7 +428,7 @@ def format_bool_param(name: str, value: bool) -> str:
@dataclass
class CreateVectorizerParams:
source_table: str | None
embedding: EmbeddingConfig | None = None
embedding: OpenAIEmbeddingConfig | OllamaEmbeddingConfig | None = None
chunking: ChunkingConfig | None = None
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None
formatting_template: str | None = None
Expand All @@ -411,7 +452,7 @@ class CreateVectorizerParams:
}

# These fields are hard to compare
ignored_fields = ("queue_table", "grant_to")
ignored_fields = ("queue_table", "grant_to", "scheduling")

@override
def __eq__(self, other: object) -> bool:
Expand Down Expand Up @@ -512,8 +553,11 @@ def from_db_config(cls, vectorizer: Vectorizer) -> "CreateVectorizerParams":
Returns:
CreateVectorizerParams: A new instance configured from database settings
"""

embedding_config = EmbeddingConfig.from_db_config(vectorizer.config.embedding)
embedding_config: None | OpenAIEmbeddingConfig | OllamaEmbeddingConfig = None
if isinstance(vectorizer.config.embedding, OpenAI):
embedding_config = OpenAIEmbeddingConfig.from_db_config(vectorizer.config.embedding)
if isinstance(vectorizer.config.embedding, Ollama):
embedding_config = OllamaEmbeddingConfig.from_db_config(vectorizer.config.embedding)
chunking_config = ChunkingConfig.from_db_config(vectorizer.config.chunking)
processing_config = ProcessingConfig.from_db_config(vectorizer.config.processing)
indexing_config: None | DiskANNIndexingConfig | HNSWIndexingConfig = None
Expand Down
4 changes: 2 additions & 2 deletions projects/pgai/pgai/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ChunkingConfig,
CreateVectorizerParams,
DiskANNIndexingConfig,
EmbeddingConfig,
OpenAIEmbeddingConfig,
HNSWIndexingConfig,
NoSchedulingConfig,
ProcessingConfig,
Expand Down Expand Up @@ -41,7 +41,7 @@ class EmbeddingModel(DeclarativeBase, Generic[T]):
class VectorizerField:
def __init__(
self,
embedding: EmbeddingConfig,
embedding: OpenAIEmbeddingConfig,
chunking: ChunkingConfig,
formatting_template: str | None = None,
indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None,
Expand Down
1 change: 1 addition & 0 deletions projects/pgai/pgai/vectorizer/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class Ollama(BaseModel, Embedder):

implementation: Literal["ollama"]
model: str
dimensions: int
base_url: str | None = None
truncate: bool = True
options: OllamaOptions | None = None
Expand Down
6 changes: 4 additions & 2 deletions projects/pgai/pgai/vectorizer/scheduling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
from typing import Literal

from pydantic import BaseModel
Expand All @@ -12,9 +13,10 @@ class TimescaleScheduling(BaseModel):
retention_policy: The retention policy to use.
"""

implementation: Literal["timescale"]
schedule_interval: str
implementation: Literal["timescaledb"]
schedule_interval: datetime.timedelta
initial_start: str
job_id: int
fixed_schedule: bool
timezone: str

Expand Down
2 changes: 1 addition & 1 deletion projects/pgai/tests/vectorizer/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from pathlib import Path
from typing import Any

import pytest
import tiktoken
import vcr # type:ignore
Expand Down Expand Up @@ -84,5 +83,6 @@ def timescale_ha_container():
password="my-password",
dbname="tsdb",
driver=None,
command="postgres -c shared_preload_libraries=timescaledb"
).with_env("OPENAI_API_KEY", os.environ["OPENAI_API_KEY"]) as postgres:
yield postgres
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Create Date: {create_date}
from alembic import op
from pgai.alembic import CreateVectorizerOp
from pgai.configuration import (
EmbeddingConfig,
OpenAIEmbeddingConfig,
ChunkingConfig
)
from sqlalchemy import text
Expand All @@ -22,7 +22,7 @@ def upgrade():
op.create_vectorizer(
'blog',
target_table='blog_embeddings',
embedding=EmbeddingConfig(
embedding=OpenAIEmbeddingConfig(
model='text-embedding-3-small',
dimensions=768
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Create Date: {create_date}
from alembic import op
from pgai.alembic import CreateVectorizerOp, DropVectorizerOp
import sqlalchemy as sa
from pgai.configuration import ChunkingConfig, DiskANNIndexingConfig, EmbeddingConfig, ProcessingConfig, SchedulingConfig
from pgai.configuration import ChunkingConfig, DiskANNIndexingConfig, OpenAIEmbeddingConfig, ProcessingConfig, SchedulingConfig

# revision identifiers
revision = '{revision_id}'
Expand All @@ -25,7 +25,7 @@ def upgrade():
)
op.create_vectorizer(
'blog_posts',
embedding=EmbeddingConfig(
embedding=OpenAIEmbeddingConfig(
model='text-embedding-3-small',
dimensions=768,
chat_user='test_user',
Expand All @@ -38,16 +38,13 @@ def upgrade():
separator=' ',
is_separator_regex=True
),
indexing=DiskANNIndexingConfig(
min_rows=10,
storage_layout='plain',
num_neighbors=5,
search_list_size=10,
max_alpha=0.5,
num_dimensions=10,
num_bits_per_dimension=10,
create_when_queue_empty=False
scheduling=SchedulingConfig(
schedule_interval= "1h",
initial_start= "2022-01-01T00:00:00Z",
fixed_schedule= True,
timezone= "UTC"
),
{indexing},
processing=ProcessingConfig(
batch_size=10,
concurrency=5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, Integer, Text
from pgai.sqlalchemy import (
VectorizerField,
EmbeddingConfig,
OpenAIEmbeddingConfig,
ChunkingConfig
)

Expand All @@ -16,7 +16,7 @@ class BlogPost(Base):
content = Column(Text, nullable=False)

content_embeddings = VectorizerField(
embedding=EmbeddingConfig(
embedding=OpenAIEmbeddingConfig(
model="{model}",
dimensions={dimensions}
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from sqlalchemy import Column, Integer, Text
from pgai.configuration import DiskANNIndexingConfig, SchedulingConfig, ProcessingConfig
from pgai.sqlalchemy import (
VectorizerField,
EmbeddingConfig,
OpenAIEmbeddingConfig,
ChunkingConfig
)

Expand All @@ -18,7 +18,7 @@ class BlogPost(Base):
content = Column(Text, nullable=False)

content_embeddings = VectorizerField(
embedding=EmbeddingConfig(
embedding=OpenAIEmbeddingConfig(
model="text-embedding-3-small",
dimensions=768,
chat_user="test_user",
Expand Down
Loading

0 comments on commit 51c4cb9

Please sign in to comment.