diff --git a/duckdb-rfuns-r/R/aggregate.R b/duckdb-rfuns-r/R/aggregate.R index 168c964..9f45d74 100644 --- a/duckdb-rfuns-r/R/aggregate.R +++ b/duckdb-rfuns-r/R/aggregate.R @@ -11,18 +11,17 @@ #' @rdname aggregate #' @export rfuns_sum <- function(x, ...) { - rfuns("aggregate", "sum", tibble(x = x), ...) + rfuns("sum", tibble(x = x), ..., op = "aggregate") } #' @rdname aggregate #' @export rfuns_min <- function(x, ...) { - rfuns("aggregate", "min", tibble(x = x), ...) + rfuns("min", tibble(x = x), ..., op = "aggregate") } #' @rdname aggregate #' @export rfuns_max <- function(x, ...) { - rfuns("aggregate", "max", tibble(x = x), ...) + rfuns("max", tibble(x = x), ..., op = "aggregate") } - diff --git a/duckdb-rfuns-r/R/project.R b/duckdb-rfuns-r/R/project.R index bb9222b..5898357 100644 --- a/duckdb-rfuns-r/R/project.R +++ b/duckdb-rfuns-r/R/project.R @@ -4,7 +4,7 @@ #' #' @export rfuns_is.na <- function(x) { - rfuns("project", "is.na", tibble(x = x)) + rfuns("is.na", tibble(x = x)) } #' as.integer() @@ -13,7 +13,7 @@ rfuns_is.na <- function(x) { #' #' @export rfuns_as.integer <- function(x) { - rfuns("project", "as.integer", tibble(x = x)) + rfuns("as.integer", tibble(x = x)) } #' as.numeric() @@ -22,5 +22,5 @@ rfuns_as.integer <- function(x) { #' #' @export rfuns_as.numeric <- function(x) { - rfuns("project", "as.numeric", tibble(x = x)) + rfuns("as.numeric", tibble(x = x)) } diff --git a/duckdb-rfuns-r/R/rfuns.R b/duckdb-rfuns-r/R/rfuns.R index 8b33e23..5b4b647 100644 --- a/duckdb-rfuns-r/R/rfuns.R +++ b/duckdb-rfuns-r/R/rfuns.R @@ -1,4 +1,4 @@ -rfuns <- function(op = c("project", "aggregate"), fun, data, ..., error_call = caller_env()) { +rfuns <- function(fun, data, ..., error_call = caller_env(), op = c("project", "aggregate")) { withr::local_options(list(duckdb.materialize_message = FALSE)) con <- local_duckdb_con() @@ -13,6 +13,7 @@ rfuns <- function(op = c("project", "aggregate"), fun, data, ..., error_call = c ) ) + op <- rlang::arg_match(op, error_call = error_call) result <- switch(op, "project" = duckdb:::rel_project(in_rel, exprs), "aggregate" = duckdb:::rel_aggregate(in_rel, list(), exprs) diff --git a/duckdb-rfuns-r/tests/testthat/test-sum.R b/duckdb-rfuns-r/tests/testthat/test-sum.R index fce0d9c..c99d881 100644 --- a/duckdb-rfuns-r/tests/testthat/test-sum.R +++ b/duckdb-rfuns-r/tests/testthat/test-sum.R @@ -2,13 +2,13 @@ test_that("r_base::sum()", { x <- c(TRUE, TRUE, FALSE, NA) empty <- logical() - expect_equal(rfuns_sum(x, na.rm = TRUE) , sum(x, na.rm = TRUE)) - expect_equal(rfuns_sum(x, na.rm = FALSE), sum(x, na.rm = FALSE)) + expect_equal(rfuns_sum(x, na.rm = TRUE)[1] , sum(x, na.rm = TRUE)) + expect_equal(rfuns_sum(x, na.rm = FALSE)[1], sum(x, na.rm = FALSE)) expect_equal(rfuns_sum(empty, na.rm = FALSE), sum(empty, na.rm = TRUE)) expect_equal(rfuns_sum(empty, na.rm = TRUE) , sum(empty, na.rm = FALSE)) - expect_equal(rfuns_sum(x) , sum(x)) - expect_equal(rfuns_sum(x), sum(x)) + expect_equal(rfuns_sum(x)[1] , sum(x)) + expect_equal(rfuns_sum(x)[1], sum(x)) expect_equal(rfuns_sum(empty), sum(empty)) expect_equal(rfuns_sum(empty), sum(empty)) }) diff --git a/src/include/rfuns_extension.hpp b/src/include/rfuns_extension.hpp index 0940b99..df8271f 100644 --- a/src/include/rfuns_extension.hpp +++ b/src/include/rfuns_extension.hpp @@ -63,7 +63,7 @@ ScalarFunctionSet base_r_is_na(); ScalarFunctionSet base_r_as_integer(); ScalarFunctionSet base_r_as_numeric(); -// sum +// aggregates AggregateFunctionSet base_r_sum(); AggregateFunctionSet base_r_min(); AggregateFunctionSet base_r_max();