diff --git a/compose.yaml b/compose.yaml index fb2728accec9..60c15bc6a949 100644 --- a/compose.yaml +++ b/compose.yaml @@ -542,7 +542,7 @@ services: - impala risingwave: - image: ghcr.io/risingwavelabs/risingwave:nightly-20240122 + image: ghcr.io/risingwavelabs/risingwave:nightly-20240204 command: "standalone --meta-opts=\" \ --advertise-addr 0.0.0.0:5690 \ --backend mem \ diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py index 9456b00688b9..badce2233664 100644 --- a/ibis/backends/risingwave/__init__.py +++ b/ibis/backends/risingwave/__init__.py @@ -19,12 +19,31 @@ from ibis import util from ibis.backends.postgres import Backend as PostgresBackend from ibis.backends.risingwave.compiler import RisingwaveCompiler +from ibis.util import experimental if TYPE_CHECKING: import pandas as pd import pyarrow as pa +def data_and_encode_format(data_format, encode_format, encode_properties): + res = "" + if data_format is not None: + res = res + " FORMAT " + data_format.upper() + if encode_format is not None: + res = res + " ENCODE " + encode_format.upper() + if encode_properties is not None: + res = res + " " + format_properties(encode_properties) + return res + + +def format_properties(props): + tokens = [] + for k, v in props.items(): + tokens.append(f"{k}='{v}'") + return "( {} ) ".format(", ".join(tokens)) + + class Backend(PostgresBackend): name = "risingwave" compiler = RisingwaveCompiler() @@ -110,6 +129,11 @@ def create_table( database: str | None = None, temp: bool = False, overwrite: bool = False, + # TODO(Kexiang): add `append only` + connector_properties: dict | None = None, + data_format: str | None = None, + encode_format: str | None = None, + encode_properties: dict | None = None, ): """Create a table in Risingwave. @@ -131,22 +155,37 @@ def create_table( overwrite If `True`, replace the table if it already exists, otherwise fail if the table exists - + connector_properties + The properties of the sink connector, providing the connector settings to push to the downstream data sink. + Refer https://docs.risingwave.com/docs/current/data-delivery/ for the required properties of different data sink. + data_format + The data format for the new source, e.g., "PLAIN". data_format and encode_format must be specified at the same time. + encode_format + The encode format for the new source, e.g., "JSON". data_format and encode_format must be specified at the same time. + encode_properties + The properties of encode format, providing information like schema registry url. Refer https://docs.risingwave.com/docs/current/sql-create-source/ for more details. + + Returns + ------- + Table + Table expression """ if obj is None and schema is None: raise ValueError("Either `obj` or `schema` must be specified") - if database is not None and database != self.current_database: + if connector_properties is not None and ( + encode_format is None or data_format is None + ): raise com.UnsupportedOperationError( - f"Creating tables in other databases is not supported by {self.name}" + "When creating tables with connector, both encode_format and data_format are required" ) - else: - database = None properties = [] if temp: - properties.append(sge.TemporaryProperty()) + raise com.UnsupportedOperationError( + f"Creating temp tables is not supported by {self.name}" + ) if obj is not None: if not isinstance(obj, ir.Expr): @@ -178,25 +217,35 @@ def create_table( else: temp_name = name - table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted) + table = sg.table(temp_name, db=database, quoted=self.compiler.quoted) target = sge.Schema(this=table, expressions=column_defs) - create_stmt = sge.Create( - kind="TABLE", - this=target, - properties=sge.Properties(expressions=properties), - ) + if connector_properties is None: + create_stmt = sge.Create( + kind="TABLE", + this=target, + properties=sge.Properties(expressions=properties), + ) + else: + create_stmt = sge.Create( + kind="TABLE", + this=target, + properties=sge.Properties( + expressions=sge.Properties.from_dict(connector_properties) + ), + ) + create_stmt = create_stmt.sql(self.dialect) + data_and_encode_format( + data_format, encode_format, encode_properties + ) - this = sg.table(name, catalog=database, quoted=self.compiler.quoted) + this = sg.table(name, db=database, quoted=self.compiler.quoted) with self._safe_raw_sql(create_stmt) as cur: if query is not None: insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect) cur.execute(insert_stmt) if overwrite: - cur.execute( - sge.Drop(kind="TABLE", this=this, exists=True).sql(self.dialect) - ) + self.drop_table(name, database=database, force=True) cur.execute( f"ALTER TABLE {table.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}" ) @@ -268,3 +317,266 @@ def list_databases( databases = list(map(itemgetter(0), cur)) return self._filter_with_like(databases, like) + + @experimental + def create_materialized_view( + self, + name: str, + obj: ir.Table, + *, + database: str | None = None, + overwrite: bool = False, + ) -> ir.Table: + """Create a materialized view. Materialized views can be accessed like a normal table. + + Parameters + ---------- + name + Materialized view name to Create. + obj + The select statement to materialize. + database + Name of the database where the view exists, if not the default + overwrite + Whether to overwrite the existing materialized view with the same name + + Returns + ------- + Table + Table expression + """ + if overwrite: + temp_name = util.gen_name(f"{self.name}_table") + else: + temp_name = name + + table = sg.table(temp_name, db=database, quoted=self.compiler.quoted) + + create_stmt = sge.Create( + this=table, + kind="MATERIALIZED VIEW", + expression=self.compile(obj), + ) + self._register_in_memory_tables(obj) + + with self._safe_raw_sql(create_stmt) as cur: + if overwrite: + target = sg.table(name, db=database).sql(self.dialect) + + self.drop_materialized_view(target, database=database, force=True) + + cur.execute( + f"ALTER MATERIALIZED VIEW {table.sql(self.dialect)} RENAME TO {target}" + ) + + return self.table(name, database=database) + + def drop_materialized_view( + self, + name: str, + *, + database: str | None = None, + force: bool = False, + ) -> None: + """Drop a materialized view. + + Parameters + ---------- + name + Materialized view name to drop. + database + Name of the database where the view exists, if not the default. + force + If `False`, an exception is raised if the view does not exist. + """ + src = sge.Drop( + this=sg.table(name, db=database, quoted=self.compiler.quoted), + kind="MATERIALIZED VIEW", + exists=force, + ) + with self._safe_raw_sql(src): + pass + + def create_source( + self, + name: str, + schema: ibis.Schema, + *, + database: str | None = None, + connector_properties: dict, + data_format: str, + encode_format: str, + encode_properties: dict | None = None, + ) -> ir.Table: + """Creating a source. + + Parameters + ---------- + name + Source name to Create. + schema + The schema for the new Source. + database + Name of the database where the source exists, if not the default. + connector_properties + The properties of the source connector, providing the connector settings to access the upstream data source. + Refer https://docs.risingwave.com/docs/current/data-ingestion/ for the required properties of different data source. + data_format + The data format for the new source, e.g., "PLAIN". data_format and encode_format must be specified at the same time. + encode_format + The encode format for the new source, e.g., "JSON". data_format and encode_format must be specified at the same time. + encode_properties + The properties of encode format, providing information like schema registry url. Refer https://docs.risingwave.com/docs/current/sql-create-source/ for more details. + + Returns + ------- + Table + Table expression + """ + column_defs = [ + sge.ColumnDef( + this=sg.to_identifier(colname, quoted=self.compiler.quoted), + kind=self.compiler.type_mapper.from_ibis(typ), + constraints=( + None + if typ.nullable + else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())] + ), + ) + for colname, typ in schema.items() + ] + + table = sg.table(name, db=database, quoted=self.compiler.quoted) + target = sge.Schema(this=table, expressions=column_defs) + + create_stmt = sge.Create( + kind="SOURCE", + this=target, + properties=sge.Properties( + expressions=sge.Properties.from_dict(connector_properties) + ), + ) + + create_stmt = create_stmt.sql(self.dialect) + data_and_encode_format( + data_format, encode_format, encode_properties + ) + + with self._safe_raw_sql(create_stmt): + pass + + return self.table(name, database=database) + + def drop_source( + self, + name: str, + *, + database: str | None = None, + force: bool = False, + ) -> None: + """Drop a Source. + + Parameters + ---------- + name + Source name to drop. + database + Name of the database where the view exists, if not the default. + force + If `False`, an exception is raised if the source does not exist. + """ + src = sge.Drop( + this=sg.table(name, db=database, quoted=self.compiler.quoted), + kind="SOURCE", + exists=force, + ) + with self._safe_raw_sql(src): + pass + + def create_sink( + self, + name: str, + sink_from: str | None = None, + connector_properties: dict | None = None, + *, + obj: ir.Table | None = None, + database: str | None = None, + data_format: str | None = None, + encode_format: str | None = None, + encode_properties: dict | None = None, + ) -> None: + """Creating a sink. + + Parameters + ---------- + name + Sink name to Create. + sink_from + The table or materialized view name to sink from. Only one of `sink_from` or `obj` can be + provided. + connector_properties + The properties of the sink connector, providing the connector settings to push to the downstream data sink. + Refer https://docs.risingwave.com/docs/current/data-delivery/ for the required properties of different data sink. + obj + An Ibis table expression that will be used to extract the schema and the data of the new table. Only one of `sink_from` or `obj` can be provided. + database + Name of the database where the source exists, if not the default. + data_format + The data format for the new source, e.g., "PLAIN". data_format and encode_format must be specified at the same time. + encode_format + The encode format for the new source, e.g., "JSON". data_format and encode_format must be specified at the same time. + encode_properties + The properties of encode format, providing information like schema registry url. Refer https://docs.risingwave.com/docs/current/sql-create-source/ for more details. + """ + table = sg.table(name, db=database, quoted=self.compiler.quoted) + if sink_from is None and obj is None: + raise ValueError("Either `sink_from` or `obj` must be specified") + if sink_from is not None and obj is not None: + raise ValueError("Only one of `sink_from` or `obj` can be specified") + + if (encode_format is None) != (data_format is None): + raise com.UnsupportedArgumentError( + "When creating sinks, both encode_format and data_format must be provided, or neither should be" + ) + + if sink_from is not None: + create_stmt = f"CREATE SINK {table.sql(self.dialect)} FROM {sink_from}" + else: + create_stmt = sge.Create( + this=table, + kind="SINK", + expression=self.compile(obj), + ).sql(self.dialect) + create_stmt = ( + create_stmt + + " WITH " + + format_properties(connector_properties) + + data_and_encode_format(data_format, encode_format, encode_properties) + ) + with self._safe_raw_sql(create_stmt): + pass + + def drop_sink( + self, + name: str, + *, + database: str | None = None, + force: bool = False, + ) -> None: + """Drop a Sink. + + Parameters + ---------- + name + Sink name to drop. + database + Name of the database where the view exists, if not the default. + force + If `False`, an exception is raised if the source does not exist. + """ + src = sge.Drop( + this=sg.table(name, db=database, quoted=self.compiler.quoted), + kind="SINK", + exists=force, + ) + with self._safe_raw_sql(src): + pass diff --git a/ibis/backends/risingwave/tests/test_streaming.py b/ibis/backends/risingwave/tests/test_streaming.py new file mode 100644 index 000000000000..344788c51e2f --- /dev/null +++ b/ibis/backends/risingwave/tests/test_streaming.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import time + +import pandas as pd +import pandas.testing as tm +import pytest + +import ibis +from ibis import util + + +@pytest.mark.parametrize( + "column", + ["string_col", "double_col", "date_string_col", "timestamp_col"], +) +def test_simple_mv(con, alltypes, column): + expr = alltypes[[column]].distinct().order_by(column) + mv_name = util.gen_name("alltypes_mv") + mv = con.create_materialized_view(mv_name, expr, overwrite=True) + expected = expr.limit(5).execute() + result = mv.order_by(column).limit(5).execute() + tm.assert_frame_equal(result, expected) + con.drop_materialized_view(mv_name) + + +def test_mv_on_simple_source(con): + sc_name = util.gen_name("simple_sc") + schema = ibis.schema([("v", "int32")]) + # use Risingwave's internal data generator to imitate a upstream data source + connector_properties = { + "connector": "datagen", + "fields.v.kind": "sequence", + "fields.v.start": "1", + "fields.v.end": "10", + "datagen.rows.per.second": "10000", + "datagen.split.num": "1", + } + source = con.create_source( + sc_name, + schema, + connector_properties=connector_properties, + data_format="PLAIN", + encode_format="JSON", + ) + expr = source["v"].sum() + mv_name = util.gen_name("simple_mv") + mv = con.create_materialized_view(mv_name, expr) + # sleep 3s to make sure the data has been generated by the source and consumed by the MV. + time.sleep(3) + result = mv.execute() + expected = pd.DataFrame({"Sum(v)": [55]}) + tm.assert_frame_equal(result, expected) + con.drop_materialized_view(mv_name) + con.drop_source(sc_name) + + +def test_mv_on_table_with_connector(con): + tblc_name = util.gen_name("simple_table_with_connector") + schema = ibis.schema([("v", "int32")]) + # use Risingwave's internal data generator to imitate a upstream data source + connector_properties = { + "connector": "datagen", + "fields.v.kind": "sequence", + "fields.v.start": "1", + "fields.v.end": "10", + "datagen.rows.per.second": "10000", + "datagen.split.num": "1", + } + tblc = con.create_table( + name=tblc_name, + obj=None, + schema=schema, + connector_properties=connector_properties, + data_format="PLAIN", + encode_format="JSON", + ) + expr = tblc["v"].sum() + mv_name = util.gen_name("simple_mv") + mv = con.create_materialized_view(mv_name, expr) + # sleep 1 s to make sure the data has been generated by the source and consumed by the MV. + time.sleep(1) + + result_tblc = expr.execute() + assert result_tblc == 55 + + result_mv = mv.execute() + expected = pd.DataFrame({"Sum(v)": [55]}) + tm.assert_frame_equal(result_mv, expected) + con.drop_materialized_view(mv_name) + con.drop_table(tblc_name) + + +def test_sink_from(con, alltypes): + sk_name = util.gen_name("sk_from") + connector_properties = { + "connector": "blackhole", + } + con.create_sink(sk_name, "functional_alltypes", connector_properties) + con.drop_sink(sk_name) + + +def test_sink_as_select(con, alltypes): + sk_name = util.gen_name("sk_as_select") + expr = alltypes[["string_col"]].distinct().order_by("string_col") + connector_properties = { + "connector": "blackhole", + } + con.create_sink(sk_name, None, connector_properties, obj=expr) + con.drop_sink(sk_name) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index cfb69d3f0736..79c84d7075af 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -302,10 +302,10 @@ def test_create_table_from_schema(con, new_schema, temp_table): reason="temporary tables not implemented", raises=NotImplementedError, ) -@pytest.mark.notyet( +@pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, - reason="truncate not supported upstream", + raises=com.UnsupportedOperationError, + reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) @pytest.mark.notimpl( ["flink"], @@ -1205,7 +1205,7 @@ def test_create_table_timestamp(con, temp_table): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, + raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_ref_count(backend, con, alltypes): @@ -1230,7 +1230,7 @@ def test_persist_expression_ref_count(backend, con, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, + raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression(backend, alltypes): @@ -1249,7 +1249,7 @@ def test_persist_expression(backend, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, + raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_contextmanager(backend, alltypes): @@ -1270,7 +1270,7 @@ def test_persist_expression_contextmanager(backend, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, + raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): @@ -1293,7 +1293,7 @@ def test_persist_expression_contextmanager_ref_count(backend, con, alltypes): ) @pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, + raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @@ -1335,7 +1335,7 @@ def test_persist_expression_multiple_refs(backend, con, alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, + raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) def test_persist_expression_repeated_cache(alltypes): @@ -1355,18 +1355,13 @@ def test_persist_expression_repeated_cache(alltypes): @mark.notimpl(["exasol"], reason="Exasol does not support temporary tables") @pytest.mark.never( ["risingwave"], - raises=PsycoPg2InternalError, + raises=com.UnsupportedOperationError, reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", ) @mark.notimpl( ["oracle"], reason="Oracle error message for a missing table/view doesn't include the name of the table", ) -@pytest.mark.never( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="Feature is not yet implemented: CREATE TEMPORARY TABLE", -) def test_persist_expression_release(con, alltypes): non_cached_table = alltypes.mutate( test_column="calculation", other_column="big calc 3" diff --git a/ibis/backends/tests/test_join.py b/ibis/backends/tests/test_join.py index c4e9b82ed937..a02233275b88 100644 --- a/ibis/backends/tests/test_join.py +++ b/ibis/backends/tests/test_join.py @@ -175,6 +175,7 @@ def test_mutate_then_join_no_column_overlap(batting, awards_players): @pytest.mark.notimpl(["druid"]) @pytest.mark.notyet(["dask"], reason="dask doesn't support descending order by") @pytest.mark.notyet(["flink"], reason="Flink doesn't support semi joins") +@pytest.mark.skip("risingwave") # TODO(Kexiang): Risingwave's bug, investigating @pytest.mark.parametrize( "func", [ diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 19341a8c9b07..011adadf568a 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -1590,17 +1590,14 @@ def test_today_from_projection(alltypes): } -@pytest.mark.notimpl(["pandas", "dask", "exasol"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( - ["druid"], raises=PyDruidProgrammingError, reason="SQL parse failed" + ["pandas", "dask", "exasol", "risingwave"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl( - ["oracle"], raises=OracleDatabaseError, reason="ORA-00936 missing expression" + ["druid"], raises=PyDruidProgrammingError, reason="SQL parse failed" ) @pytest.mark.notimpl( - ["risingwave"], - raises=com.OperationNotDefinedError, - reason="function make_date(integer, integer, integer) does not exist", + ["oracle"], raises=OracleDatabaseError, reason="ORA-00936 missing expression" ) def test_date_literal(con, backend): expr = ibis.date(2022, 2, 4) @@ -1631,11 +1628,6 @@ def test_date_literal(con, backend): raises=com.OperationNotDefinedError, ) @pytest.mark.notyet(["impala"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", -) def test_timestamp_literal(con, backend): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0) result = con.execute(expr) @@ -1689,11 +1681,6 @@ def test_timestamp_literal(con, backend): ", , , )" ), ) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function make_timestamp(integer, integer, integer, integer, integer, integer) does not exist", -) def test_timestamp_with_timezone_literal(con, timezone, expected): expr = ibis.timestamp(2022, 2, 4, 16, 20, 0).cast(dt.Timestamp(timezone=timezone)) result = con.execute(expr) @@ -1722,11 +1709,6 @@ def test_timestamp_with_timezone_literal(con, timezone, expected): ["clickhouse", "impala", "exasol"], raises=com.OperationNotDefinedError ) @pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function make_time(integer, integer, integer) does not exist", -) def test_time_literal(con, backend): expr = ibis.time(16, 20, 0) result = con.execute(expr) @@ -1845,7 +1827,9 @@ def test_interval_literal(con, backend): assert con.execute(expr.typeof()) == INTERVAL_BACKEND_TYPES[backend_name] -@pytest.mark.notimpl(["pandas", "dask", "exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl( + ["pandas", "dask", "exasol", "risingwave"], raises=com.OperationNotDefinedError +) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -1854,11 +1838,6 @@ def test_interval_literal(con, backend): @pytest.mark.broken( ["oracle"], raises=OracleDatabaseError, reason="ORA-00936: missing expression" ) -@pytest.mark.notimpl( - ["risingwave"], - raises=com.OperationNotDefinedError, - reason="function make_date(integer, integer, integer) does not exist", -) def test_date_column_from_ymd(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.date(c.year(), c.month(), c.day()) @@ -1879,11 +1858,6 @@ def test_date_column_from_ymd(backend, con, alltypes, df): reason="StringColumn' object has no attribute 'year'", ) @pytest.mark.notyet(["impala", "oracle"], raises=com.OperationNotDefinedError) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function make_timestamp(smallint, smallint, smallint, smallint, smallint, smallint) does not exist", -) def test_timestamp_column_from_ymdhms(backend, con, alltypes, df): c = alltypes.timestamp_col expr = ibis.timestamp(