Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Devel #150

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open

Devel #150

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8d4ccd8
Modified TMLE for computationally effiecient bootstrap
nt-williams Apr 26, 2024
a73f486
Pre-allocated for loop with future for bootstrap
nt-williams Apr 26, 2024
327c0bf
Updating tests
nt-williams Apr 26, 2024
6e63f60
Updated NEWS
nt-williams Apr 26, 2024
2a72e66
Todo notes
nt-williams Apr 26, 2024
feb8ec8
Merge
nt-williams Apr 29, 2024
3ee75c7
Merge branch 'devel' into bootstrap
nt-williams Apr 29, 2024
82d22e8
Merge branch 'devel' into bootstrap
nt-williams May 30, 2024
6b1a707
Basic implementation for lmtp_survival
nt-williams Jun 27, 2024
c989ed0
Isotonic project to keep estimates monotone
nt-williams Jun 27, 2024
d30f1f5
lmtp_survival documentation
nt-williams Jun 27, 2024
c728433
Version bump
nt-williams Jun 27, 2024
6434a86
Updateds NEWS.md
nt-williams Jun 27, 2024
609cb61
Tidy method for lmtp_survival
nt-williams Jun 28, 2024
0c31ee3
Progress ticker for lmtp_survival, updated documentation
nt-williams Jun 28, 2024
c539f2f
Merge pull request #141 from nt-williams/lmtp_survival
nt-williams Jun 29, 2024
102002b
Refactoring for loop in lmtp_survival
nt-williams Jun 29, 2024
d220621
Merge branch 'devel' into bootstrap
nt-williams Jun 29, 2024
b3c826e
Capturing boot call in lmtp_survival
nt-williams Jun 29, 2024
dbe5f7f
bug fixes for lmtp_survival with time varying treatment
nt-williams Jun 30, 2024
3fb7733
Removing schoolmath dependency
nt-williams Sep 11, 2024
257bfd0
Removing schoolmath dependency
nt-williams Sep 11, 2024
818efea
version bump
nt-williams Sep 11, 2024
db81fcc
Updated NEWS
nt-williams Sep 11, 2024
b471978
Merge with devel
nt-williams Sep 17, 2024
98fd383
Changing default for boot to false
nt-williams Sep 17, 2024
ab9b1e4
Merge pull request #148 from nt-williams/bootstrap
nt-williams Sep 17, 2024
387070a
Accounting for ignored boot call in lmtp_survival
nt-williams Sep 19, 2024
15241d8
Updated NEWS
nt-williams Sep 19, 2024
7be1f43
Bug fix in isotonic regression
nt-williams Sep 26, 2024
74db0f0
updates NEWS
nt-williams Sep 26, 2024
65e16c0
Fixing version discrepancies
nt-williams Sep 26, 2024
9bec661
Fix for issue 151
nt-williams Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: lmtp
Title: Non-Parametric Causal Effects of Feasible Interventions Based on Modified Treatment Policies
Version: 1.4.0
Version: 1.4.1
Authors@R:
c(person(given = "Nicholas",
family = "Williams",
Expand All @@ -24,7 +24,7 @@ License: AGPL-3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Imports:
stats,
nnls,
Expand All @@ -37,7 +37,7 @@ Imports:
data.table (>= 1.13.0),
checkmate (>= 2.1.0),
SuperLearner,
schoolmath
isotone
URL: https://beyondtheate.com/, https://github.com/nt-williams/lmtp
BugReports: https://github.com/nt-williams/lmtp/issues
Suggests:
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

S3method(print,lmtp)
S3method(print,lmtp_contrast)
S3method(print,lmtp_survival)
S3method(tidy,lmtp)
S3method(tidy,lmtp_survival)
export(create_node_list)
export(event_locf)
export(ipsi)
Expand All @@ -11,6 +13,7 @@ export(lmtp_control)
export(lmtp_ipw)
export(lmtp_sdr)
export(lmtp_sub)
export(lmtp_survival)
export(lmtp_tmle)
export(static_binary_off)
export(static_binary_on)
Expand Down
16 changes: 16 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# lmtp 1.4.1

### New Features

- Added `lmtp_survival()` function for estimating the entire survival curve. Enforces monotonicity using isotonic regression (see issue \#140).
- Bootstrap for TMLE with the `boot` argument using a modified TMLE algorithm (https://arxiv.org/abs/1810.03030).

### Bug Fixes

- Using fitted values from isotonic regression in `lmtp_survival()` instead of the original values (see issue \#149).
- Bootstrap TMLE uses cumulative density ratios (see issue \#151).

### General

- Removed dependency on `schoolmath` which used a very slow function for testing if a vector was "decimalish".

# lmtp 1.4.0

### New Features
Expand Down
2 changes: 1 addition & 1 deletion R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ check_trt_type <- function(data, trt, mtp) {
for (i in seq_along(trt)) {
a <- data[[trt[i]]]
if (is.character(a) | is.factor(a)) next
is_decimal[i] <- any(schoolmath::is.decimal(a[!is.na(a)]))
is_decimal[i] <- any(is_decimal(a[!is.na(a)]))
}
if (any(is_decimal) & isFALSE(mtp)) {
cli::cli_warn("Detected decimalish `trt` values and {.code mtp = FALSE}. Consider setting {.code mtp = TRUE} if getting errors.")
Expand Down
39 changes: 36 additions & 3 deletions R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
#' @param mtp \[\code{logical(1)}\]\cr
#' Is the intervention of interest a modified treatment policy?
#' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.
#' @param boot \[\code{logical(1)}\]\cr
#' Compute standard errors using the bootstrap? Default is \code{FALSE}. If \code{FALSE}, standard
#' errors will be calculated using the empirical variance of the efficient influence function.
#' @param outcome_type \[\code{character(1)}\]\cr
#' Outcome variable type (i.e., continuous, binomial, survival).
#' @param id \[\code{character(1)}\]\cr
Expand Down Expand Up @@ -77,7 +80,8 @@
#' \item{standard_error}{The estimated standard error of the LMTP effect.}
#' \item{low}{Lower bound of the 95% confidence interval of the LMTP effect.}
#' \item{high}{Upper bound of the 95% confidence interval of the LMTP effect.}
#' \item{eif}{The estimated, un-centered, influence function of the estimate.}
#' \item{eif}{The estimated, un-centered, influence function of the estimate,
#' \code{NULL} if \code{boot = TRUE}.}
#' \item{shift}{The shift function specifying the treatment policy of interest.}
#' \item{outcome_reg}{An n x Tau + 1 matrix of outcome regression predictions.
#' The mean of the first column is used for calculating theta.}
Expand All @@ -92,7 +96,8 @@
#' @export
lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
cens = NULL, shift = NULL, shifted = NULL, k = Inf,
mtp = FALSE, outcome_type = c("binomial", "continuous", "survival"),
mtp = FALSE, boot = FALSE,
outcome_type = c("binomial", "continuous", "survival"),
id = NULL, bounds = NULL,
learners_outcome = "SL.glm",
learners_trt = "SL.glm",
Expand Down Expand Up @@ -125,6 +130,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
checkmate::assertNumber(control$.bound)
checkmate::assertNumber(control$.trim, upper = 1)
checkmate::assertLogical(control$.return_full_fits, len = 1)
checkmate::assertLogical(boot, len = 1)
check_trt_type(data, unlist(trt), mtp)

task <- lmtp_task$new(
Expand All @@ -147,6 +153,33 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,

pb <- progressr::progressor(task$tau*folds*2)

if (isTRUE(boot)) {
ratios <- cf_r(task, learners_trt, mtp, control, pb)
Qn <- cf_sub(task, "tmp_lmtp_scaled_outcome", learners_outcome, control, pb)
Qnb_eps <- cf_tmle2(task, ratios$ratios, Qn, control)

ans <- theta_boot(
list(
estimator = "TMLE",
m = Qnb_eps$psi,
r = ratios$ratios,
boots = Qnb_eps$booted,
tau = task$tau,
folds = task$folds,
id = task$id,
outcome_type = task$outcome_type,
bounds = task$bounds,
weights = task$weights,
shift = if (is.null(shifted)) deparse(substitute((shift))) else NULL,
fits_m = Qn$fits,
fits_r = ratios$fits,
outcome_type = task$outcome_type,
seed = Qnb_eps$seed
)
)
return(ans)
}

ratios <- cf_r(task, learners_trt, mtp, control, pb)
estims <- cf_tmle(task,
"tmp_lmtp_scaled_outcome",
Expand Down Expand Up @@ -486,7 +519,7 @@ lmtp_sub <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, cens

theta_sub(
eta = list(
m = estims$m,
m = estims$ms,
outcome_type = task$outcome_type,
bounds = task$bounds,
folds = task$folds,
Expand Down
14 changes: 8 additions & 6 deletions R/gcomp.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ cf_sub <- function(task, outcome, learners, control, pb) {
out <- future::value(out)

list(
m = recombine_outcome(out, "m", task$folds),
ms = recombine_outcome(out, "ms", task$folds),
mn = recombine_outcome(out, "mn", task$folds),
fits = lapply(out, function(x) x[["fits"]])
)
}

estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
tau, outcome_type, learners, control, pb) {

m <- matrix(nrow = nrow(natural$valid), ncol = tau)
ms <- mn <- matrix(nrow = nrow(natural$valid), ncol = tau)
fits <- vector("list", length = tau)

for (t in tau:1) {
Expand Down Expand Up @@ -79,13 +79,15 @@ estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
under_shift_valid[, trt_t] <- shifted$valid[jv & rv, trt_t]

natural$train[jt & rt, pseudo] <- bound(SL_predict(fit, under_shift_train), 1e-05)
m[jv & rv, t] <- bound(SL_predict(fit, under_shift_valid), 1e-05)
ms[jv & rv, t] <- bound(SL_predict(fit, under_shift_valid), 1e-05)
mn[jv & rv, t] <- bound(SL_predict(fit, natural$valid[jv & rv, vars]), 1e-05)

natural$train[!rt, pseudo] <- 0
m[!rv, t] <- 0
ms[!rv, t] <- 0
mn[!rv, t] <- 0

pb()
}

list(m = m, fits = fits)
list(ms = ms, mn = mn, fits = fits)
}
10 changes: 8 additions & 2 deletions R/lmtp_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#' The number of cross-validation folds for \code{learners_trt}.
#' @param .return_full_fits \[\code{logical(1)}\]\cr
#' Return full SuperLearner fits? Default is \code{FALSE}, return only SuperLearner weights.
#' @param .B description
#' @param .boot_seed description
#'
#' @return A list of parameters controlling the estimation procedure.
#' @export
Expand All @@ -23,10 +25,14 @@ lmtp_control <- function(.bound = 1e5,
.trim = 0.999,
.learners_outcome_folds = 10,
.learners_trt_folds = 10,
.return_full_fits = FALSE) {
.return_full_fits = FALSE,
.B = 1000,
.boot_seed = NULL) {
list(.bound = .bound,
.trim = .trim,
.learners_outcome_folds = .learners_outcome_folds,
.learners_trt_folds = .learners_trt_folds,
.return_full_fits = .return_full_fits)
.return_full_fits = .return_full_fits,
.B = .B,
.boot_seed = .boot_seed)
}
142 changes: 142 additions & 0 deletions R/lmtp_survival.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#' LMTP Survival Curve Estimator
#'
#' Wrapper around \code{lmtp_tmle} and \code{lmtp_sdr} for survival outcomes to estimate the entire survival curve.
#' Estimates are reconstructed using isotonic regression to enforce monotonicity of the survival curve.
#' \bold{Confidence intervals correspond to marginal confidence intervals for the survival curve, not simultaneous intervals.}
#'
#' @param data \[\code{data.frame}\]\cr
#' A \code{data.frame} in wide format containing all necessary variables
#' for the estimation problem. Must not be a \code{data.table}.
#' @param trt \[\code{character}\] or \[\code{list}\]\cr
#' A vector containing the column names of treatment variables ordered by time.
#' Or, a list of vectors, the same length as the number of time points of observation.
#' Vectors should contain column names for the treatment variables at each time point. The list
#' should be ordered following the time ordering of the model.
#' @param outcomes \[\code{character}\]\cr
#' A vector containing the columns names of intermediate outcome variables and the final
#' outcome variable ordered by time. Only numeric values are allowed. Variables should be coded as 0 and 1.
#' @param baseline \[\code{character}\]\cr
#' An optional vector containing the column names of baseline covariates to be
#' included for adjustment at every time point.
#' @param time_vary \[\code{list}\]\cr
#' A list the same length as the number of time points of observation with
#' the column names for new time-varying covariates introduced at each time point. The list
#' should be ordered following the time ordering of the model.
#' @param cens \[\code{character}\]\cr
#' An optional vector of column names of censoring indicators the same
#' length as the number of time points of observation. If missingness in the outcome is
#' present or if time-to-event outcome, must be provided.
#' @param shift \[\code{closure}\]\cr
#' A two argument function that specifies how treatment variables should be shifted.
#' See examples for how to specify shift functions for continuous, binary, and categorical exposures.
#' @param shifted \[\code{data.frame}\]\cr
#' An optional data frame, the same as in \code{data}, but modified according
#' to the treatment policy of interest. If specified, \code{shift} is ignored.
#' @param estimator \[\code{character(1)}\]\cr
#' The estimator to use. Either \code{"lmtp_tmle"} or \code{"lmtp_sdr"}.
#' @param k \[\code{integer(1)}\]\cr
#' An integer specifying how previous time points should be
#' used for estimation at the given time point. Default is \code{Inf},
#' all time points.
#' @param mtp \[\code{logical(1)}\]\cr
#' Is the intervention of interest a modified treatment policy?
#' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.
#' @param boot \[\code{logical(1)}\]\cr
#' Compute standard errors using the bootstrap? Default is \code{FALSE}. If \code{FALSE}, standard
#' errors will be calculated using the empirical variance of the efficient influence function.
#' Ignored if \code{estimator = "lmtp_sdr"}.
#' @param id \[\code{character(1)}\]\cr
#' An optional column name containing cluster level identifiers.
#' @param learners_outcome \[\code{character}\]\cr A vector of \code{SuperLearner} algorithms for estimation
#' of the outcome regression. Default is \code{"SL.glm"}, a main effects GLM.
#' @param learners_trt \[\code{character}\]\cr A vector of \code{SuperLearner} algorithms for estimation
#' of the exposure mechanism. Default is \code{"SL.glm"}, a main effects GLM.
#' \bold{Only include candidate learners capable of binary classification}.
#' @param folds \[\code{integer(1)}\]\cr
#' The number of folds to be used for cross-fitting.
#' @param weights \[\code{numeric(nrow(data))}\]\cr
#' An optional vector containing sampling weights.
#' @param control \[\code{list()}\]\cr
#' Output of \code{lmtp_control()}.
#'
#' @return A list of class \code{lmtp_survival} containing \code{lmtp} objects for each time point.
#'
#' @example inst/examples/lmtp_survival-ex.R
#' @export
lmtp_survival <- function(data, trt, outcomes, baseline = NULL, time_vary = NULL,
cens = NULL, shift = NULL, shifted = NULL,
estimator = c("lmtp_tmle", "lmtp_sdr"),
k = Inf,
mtp = FALSE,
boot = FALSE,
id = NULL,
learners_outcome = "SL.glm",
learners_trt = "SL.glm",
folds = 10,
weights = NULL,
control = lmtp_control()) {

checkmate::assertCharacter(outcomes, min.len = 2, null.ok = FALSE, unique = TRUE, any.missing = FALSE)

estimator <- match.arg(estimator)
tau <- length(outcomes)
estimates <- vector("list", tau)

args <- list(
data = data,
baseline = baseline,
shift = shift,
shifted = shifted,
k = k,
mtp = mtp,
id = id,
learners_outcome = learners_outcome,
learners_trt = learners_trt,
folds = folds,
weights = weights,
control = control
)

if (length(trt) == 1) args$trt <- trt
if (length(time_vary) == 1) args$time_vary <- time_vary

if (estimator == "lmtp_tmle") {
args$boot <- boot
expr <- expression(do.call(lmtp_tmle, args))
} else {
expr <- expression(do.call(lmtp_sdr, args))
}

t <- 1
cli::cli_progress_step("Working on time {t}/{tau}...")
for (t in 1:tau) {
if (length(trt) > 1) args$trt <- trt[1:t]
if (length(args$time_vary) > 1) args$time_vary <- time_vary[1:t]
args$outcome <- outcomes[1:t]
args$cens <- cens[1:t]
args$outcome_type <- ifelse(t == 1, "binomial", "survival")

estimates[[t]] <- future::future(eval(expr), seed = TRUE)
cli::cli_progress_update()
}

cli::cli_progress_done()
estimates <- future::value(estimates)
estimates <- fix_surv_time1(estimates)
estimates <- isotonic_projection(estimates)

class(estimates) <- "lmtp_survival"
estimates
}

isotonic_projection <- function(x, alpha = 0.05) {
cv <- abs(qnorm(p = alpha / 2))
estim <- tidy.lmtp_survival(x)
iso_fit <- isotone::gpava(1:length(x), 1 - estim$estimate)
for (i in seq_along(x)) {
x[[i]]$theta <- (1 - iso_fit$x[i])
x[[i]]$low <- x[[i]]$theta - (qnorm(0.975) * x[[i]]$standard_error)
x[[i]]$high <- x[[i]]$theta + (qnorm(0.975) * x[[i]]$standard_error)
}
x
}
5 changes: 5 additions & 0 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@ print.lmtp_contrast <- function(x, ...) {
x$vals$p.value <- format.pval(x$vals$p.value, digits = 3, eps = 0.001)
print(format(x$vals, digits = 3))
}

#' @export
print.lmtp_survival <- function(x, ...) {
print(as.data.frame(tidy.lmtp_survival(x)))
}
Loading
Loading