From 51c4cb9611768b1cca2b9ddd9123a399e2924fe5 Mon Sep 17 00:00:00 2001 From: Jascha Date: Thu, 28 Nov 2024 13:24:36 -0800 Subject: [PATCH] chore: add more extensive tests for vectorizer creation and add ollama support --- projects/extension/Dockerfile | 8 ++- projects/pgai/pgai/alembic/operations.py | 4 +- projects/pgai/pgai/configuration.py | 60 ++++++++++++++++--- projects/pgai/pgai/sqlalchemy/__init__.py | 4 +- projects/pgai/pgai/vectorizer/embeddings.py | 1 + projects/pgai/pgai/vectorizer/scheduling.py | 6 +- projects/pgai/tests/vectorizer/conftest.py | 2 +- .../002_create_vectorizer.py.template | 4 +- ...2_create_vectorizer_all_fields.py.template | 19 +++--- .../fixtures/models/blog_post.py.template | 4 +- .../models/blog_post_all_fields.py.template | 4 +- .../vectorizer/extensions/test_alembic.py | 45 ++++++++++---- .../extensions/test_alembic_autogenerate.py | 22 +++---- .../vectorizer/extensions/test_sqlalchemy.py | 4 +- .../test_sqlalchemy_large_embeddings.py | 4 +- .../test_sqlalchemy_no_relationship.py | 4 +- .../test_sqlalchemy_relationship.py | 4 +- 17 files changed, 136 insertions(+), 63 deletions(-) diff --git a/projects/extension/Dockerfile b/projects/extension/Dockerfile index fefbfd95..b3367346 100644 --- a/projects/extension/Dockerfile +++ b/projects/extension/Dockerfile @@ -41,7 +41,6 @@ RUN set -e; \ dpkg -i pgvectorscale-postgresql-${PG_MAJOR}_${PGVECTORSCALE_VERSION}-Linux_"$TARGET_ARCH".deb; \ rm pgvectorscale-${PGVECTORSCALE_VERSION}-pg${PG_MAJOR}-"$TARGET_ARCH".zip pgvectorscale-postgresql-${PG_MAJOR}_${PGVECTORSCALE_VERSION}-Linux_"$TARGET_ARCH".deb - ############################################################################### # image for use in testing the pgai library FROM base AS pgai-test-db @@ -51,6 +50,13 @@ WORKDIR /pgai COPY . . RUN just build install +# Create a custom config file in docker-entrypoint-initdb.d +RUN mkdir -p /docker-entrypoint-initdb.d && \ + echo "#!/bin/bash" > /docker-entrypoint-initdb.d/configure-timescaledb.sh && \ + echo "echo \"shared_preload_libraries = 'timescaledb'\" >> \${PGDATA}/postgresql.conf" >> /docker-entrypoint-initdb.d/configure-timescaledb.sh && \ + chmod +x /docker-entrypoint-initdb.d/configure-timescaledb.sh + + ############################################################################### # image for use in extension development diff --git a/projects/pgai/pgai/alembic/operations.py b/projects/pgai/pgai/alembic/operations.py index a550089f..359b4e89 100644 --- a/projects/pgai/pgai/alembic/operations.py +++ b/projects/pgai/pgai/alembic/operations.py @@ -8,7 +8,7 @@ ChunkingConfig, CreateVectorizerParams, DiskANNIndexingConfig, - EmbeddingConfig, + OpenAIEmbeddingConfig, HNSWIndexingConfig, NoSchedulingConfig, ProcessingConfig, @@ -23,7 +23,7 @@ class CreateVectorizerOp(MigrateOperation): def __init__( self, source_table: str | None = None, - embedding: EmbeddingConfig | None = None, + embedding: OpenAIEmbeddingConfig | None = None, chunking: ChunkingConfig | None = None, indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None, formatting_template: str | None = None, diff --git a/projects/pgai/pgai/configuration.py b/projects/pgai/pgai/configuration.py index 5e96bbc7..b189e88d 100644 --- a/projects/pgai/pgai/configuration.py +++ b/projects/pgai/pgai/configuration.py @@ -10,7 +10,7 @@ LangChainCharacterTextSplitter, LangChainRecursiveCharacterTextSplitter, ) -from pgai.vectorizer.embeddings import OpenAI +from pgai.vectorizer.embeddings import OpenAI, Ollama from pgai.vectorizer.formatting import ChunkValue, PythonTemplate from pgai.vectorizer.indexing import DiskANNIndexing, HNSWIndexing from pgai.vectorizer.processing import ProcessingDefault @@ -82,7 +82,7 @@ def format_python_arg(config_type: str, instance: Any) -> str: @dataclass -class EmbeddingConfig: +class OpenAIEmbeddingConfig: model: str dimensions: int chat_user: str | None = None @@ -109,7 +109,7 @@ def to_python_arg(self) -> str: return format_python_arg("embedding", self) @classmethod - def from_db_config(cls, openai_config: OpenAI) -> "EmbeddingConfig": + def from_db_config(cls, openai_config: OpenAI) -> "OpenAIEmbeddingConfig": return cls( model=openai_config.model, dimensions=openai_config.dimensions or 1536, @@ -118,6 +118,47 @@ def from_db_config(cls, openai_config: OpenAI) -> "EmbeddingConfig": ) +@dataclass +class OllamaEmbeddingConfig: + model: str + dimensions: int + base_url: str | None = None + truncate: bool | None = None + keep_alive: str | None = None + + _defaults = {"dimensions": 1536, "api_key_name": "OPENAI_API_KEY"} + + @override + def __eq__(self, other: object) -> bool: + return equivalent_dataclass_with_defaults(self, other, self._defaults) + + def to_sql_argument(self) -> str: + params = [ + f"'{self.model}'", + str(self.dimensions), + ] + if self.base_url: + params.append(f"base_url=>'{self.base_url}'") + if self.truncate is False: + params.append("truncate=>false") + if self.keep_alive: + params.append(f"keep_alive=>'{self.keep_alive}'") + return f", embedding => ai.embedding_ollama({', '.join(params)})" + + def to_python_arg(self) -> str: + return format_python_arg("embedding", self) + + @classmethod + def from_db_config(cls, config: Ollama) -> "OllamaEmbeddingConfig": + return cls( + model=config.model, + dimensions=config.dimensions, + base_url = config.base_url, + truncate = config.truncate, + keep_alive = config.keep_alive + ) + + @dataclass class ChunkingConfig: chunk_column: str @@ -339,7 +380,7 @@ def to_python_arg(self) -> str: @classmethod def from_db_config(cls, config: TimescaleScheduling) -> "SchedulingConfig": return cls( - schedule_interval=config.schedule_interval, + schedule_interval=str(config.schedule_interval), initial_start=config.initial_start, fixed_schedule=config.fixed_schedule, timezone=config.timezone, @@ -387,7 +428,7 @@ def format_bool_param(name: str, value: bool) -> str: @dataclass class CreateVectorizerParams: source_table: str | None - embedding: EmbeddingConfig | None = None + embedding: OpenAIEmbeddingConfig | OllamaEmbeddingConfig | None = None chunking: ChunkingConfig | None = None indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None formatting_template: str | None = None @@ -411,7 +452,7 @@ class CreateVectorizerParams: } # These fields are hard to compare - ignored_fields = ("queue_table", "grant_to") + ignored_fields = ("queue_table", "grant_to", "scheduling") @override def __eq__(self, other: object) -> bool: @@ -512,8 +553,11 @@ def from_db_config(cls, vectorizer: Vectorizer) -> "CreateVectorizerParams": Returns: CreateVectorizerParams: A new instance configured from database settings """ - - embedding_config = EmbeddingConfig.from_db_config(vectorizer.config.embedding) + embedding_config: None | OpenAIEmbeddingConfig | OllamaEmbeddingConfig = None + if isinstance(vectorizer.config.embedding, OpenAI): + embedding_config = OpenAIEmbeddingConfig.from_db_config(vectorizer.config.embedding) + if isinstance(vectorizer.config.embedding, Ollama): + embedding_config = OllamaEmbeddingConfig.from_db_config(vectorizer.config.embedding) chunking_config = ChunkingConfig.from_db_config(vectorizer.config.chunking) processing_config = ProcessingConfig.from_db_config(vectorizer.config.processing) indexing_config: None | DiskANNIndexingConfig | HNSWIndexingConfig = None diff --git a/projects/pgai/pgai/sqlalchemy/__init__.py b/projects/pgai/pgai/sqlalchemy/__init__.py index 2f6d331a..8245e692 100644 --- a/projects/pgai/pgai/sqlalchemy/__init__.py +++ b/projects/pgai/pgai/sqlalchemy/__init__.py @@ -8,7 +8,7 @@ ChunkingConfig, CreateVectorizerParams, DiskANNIndexingConfig, - EmbeddingConfig, + OpenAIEmbeddingConfig, HNSWIndexingConfig, NoSchedulingConfig, ProcessingConfig, @@ -41,7 +41,7 @@ class EmbeddingModel(DeclarativeBase, Generic[T]): class VectorizerField: def __init__( self, - embedding: EmbeddingConfig, + embedding: OpenAIEmbeddingConfig, chunking: ChunkingConfig, formatting_template: str | None = None, indexing: DiskANNIndexingConfig | HNSWIndexingConfig | None = None, diff --git a/projects/pgai/pgai/vectorizer/embeddings.py b/projects/pgai/pgai/vectorizer/embeddings.py index 246b4f5c..c48bd4e6 100644 --- a/projects/pgai/pgai/vectorizer/embeddings.py +++ b/projects/pgai/pgai/vectorizer/embeddings.py @@ -526,6 +526,7 @@ class Ollama(BaseModel, Embedder): implementation: Literal["ollama"] model: str + dimensions: int base_url: str | None = None truncate: bool = True options: OllamaOptions | None = None diff --git a/projects/pgai/pgai/vectorizer/scheduling.py b/projects/pgai/pgai/vectorizer/scheduling.py index 21166d5b..c8d96cef 100644 --- a/projects/pgai/pgai/vectorizer/scheduling.py +++ b/projects/pgai/pgai/vectorizer/scheduling.py @@ -1,3 +1,4 @@ +import datetime from typing import Literal from pydantic import BaseModel @@ -12,9 +13,10 @@ class TimescaleScheduling(BaseModel): retention_policy: The retention policy to use. """ - implementation: Literal["timescale"] - schedule_interval: str + implementation: Literal["timescaledb"] + schedule_interval: datetime.timedelta initial_start: str + job_id: int fixed_schedule: bool timezone: str diff --git a/projects/pgai/tests/vectorizer/conftest.py b/projects/pgai/tests/vectorizer/conftest.py index fbd283c0..a447f24f 100644 --- a/projects/pgai/tests/vectorizer/conftest.py +++ b/projects/pgai/tests/vectorizer/conftest.py @@ -1,7 +1,6 @@ import os from pathlib import Path from typing import Any - import pytest import tiktoken import vcr # type:ignore @@ -84,5 +83,6 @@ def timescale_ha_container(): password="my-password", dbname="tsdb", driver=None, + command="postgres -c shared_preload_libraries=timescaledb" ).with_env("OPENAI_API_KEY", os.environ["OPENAI_API_KEY"]) as postgres: yield postgres \ No newline at end of file diff --git a/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer.py.template b/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer.py.template index d2fda147..dfd89edb 100644 --- a/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer.py.template +++ b/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer.py.template @@ -7,7 +7,7 @@ Create Date: {create_date} from alembic import op from pgai.alembic import CreateVectorizerOp from pgai.configuration import ( - EmbeddingConfig, + OpenAIEmbeddingConfig, ChunkingConfig ) from sqlalchemy import text @@ -22,7 +22,7 @@ def upgrade(): op.create_vectorizer( 'blog', target_table='blog_embeddings', - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model='text-embedding-3-small', dimensions=768 ), diff --git a/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer_all_fields.py.template b/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer_all_fields.py.template index 96a0c690..c01fc9dc 100644 --- a/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer_all_fields.py.template +++ b/projects/pgai/tests/vectorizer/extensions/fixtures/migrations/002_create_vectorizer_all_fields.py.template @@ -7,7 +7,7 @@ Create Date: {create_date} from alembic import op from pgai.alembic import CreateVectorizerOp, DropVectorizerOp import sqlalchemy as sa -from pgai.configuration import ChunkingConfig, DiskANNIndexingConfig, EmbeddingConfig, ProcessingConfig, SchedulingConfig +from pgai.configuration import ChunkingConfig, DiskANNIndexingConfig, OpenAIEmbeddingConfig, ProcessingConfig, SchedulingConfig # revision identifiers revision = '{revision_id}' @@ -25,7 +25,7 @@ def upgrade(): ) op.create_vectorizer( 'blog_posts', - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model='text-embedding-3-small', dimensions=768, chat_user='test_user', @@ -38,16 +38,13 @@ def upgrade(): separator=' ', is_separator_regex=True ), - indexing=DiskANNIndexingConfig( - min_rows=10, - storage_layout='plain', - num_neighbors=5, - search_list_size=10, - max_alpha=0.5, - num_dimensions=10, - num_bits_per_dimension=10, - create_when_queue_empty=False + scheduling=SchedulingConfig( + schedule_interval= "1h", + initial_start= "2022-01-01T00:00:00Z", + fixed_schedule= True, + timezone= "UTC" ), + {indexing}, processing=ProcessingConfig( batch_size=10, concurrency=5 diff --git a/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post.py.template b/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post.py.template index 2989c666..4dc4c7dd 100644 --- a/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post.py.template +++ b/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post.py.template @@ -2,7 +2,7 @@ from sqlalchemy.orm import declarative_base from sqlalchemy import Column, Integer, Text from pgai.sqlalchemy import ( VectorizerField, - EmbeddingConfig, + OpenAIEmbeddingConfig, ChunkingConfig ) @@ -16,7 +16,7 @@ class BlogPost(Base): content = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="{model}", dimensions={dimensions} ), diff --git a/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post_all_fields.py.template b/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post_all_fields.py.template index 34481540..6e25fcac 100644 --- a/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post_all_fields.py.template +++ b/projects/pgai/tests/vectorizer/extensions/fixtures/models/blog_post_all_fields.py.template @@ -4,7 +4,7 @@ from sqlalchemy import Column, Integer, Text from pgai.configuration import DiskANNIndexingConfig, SchedulingConfig, ProcessingConfig from pgai.sqlalchemy import ( VectorizerField, - EmbeddingConfig, + OpenAIEmbeddingConfig, ChunkingConfig ) @@ -18,7 +18,7 @@ class BlogPost(Base): content = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-small", dimensions=768, chat_user="test_user", diff --git a/projects/pgai/tests/vectorizer/extensions/test_alembic.py b/projects/pgai/tests/vectorizer/extensions/test_alembic.py index 62f872b9..b790e98f 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_alembic.py +++ b/projects/pgai/tests/vectorizer/extensions/test_alembic.py @@ -1,10 +1,14 @@ +from datetime import timedelta from pathlib import Path from alembic.command import downgrade, upgrade from alembic.config import Config from sqlalchemy import Engine, inspect, text +from pgai.configuration import DiskANNIndexingConfig from pgai.vectorizer import Vectorizer +from pgai.vectorizer.indexing import DiskANNIndexing +from pgai.vectorizer.scheduling import TimescaleScheduling from tests.vectorizer.extensions.conftest import load_template @@ -130,24 +134,36 @@ def test_vectorizer_migration( def test_vectorizer_migration_all_fields( - timescale_alembic_config: Config, - initialized_engine_with_timescale: Engine, + alembic_config: Config, + initialized_engine: Engine, cleanup_modules: None, # noqa: ARG001 ): """Test vectorizer creation with a bunch of fields""" - migrations_dir = Path(timescale_alembic_config.get_main_option("script_location")) # type: ignore + migrations_dir = Path(alembic_config.get_main_option("script_location")) # type: ignore versions_dir = migrations_dir / "versions" - with initialized_engine_with_timescale.connect() as conn: + with initialized_engine.connect() as conn: conn.execute( text( """ CREATE schema timescale; + CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE; """ ) ) + conn.commit() # First migration - create blog table + indexing_config = DiskANNIndexingConfig( + min_rows=10, + storage_layout='plain', + num_neighbors=5, + search_list_size=10, + max_alpha=0.5, + num_dimensions=10, + num_bits_per_dimension=10, + create_when_queue_empty=False + ) blog_content = load_template( "migrations/001_create_blog_table.py.template", revision_id="001", @@ -165,21 +181,28 @@ def test_vectorizer_migration_all_fields( revises="001", create_date="2024-03-19 10:01:00.000000", down_revision="001", + indexing=indexing_config.to_python_arg() ) with open(versions_dir / "002_create_vectorizer.py", "w") as f: f.write(vectorizer_content) # Run upgrade - upgrade(timescale_alembic_config, "head") + upgrade(alembic_config, "head") # Verify vectorizer exists - with initialized_engine_with_timescale.connect() as conn: - result = conn.execute( + with initialized_engine.connect() as conn: + rows = conn.execute( text(""" select pg_catalog.to_jsonb(v) as vectorizer from ai.vectorizer v """) ).fetchall() - assert len(result) == 1 - parsed_vectorizer = Vectorizer.model_validate(row.vectorizer) # type: ignore - assert parsed_vectorizer.target_table == "blog" - \ No newline at end of file + assert len(rows) == 1 + parsed_vectorizer = Vectorizer.model_validate(rows[0].vectorizer) # type: ignore + assert parsed_vectorizer.target_table == "blog_posts_embedding" + assert isinstance(parsed_vectorizer.config.scheduling, TimescaleScheduling) + assert parsed_vectorizer.config.scheduling.fixed_schedule == True + assert parsed_vectorizer.config.scheduling.schedule_interval == timedelta(hours=1) + + assert isinstance(parsed_vectorizer.config.indexing, DiskANNIndexing) + assert parsed_vectorizer.config.indexing.min_rows == 10 + assert parsed_vectorizer.config.indexing.storage_layout == 'plain' \ No newline at end of file diff --git a/projects/pgai/tests/vectorizer/extensions/test_alembic_autogenerate.py b/projects/pgai/tests/vectorizer/extensions/test_alembic_autogenerate.py index f58b7f18..a9c364c3 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_alembic_autogenerate.py +++ b/projects/pgai/tests/vectorizer/extensions/test_alembic_autogenerate.py @@ -166,7 +166,7 @@ def test_vectorizer_all_fields_autogeneration( assert ( "from pgai.configuration import ChunkingConfig, DiskANNIndexingConfig," - " EmbeddingConfig, ProcessingConfig, SchedulingConfig" + " OpenAIEmbeddingConfig, ProcessingConfig, SchedulingConfig" ) in migration_contents # Verify vectorizer creation and basic config @@ -174,7 +174,7 @@ def test_vectorizer_all_fields_autogeneration( assert "'blog_posts'" in migration_contents # Verify embedding config - assert "embedding=EmbeddingConfig" in migration_contents + assert "embedding=OpenAIEmbeddingConfig" in migration_contents assert "model='text-embedding-3-small'" in migration_contents assert "dimensions=768" in migration_contents assert "chat_user='test_user'" in migration_contents @@ -244,7 +244,7 @@ def test_multiple_vectorizer_fields_autogeneration( model_content = """ from sqlalchemy.orm import declarative_base from sqlalchemy import Column, Integer, Text -from pgai.sqlalchemy import VectorizerField, EmbeddingConfig, ChunkingConfig +from pgai.sqlalchemy import VectorizerField, OpenAIEmbeddingConfig, ChunkingConfig Base = declarative_base() @@ -257,7 +257,7 @@ class BlogPost(Base): summary = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-small", dimensions=768 ), @@ -270,7 +270,7 @@ class BlogPost(Base): ) summary_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-large", dimensions=1536 ), @@ -362,7 +362,7 @@ def test_multiple_vectorizer_fields_change_autogeneration( model_content = """ from sqlalchemy.orm import declarative_base from sqlalchemy import Column, Integer, Text -from pgai.sqlalchemy import VectorizerField, EmbeddingConfig, ChunkingConfig +from pgai.sqlalchemy import VectorizerField, OpenAIEmbeddingConfig, ChunkingConfig Base = declarative_base() @@ -375,7 +375,7 @@ class BlogPost(Base): summary = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-small", dimensions=768 ), @@ -388,7 +388,7 @@ class BlogPost(Base): ) summary_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-small", dimensions=768 ), @@ -429,7 +429,7 @@ class BlogPost(Base): modified_model_content = """ from sqlalchemy.orm import declarative_base from sqlalchemy import Column, Integer, Text -from pgai.sqlalchemy import VectorizerField, EmbeddingConfig, ChunkingConfig +from pgai.sqlalchemy import VectorizerField, OpenAIEmbeddingConfig, ChunkingConfig Base = declarative_base() @@ -442,7 +442,7 @@ class BlogPost(Base): summary = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-large", dimensions=1536 ), @@ -455,7 +455,7 @@ class BlogPost(Base): ) summary_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-small", dimensions=768 ), diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py index d577ffd7..ce3adb11 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py @@ -7,7 +7,7 @@ from pgai.cli import vectorizer_worker from pgai.configuration import ( ChunkingConfig, - EmbeddingConfig, + OpenAIEmbeddingConfig, ) from pgai.sqlalchemy import VectorizerField @@ -25,7 +25,7 @@ class BlogPost(Base): content = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig(model="text-embedding-3-small", dimensions=768), + embedding=OpenAIEmbeddingConfig(model="text-embedding-3-small", dimensions=768), chunking=ChunkingConfig( chunk_column="content", chunk_size=500, chunk_overlap=50 ), diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py index c3528715..eddc0c16 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py @@ -8,7 +8,7 @@ from pgai.cli import vectorizer_worker from pgai.configuration import ( ChunkingConfig, - EmbeddingConfig, + OpenAIEmbeddingConfig, ) from pgai.sqlalchemy import VectorizerField @@ -23,7 +23,7 @@ class BlogPost(Base): title = Column(Text, nullable=False) content = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig( + embedding=OpenAIEmbeddingConfig( model="text-embedding-3-large", dimensions=1536, ), diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_no_relationship.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_no_relationship.py index 5ec2a126..268f38e8 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_no_relationship.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_no_relationship.py @@ -8,7 +8,7 @@ from pgai.cli import vectorizer_worker from pgai.configuration import ( ChunkingConfig, - EmbeddingConfig, + OpenAIEmbeddingConfig, ) from pgai.sqlalchemy import VectorizerField @@ -23,7 +23,7 @@ class BlogPost(Base): title = Column(Text, nullable=False) content = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig(model="text-embedding-3-small", dimensions=768), + embedding=OpenAIEmbeddingConfig(model="text-embedding-3-small", dimensions=768), chunking=ChunkingConfig( chunk_column="content", chunk_size=500, chunk_overlap=50 ), diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py index 8a16affe..a977598d 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py @@ -8,7 +8,7 @@ from pgai.cli import vectorizer_worker from pgai.configuration import ( ChunkingConfig, - EmbeddingConfig, + OpenAIEmbeddingConfig, ) from pgai.sqlalchemy import EmbeddingModel, VectorizerField @@ -23,7 +23,7 @@ class BlogPost(Base): title = Column(Text, nullable=False) content = Column(Text, nullable=False) content_embeddings = VectorizerField( - embedding=EmbeddingConfig(model="text-embedding-3-small", dimensions=768), + embedding=OpenAIEmbeddingConfig(model="text-embedding-3-small", dimensions=768), chunking=ChunkingConfig( chunk_column="content", chunk_size=500, chunk_overlap=50 ),