diff --git a/dbt/adapters/risingwave/__version__.py b/dbt/adapters/risingwave/__version__.py index 582554e..6aaa73b 100644 --- a/dbt/adapters/risingwave/__version__.py +++ b/dbt/adapters/risingwave/__version__.py @@ -1 +1 @@ -version = "1.7.4" +version = "1.8.0" diff --git a/dbt/adapters/risingwave/connections.py b/dbt/adapters/risingwave/connections.py index 48baa5c..99e2fd7 100644 --- a/dbt/adapters/risingwave/connections.py +++ b/dbt/adapters/risingwave/connections.py @@ -1,13 +1,12 @@ -from contextlib import contextmanager from dataclasses import dataclass -import dbt.exceptions # noqa -from dbt.adapters.base import Credentials -from dbt.adapters.postgres import PostgresConnectionManager, PostgresCredentials -from dbt.helper_types import Port -from dbt.adapters.sql import SQLConnectionManager as connection_cls -from dbt.events import AdapterLogger from typing import Dict, Optional + import psycopg2 +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.postgres.connections import ( + PostgresConnectionManager, + PostgresCredentials, +) logger = AdapterLogger("RisingWave") @@ -68,9 +67,7 @@ def _super_open(cls, connection, extra_kwargs: Optional[Dict[str, str]] = None): search_path = credentials.search_path if search_path is not None and search_path != "": # see https://postgresql.org/docs/9.5/libpq-connect.html - kwargs["options"] = "-c search_path={}".format( - search_path.replace(" ", "\\ ") - ) + kwargs["options"] = "-c search_path={}".format(search_path.replace(" ", "\\ ")) if credentials.sslmode: kwargs["sslmode"] = credentials.sslmode diff --git a/dbt/adapters/risingwave/impl.py b/dbt/adapters/risingwave/impl.py index 034ee74..61ca453 100644 --- a/dbt/adapters/risingwave/impl.py +++ b/dbt/adapters/risingwave/impl.py @@ -1,15 +1,13 @@ +from dbt.adapters.postgres.impl import PostgresAdapter -from typing import Optional, List -from dbt.adapters.sql import SQLAdapter as adapter_cls -from dbt.adapters.base.relation import BaseRelation -from dbt.adapters.risingwave import RisingWaveConnectionManager +from dbt.adapters.risingwave.connections import RisingWaveConnectionManager from dbt.adapters.risingwave.relation import RisingWaveRelation -from dbt.adapters.postgres import PostgresAdapter class RisingWaveAdapter(PostgresAdapter): ConnectionManager = RisingWaveConnectionManager Relation = RisingWaveRelation + def _link_cached_relations(self, manifest): # lack of `pg_depend`, `pg_rewrite` pass diff --git a/dbt/adapters/risingwave/relation.py b/dbt/adapters/risingwave/relation.py index 00412cb..4d6d146 100644 --- a/dbt/adapters/risingwave/relation.py +++ b/dbt/adapters/risingwave/relation.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from typing import Optional, Type -from dbt.adapters.postgres import PostgresRelation -from dbt.dataclass_schema import StrEnum -from dbt.utils import classproperty +from dbt.adapters.postgres.relation import PostgresRelation +from dbt.adapters.utils import classproperty +from dbt_common.dataclass_schema import StrEnum class RisingWaveRelationType(StrEnum): @@ -22,7 +22,6 @@ class RisingWaveRelationType(StrEnum): class RisingWaveRelation(PostgresRelation): type: Optional[RisingWaveRelationType] = None - @classproperty def get_relation_type(cls) -> Type[RisingWaveRelationType]: return RisingWaveRelationType @@ -30,4 +29,4 @@ def get_relation_type(cls) -> Type[RisingWaveRelationType]: # RisingWave has no limitation on relation name length. # We set a relatively large value right now. def relation_max_name_length(self): - return 1024 \ No newline at end of file + return 1024 diff --git a/setup.py b/setup.py index 5e723bc..e74b452 100644 --- a/setup.py +++ b/setup.py @@ -1,20 +1,28 @@ #!/usr/bin/env python -import os +from pathlib import Path + from setuptools import find_namespace_packages, setup -package_name = "dbt-risingwave" -# make sure this always matches dbt/adapters/{adapter}/__version__.py -package_version = "1.7.4" -description = """The RisingWave adapter plugin for dbt""" +# to get the long description +README = Path(__file__).parent / "README.md" +# update the version number in dbt/adapters/risingwave/__version__.py +VERSION = Path(__file__).parent / "dbt/adapters/risingwave/__version__.py" + + +def _plugin_version() -> str: + """ + Pull the package version from the main package version file + """ + attributes = {} + exec(VERSION.read_text(), attributes) + return attributes["version"] -with open(os.path.join(os.path.dirname(__file__), "README.md")) as f: - README = f.read() setup( - name=package_name, - version=package_version, - description=description, - long_description=README, + name="dbt-risingwave", + version=_plugin_version(), + description="The RisingWave adapter plugin for dbt", + long_description=README.read_text(), long_description_content_type="text/markdown", license="http://www.apache.org/licenses/LICENSE-2.0", keywords="dbt RisingWave", @@ -23,5 +31,24 @@ url="https://github.com/risingwavelabs/dbt-risingwave", packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, - install_requires=["dbt-postgres~=1.7.0"], + install_requires=[ + "dbt-postgres~=1.8.0", + "dbt-core~=1.8.0", + # not sure if these are needed due to inheritance from dbt-postgres + # but doesn't hurt to be explicit I suppose + "dbt-common>=1.0.4,<2.0", + "dbt-adapters>=1.1.1,<2.0", + ], + classifiers=[ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], + python_requires=">=3.8", )