Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use type checkers in backend-*.R files #1556

Merged
merged 11 commits into from
Nov 12, 2024
8 changes: 4 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# dbplyr (development version)

* Tightened argument checks for Snowflake SQL translations. These changes should
result in more informative errors in cases where code already failed; if you
see errors with code that used to run without issue, please report them to
the package authors (@simonpcouch, #1554).
* Tightened argument checks for SQL translations. These changes should
result in more informative errors in cases where code already failed, possibly
silently; if you see errors with code that used to run correctly, please report
them to the package authors (@simonpcouch, #1554, #1555).

* `clock::add_years()` translates to correct SQL on Spark (@ablack3, #1510).

Expand Down
2 changes: 2 additions & 0 deletions R/backend-.R
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ base_scalar <- sql_translator(
# base R
nchar = sql_prefix("LENGTH", 1),
nzchar = function(x, keepNA = FALSE) {
check_bool(keepNA)
if (keepNA) {
exp <- expr(!!x != "")
translate_sql(!!exp, con = sql_current_con())
Expand Down Expand Up @@ -281,6 +282,7 @@ base_scalar <- sql_translator(
str_c = sql_paste(""),
str_sub = sql_str_sub("SUBSTR"),
str_like = function(string, pattern, ignore_case = TRUE) {
check_bool(ignore_case)
if (isTRUE(ignore_case)) {
sql_expr(!!string %LIKE% !!pattern)
} else {
Expand Down
1 change: 1 addition & 0 deletions R/backend-hive.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ sql_table_analyze.Hive <- function(con, table, ...) {

#' @export
sql_query_set_op.Hive <- function(con, x, y, method, ..., all = FALSE, lvl = 0) {
check_bool(all)
# parentheses are not allowed
method <- paste0(method, if (all) " ALL")
glue_sql2(
Expand Down
28 changes: 12 additions & 16 deletions R/backend-mssql.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ simulate_mssql <- function(version = "15.0") {
conflict = c("error", "ignore"),
returning_cols = NULL,
method = NULL) {
check_string(method, allow_null = TRUE)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
method <- method %||% "where_not_exists"
arg_match(method, "where_not_exists", error_arg = "method")
# https://stackoverflow.com/questions/25969/insert-into-values-select-from
Expand Down Expand Up @@ -177,6 +178,7 @@ simulate_mssql <- function(version = "15.0") {
...,
returning_cols = NULL,
method = NULL) {
check_string(method, allow_null = TRUE)
method <- method %||% "merge"
arg_match(method, "merge", error_arg = "method")

Expand Down Expand Up @@ -333,6 +335,7 @@ simulate_mssql <- function(version = "15.0") {
second = function(x) sql_expr(DATEPART(SECOND, !!x)),

month = function(x, label = FALSE, abbr = TRUE) {
check_bool(label)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be tempted to move the check for check_unsupported_arg() next to this check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting. I moved that check_unsupported_arg() call to be evaluated unconditionally before realizing that the only allowed value for abbr is FALSE and the default is TRUE. Bringing the check out of if (label) would mean that month(x) would fail by default on SQL server. I opted not to make your suggested change here and will leave this as unresolved for a day or two to give you a chance to reply before merging.

if (!label) {
sql_expr(DATEPART(MONTH, !!x))
} else {
Expand All @@ -342,6 +345,7 @@ simulate_mssql <- function(version = "15.0") {
},

quarter = function(x, with_year = FALSE, fiscal_start = 1) {
check_bool(with_year)
check_unsupported_arg(fiscal_start, 1, backend = "SQL Server")

if (with_year) {
Expand All @@ -361,6 +365,7 @@ simulate_mssql <- function(version = "15.0") {
sql_expr(DATEADD(YEAR, !!n, !!x))
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
sql_expr(DATEFROMPARTS(!!year, !!month, !!day))
},
get_year = function(x) {
Expand All @@ -373,27 +378,16 @@ simulate_mssql <- function(version = "15.0") {
sql_expr(DATEPART(DAY, !!x))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
check_unsupported_arg(n, allowed = 1L)

sql_expr(DATEDIFF(DAY, !!start, !!end))
},

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr(DATEDIFF(DAY, !!time2, !!time1))
}
Expand Down Expand Up @@ -545,7 +539,7 @@ mssql_version <- function(con) {

#' @export
`sql_returning_cols.Microsoft SQL Server` <- function(con, cols, table, ...) {
stopifnot(table %in% c("DELETED", "INSERTED"))
arg_match(table, values = c("DELETED", "INSERTED"))
returning_cols <- sql_named_cols(con, cols, table = table)

sql_clause("OUTPUT", returning_cols)
Expand Down Expand Up @@ -637,6 +631,8 @@ mssql_bit_int_bit <- function(f) {

#' @export
`db_sql_render.Microsoft SQL Server` <- function(con, sql, ..., cte = FALSE, use_star = TRUE) {
check_unsupported_arg(cte, allowed = FALSE)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
check_unsupported_arg(use_star, allowed = TRUE)
# Post-process WHERE to cast logicals from BIT to BOOLEAN
sql$lazy_query <- purrr::modify_tree(
sql$lazy_query,
Expand Down
28 changes: 15 additions & 13 deletions R/backend-postgres.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ postgres_grepl <- function(pattern,
check_unsupported_arg(perl, FALSE, backend = "PostgreSQL")
check_unsupported_arg(fixed, FALSE, backend = "PostgreSQL")
check_unsupported_arg(useBytes, FALSE, backend = "PostgreSQL")
check_bool(ignore.case)

if (ignore.case) {
sql_expr(((!!x)) %~*% ((!!pattern)))
Expand Down Expand Up @@ -123,6 +124,7 @@ sql_translation.PqConnection <- function(con) {
},
# https://www.postgresql.org/docs/current/functions-matching.html
str_like = function(string, pattern, ignore_case = TRUE) {
check_bool(ignore_case)
if (isTRUE(ignore_case)) {
sql_expr(!!string %ILIKE% !!pattern)
} else {
Expand Down Expand Up @@ -162,6 +164,9 @@ sql_translation.PqConnection <- function(con) {
sql_expr(EXTRACT(DAY %FROM% !!x))
},
wday = function(x, label = FALSE, abbr = TRUE, week_start = NULL) {
check_bool(label)
check_bool(abbr)
check_number_whole(week_start, allow_null = TRUE)
if (!label) {
week_start <- week_start %||% getOption("lubridate.week.start", 7)
offset <- as.integer(7 - week_start)
Expand All @@ -182,6 +187,8 @@ sql_translation.PqConnection <- function(con) {
sql_expr(EXTRACT(WEEK %FROM% !!x))
},
month = function(x, label = FALSE, abbr = TRUE) {
check_bool(label)
check_bool(abbr)
if (!label) {
sql_expr(EXTRACT(MONTH %FROM% !!x))
} else {
Expand All @@ -193,6 +200,7 @@ sql_translation.PqConnection <- function(con) {
}
},
quarter = function(x, with_year = FALSE, fiscal_start = 1) {
check_bool(with_year)
check_unsupported_arg(fiscal_start, 1, backend = "PostgreSQL")

if (with_year) {
Expand Down Expand Up @@ -246,17 +254,14 @@ sql_translation.PqConnection <- function(con) {
glue_sql2(sql_current_con(), "({.col x} + {.val n}*INTERVAL'1 year')")
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
sql_expr(make_date(!!year, !!month, !!day))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
check_unsupported_arg(n, allowed = 1L)

sql_expr(!!end - !!start)
},
Expand All @@ -272,13 +277,8 @@ sql_translation.PqConnection <- function(con) {

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr((CAST(!!time1 %AS% DATE) - CAST(!!time2 %AS% DATE)))
},
Expand Down Expand Up @@ -344,6 +344,7 @@ sql_query_insert.PqConnection <- function(con,
...,
returning_cols = NULL,
method = NULL) {
check_string(method, allow_null = TRUE)
method <- method %||% "on_conflict"
arg_match(method, c("on_conflict", "where_not_exists"), error_arg = "method")
if (method == "where_not_exists") {
Expand Down Expand Up @@ -379,6 +380,7 @@ sql_query_upsert.PqConnection <- function(con,
...,
returning_cols = NULL,
method = NULL) {
check_string(method, allow_null = TRUE)
method <- method %||% "on_conflict"
arg_match(method, c("cte_update", "on_conflict"), error_arg = "method")

Expand Down
20 changes: 5 additions & 15 deletions R/backend-redshift.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ sql_translation.RedshiftConnection <- function(con) {
sql_expr(DATEADD(YEAR, !!n, !!x))
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
glue_sql2(sql_current_con(), "TO_DATE(CAST({.val year} AS TEXT) || '-' CAST({.val month} AS TEXT) || '-' || CAST({.val day} AS TEXT)), 'YYYY-MM-DD')")
},
get_year = function(x) {
Expand All @@ -84,27 +85,16 @@ sql_translation.RedshiftConnection <- function(con) {
sql_expr(DATE_PART('day', !!x))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
check_unsupported_arg(n, allowed = 1L)

sql_expr(DATEDIFF(DAY, !!start, !!end))
},

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr(DATEDIFF(DAY, !!time2, !!time1))
}
Expand Down
23 changes: 7 additions & 16 deletions R/backend-spark-sql.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL")
sql_expr(add_months(!!x, !!n*12))
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
sql_expr(make_date(!!year, !!month, !!day))
},
get_year = function(x) {
Expand All @@ -59,27 +60,16 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL")
sql_expr(date_part('DAY', !!x))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
check_unsupported_arg(n, allowed = 1L)

sql_expr(datediff(!!end, !!start))
},

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr(datediff(!!time2, !!time1))
}
Expand Down Expand Up @@ -153,7 +143,8 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL")
indexes = list(),
analyze = TRUE,
in_transaction = FALSE) {

check_bool(overwrite)
check_bool(temporary)
sql <- glue_sql2(
con,
"CREATE ", if (overwrite) "OR REPLACE ",
Expand Down
2 changes: 2 additions & 0 deletions R/backend-teradata.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ sql_translation.Teradata <- function(con) {
row_number = win_rank("ROW_NUMBER", empty_order = TRUE),
weighted.mean = function(x, w, na.rm = T) {
# nocov start
check_unsupported_arg(na.rm, allowed = TRUE)
win_over(
sql_expr(SUM((!!x * !!w))/SUM(!!w)),
win_current_group(),
Expand Down Expand Up @@ -191,6 +192,7 @@ sql_translation.Teradata <- function(con) {
},
weighted.mean = function(x, w, na.rm = T) {
# nocov start
check_unsupported_arg(na.rm, allowed = TRUE)
win_over(
sql_expr(SUM((!!x * !!w))/SUM(!!w)),
win_current_group(),
Expand Down
1 change: 1 addition & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ res_warn_incomplete <- function(res, hint = "n = -1") {
}

add_temporary_prefix <- function(con, table, temporary = TRUE) {
check_bool(temporary)
check_table_path(table)

if (!temporary) {
Expand Down
25 changes: 7 additions & 18 deletions tests/testthat/_snaps/backend-mssql.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,34 @@
test_translate_sql(date_count_between(date_column_1, date_column_2, "year"))
Condition
Error in `date_count_between()`:
! `precision` must be "day" on SQL backends.
! `precision = "year"` isn't supported on database backends.
i It must be "day" instead.

---

Code
test_translate_sql(date_count_between(date_column_1, date_column_2, "day", n = 5))
Condition
Error in `date_count_between()`:
! `n` must be "1" on SQL backends.
! `n = 5` isn't supported on database backends.
i It must be 1 instead.

# difftime is translated correctly

Code
test_translate_sql(difftime(start_date, end_date, units = "auto"))
Condition
Error in `difftime()`:
! The only supported value for `units` on SQL backends is "days"
! `units = "auto"` isn't supported on database backends.
i It must be "days" instead.

---

Code
test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))
Condition
Error in `difftime()`:
! The `tz` argument is not supported for SQL backends.
! Argument `tz` isn't supported on database backends.

# convert between bit and boolean as needed

Expand Down Expand Up @@ -494,20 +497,6 @@
FROM `df`
ORDER BY `y`

# can copy_to() and compute() with temporary tables (#438)

Code
db <- copy_to(con, df, name = unique_table_name(), temporary = TRUE)
Message
Created a temporary table named #dbplyr_{tmp}

---

Code
db2 <- db %>% mutate(y = x + 1) %>% compute()
Message
Created a temporary table named #dbplyr_{tmp}

# add prefix to temporary table

Code
Expand Down
Loading
Loading