Skip to content

Commit

Permalink
Add quantile regression mode (#1209)
Browse files Browse the repository at this point in the history
* add a quantile regression mode to test with

* update type checkers

* avoid confusion with global all_models object

* add quantile_level argument to set_mode()

* initial data for quantreg

* some initial tests

* fix some issues

* enable quantile prediction

* tests for quantreg

* Quantile predictions output constructor (#1191)

* small change to predict checks

* add vctrs for quantiles and test, refactor *_rq_preds

* revise tests

* Apply some of the suggestions from code review

Co-authored-by: Simon P. Couch <[email protected]>

* rename tests on suggestion from code review

* export missing funs from vctrs for formatting

* convert errors to snapshot tests

* pass call through input check

* update snapshots for caller_env

* rename to parsnip_quantiles, add format snapshot tests

* Apply suggestions from @topepo

Co-authored-by: Max Kuhn <[email protected]>

* rename parsnip_quantiles to quantile_pred

* rename parsnip_quantiles to quantile_pred and add vector probability check

* fix: two bugs introduced earlier

* add formatting tests for single quantile

* replace walk with a loop to avoid "Error in map()"

* remove row/col names

* adjust quantile_pred format

* as_tibble method

* updated NEWS file

* add PR number

* small new update

* helper methods

* update docs

* re-enable quantiles prediction for #1203

* update some tests

* no longer needed

* use tibble::new_tibble

* braces

* test as_tibble

* remove print methods

---------

Co-authored-by: Simon P. Couch <[email protected]>
Co-authored-by: Max Kuhn <[email protected]>
Co-authored-by: ‘topepo’ <‘[email protected]’>

* quantile regression updates for new hardhat model (#1207)

* bump hardhat version

* remove parts now in hardhat

* update for new hardhat version

* quantile_levels (plural now)

* news update

* typo

* rename helper function

* run CI on PRs from branches

* forgotten remote

* actions for edited PRs

* plural

* expand branch list

* export function for censored to use

* updated snapshot

* remake snapshot

* Revert "remake snapshot"

This reverts commit 954e326.

* updated snapshot

* Update R/arguments.R

Co-authored-by: Hannah Frick <[email protected]>

* typo

* changes from reviewer feedback

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
Co-authored-by: Hannah Frick <[email protected]>

* Change to `quantile` argument to `quantile levels` (#1208)

* quantile -> quantile_levels for #1203

* defer test until censored updates in new PR

* update docs for quantile_levels

* update test

* disable quantile predictions for surv_reg

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>

* post conflict merge updates

* update news

* version bump and fix typo

* revert GHA branches

* small bug fix

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>
Co-authored-by: Emil Hvitfeldt <[email protected]>

* don't export median

* add call arg

* added documentation on model

* add mode

* convert error to warning

* remove rankdeficient

* added skip

* add deprecated `quantile` arg back in

* remove numeric prediction

---------

Co-authored-by: ‘topepo’ <‘[email protected]’>
Co-authored-by: Daniel McDonald <[email protected]>
Co-authored-by: Simon P. Couch <[email protected]>
Co-authored-by: Hannah Frick <[email protected]>
Co-authored-by: Emil Hvitfeldt <[email protected]>
  • Loading branch information
6 people authored Oct 11, 2024
1 parent 5ce414e commit 297320e
Show file tree
Hide file tree
Showing 35 changed files with 906 additions and 108 deletions.
15 changes: 8 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.2.1.9002
Version: 1.2.1.9003
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand All @@ -25,7 +25,7 @@ Imports:
ggplot2,
globals,
glue,
hardhat (>= 1.4.0),
hardhat (>= 1.4.0.9002),
lifecycle,
magrittr,
pillar,
Expand All @@ -40,8 +40,8 @@ Imports:
vctrs (>= 0.6.0),
withr
Suggests:
C50,
bench,
C50,
covr,
dials (>= 1.1.0),
earth,
Expand Down Expand Up @@ -69,16 +69,17 @@ Suggests:
xgboost (>= 1.5.0.1)
VignetteBuilder:
knitr
Remotes:
r-lib/sparsevctrs,
tidymodels/hardhat
ByteCompile: true
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
LiblineaR, mgcv, nnet, parsnip, quantreg, randomForest, ranger, rpart,
rstanarm, tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
xgboost
Config/rcmdcheck/ignore-inconsequential-notes: true
Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
Remotes:
r-lib/sparsevctrs
RoxygenNote: 7.3.2
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ export(make_classes)
export(make_engine_list)
export(make_seealso_list)
export(mars)
export(matrix_to_quantile_pred)
export(max_mtry_formula)
export(maybe_data_frame)
export(maybe_matrix)
Expand Down
24 changes: 20 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
# parsnip (development version)

## New Features

* A new model mode (`"quantile regression"`) was added. Including:
* A `linear_reg()` engine for `"quantreg"`.
* Predictions are encoded via a custom vector type. See [hardhat::quantile_pred()].
* Predicted quantile levels are designated when the new mode is specified. See `?set_mode`.

* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).

* `fit_xy()` can now take sparse tibbles as data values (#1165).

* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).

* Transitioned package errors and warnings to use cli (#1147 and #1148 by
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
#1161, #1081).
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).

## Other Changes

* Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081).

* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).

* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).

* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
## Bug Fixes

* Ensure that `knit_engine_docs()` has the required packages installed (#1156).

* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).

## Breaking Change

* For quantile prediction, the `quantile` argument to `predict()` has been deprecate in facor of `quantile_levels`. This does not affect models with mode `"quantile regression"`.

* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model.


# parsnip 1.2.1

* Added a missing `tidy()` method for survival analysis glmnet models (#1086).
Expand Down
21 changes: 16 additions & 5 deletions R/aaa-import-standalone-types-check.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Standalone file: do not edit by hand
# Source: <https://github.com/r-lib/rlang/blob/main/R/standalone-types-check.R>
# ----------------------------------------------------------------------
#
# ---
# repo: r-lib/rlang
# file: standalone-types-check.R
Expand All @@ -13,6 +9,9 @@
#
# ## Changelog
#
# 2024-08-15:
# - `check_character()` gains an `allow_na` argument (@martaalcalde, #1724)
#
# 2023-03-13:
# - Improved error messages of number checkers (@teunbrand)
# - Added `allow_infinite` argument to `check_number_whole()` (@mgirlich).
Expand Down Expand Up @@ -461,15 +460,28 @@ check_formula <- function(x,

# Vectors -----------------------------------------------------------------

# TODO: Figure out what to do with logical `NA` and `allow_na = TRUE`

check_character <- function(x,
...,
allow_na = TRUE,
allow_null = FALSE,
arg = caller_arg(x),
call = caller_env()) {

if (!missing(x)) {
if (is_character(x)) {
if (!allow_na && any(is.na(x))) {
abort(
sprintf("`%s` can't contain NA values.", arg),
arg = arg,
call = call
)
}

return(invisible(NULL))
}

if (allow_null && is_null(x)) {
return(invisible(NULL))
}
Expand All @@ -479,7 +491,6 @@ check_character <- function(x,
x,
"a character vector",
...,
allow_na = FALSE,
allow_null = allow_null,
arg = arg,
call = call
Expand Down
6 changes: 3 additions & 3 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Initialize model environments

all_modes <- c("classification", "regression", "censored regression")
all_modes <- c("classification", "regression", "censored regression", "quantile regression")

# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -195,8 +195,8 @@ stop_missing_engine <- function(cls, call) {
}

check_mode_for_new_engine <- function(cls, eng, mode, call = caller_env()) {
all_modes <- get_from_env(paste0(cls, "_modes"))
if (!(mode %in% all_modes)) {
model_modes <- get_from_env(paste0(cls, "_modes"))
if (!(mode %in% model_modes)) {
cli::cli_abort(
"{.val {mode}} is not a known mode for model {.fn {cls}}.",
call = call
Expand Down
17 changes: 17 additions & 0 deletions R/aaa_quantiles.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#' Reformat quantile predictions
#'
#' @param x A matrix of predictions with rows as samples and columns as quantile
#' levels.
#' @param object A parsnip `model_fit` object from a quantile regression model.
#' @keywords internal
#' @export
matrix_to_quantile_pred <- function(x, object) {
if (!is.matrix(x)) {
x <- as.matrix(x)
}
rownames(x) <- NULL
n_pred_quantiles <- ncol(x)
quantile_levels <- object$spec$quantile_levels

tibble::new_tibble(x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels)))
}
24 changes: 21 additions & 3 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ check_eng_args <- function(args, obj, core_args) {
#' set_args(mtry = 3, importance = TRUE) %>%
#' set_mode("regression")
#'
#' linear_reg() %>%
#' set_mode("quantile regression", quantile_levels = c(0.2, 0.5, 0.8))
#' @export
set_args <- function(object, ...) {
UseMethod("set_args")
Expand Down Expand Up @@ -89,12 +91,18 @@ set_args.default <- function(object,...) {

#' @rdname set_args
#' @export
set_mode <- function(object, mode) {
set_mode <- function(object, mode, ...) {
UseMethod("set_mode")
}

#' @rdname set_args
#' @param quantile_levels A vector of values between zero and one (only for the
#' `"quantile regression"` mode); otherwise, it is `NULL`. The model uses these
#' values to appropriately train quantile regression models to make predictions
#' for these values (e.g., `quantile_levels = 0.5` is the median).
#' @export
set_mode.model_spec <- function(object, mode) {
set_mode.model_spec <- function(object, mode, quantile_levels = NULL, ...) {
check_dots_empty()
cls <- class(object)[1]
if (rlang::is_missing(mode)) {
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
Expand All @@ -111,11 +119,21 @@ set_mode.model_spec <- function(object, mode) {

object$mode <- mode
object$user_specified_mode <- TRUE
if (mode == "quantile regression") {
hardhat::check_quantile_levels(quantile_levels)
} else {
if (!is.null(quantile_levels)) {
cli::cli_warn("{.arg quantile_levels} is only used when the mode is
{.val quantile regression}.")
}
}

object$quantile_levels <- quantile_levels
object
}

#' @export
set_mode.default <- function(object, mode) {
set_mode.default <- function(object, mode, ...) {
error_set_object(object, func = "set_mode")

invisible(FALSE)
Expand Down
9 changes: 8 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ fit.model_spec <-
eval_env$formula <- formula
eval_env$weights <- wts

if (!is.null(object$quantile_levels)) {
eval_env$quantile_levels <- object$quantile_levels
}

data <- materialize_sparse_tibble(data, object, "data")

fit_interface <-
Expand All @@ -187,7 +191,6 @@ fit.model_spec <-
with a spark data object."
)


# populate `method` with the details for this model type
object <- add_methods(object, engine = object$engine)

Expand Down Expand Up @@ -295,6 +298,10 @@ fit_xy.model_spec <-
eval_env$y_var <- y_var
eval_env$weights <- weights_to_numeric(case_weights, object)

if (!is.null(object$quantile_levels)) {
eval_env$quantile_levels <- object$quantile_levels
}

# TODO case weights: pass in eval_env not individual elements
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)

Expand Down
4 changes: 2 additions & 2 deletions R/install_packages.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ install_engine_packages <- function(extension = TRUE, extras = TRUE,
}

if (extras) {
rmd_pkgs <- c("tidymodels", "broom.mixed", "glmnet", "Cubist", "xrf", "ape",
"rmarkdown")
rmd_pkgs <- c("ape", "broom.mixed", "Cubist", "glmnet", "quantreg",
"rmarkdown", "tidymodels", "xrf")
engine_packages <- unique(c(engine_packages, rmd_pkgs))
}

Expand Down
46 changes: 46 additions & 0 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set_new_model("linear_reg")

set_model_mode("linear_reg", "regression")
set_model_mode("linear_reg", "quantile regression")

# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -582,3 +583,48 @@ set_pred(
)
)

# ------------------------------------------------------------------------------

set_model_engine(model = "linear_reg", mode = "quantile regression", eng = "quantreg")
set_dependency(model = "linear_reg", eng = "quantreg", pkg = "quantreg", mode = "quantile regression")

set_fit(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
value = list(
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "quantreg", fun = "rq"),
defaults = list(tau = expr(quantile_levels))
)
)

set_encoding(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
options = list(
predictor_indicators = "traditional",
compute_intercept = TRUE,
remove_intercept = TRUE,
allow_sparse_x = FALSE
)
)

set_pred(
model = "linear_reg",
eng = "quantreg",
mode = "quantile regression",
type = "quantile",
value = list(
pre = NULL,
post = matrix_to_quantile_pred,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data)
)
)
)
11 changes: 11 additions & 0 deletions R/linear_reg_quantreg.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#' Linear quantile regression via the quantreg package
#'
#' [quantreg::rq()] optimizes quantile loss to fit models with numeric outcomes.
#'
#' @includeRmd man/rmd/linear_reg_quantreg.md details
#'
#' @name details_linear_reg_quantreg
#' @keywords internal
NULL

# See inst/README-DOCS.md for a description of how these files are processed
6 changes: 4 additions & 2 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ check_pred_type <- function(object, type, ..., call = rlang::caller_env()) {
regression = "numeric",
classification = "class",
"censored regression" = "time",
"quantile regression" = "quantile",
cli::cli_abort(
"{.arg type} should be 'regression', 'censored regression', or 'classification'.",
"{.arg type} should be one of {.or {.val {all_modes}}}.",
call = call
)
)
}

if (!(type %in% pred_types))
cli::cli_abort(
"{.arg type} should be one of {.or {.arg {pred_types}}}.",
Expand Down Expand Up @@ -373,7 +375,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

# ----------------------------------------------------------------------------

other_args <- c("interval", "level", "std_error", "quantile",
other_args <- c("interval", "level", "std_error", "quantile_levels",
"time", "eval_time", "increasing")

eval_time_types <- c("survival", "hazard")
Expand Down
Loading

0 comments on commit 297320e

Please sign in to comment.