Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bootstrap #148

Merged
merged 12 commits into from
Sep 17, 2024
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

### New Features

- Added `lmtp_survival()` function for estimating the entire survival curve. Enforces monotonicity using isotonic regression (see issue \#140.
- Added `lmtp_survival()` function for estimating the entire survival curve. Enforces monotonicity using isotonic regression (see issue \#140).
- Bootstrap for TMLE with the `boot` argument using a modified TMLE algorithm (https://arxiv.org/abs/1810.03030).

### Bug Fixes

Expand Down
39 changes: 36 additions & 3 deletions R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
#' @param mtp \[\code{logical(1)}\]\cr
#' Is the intervention of interest a modified treatment policy?
#' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.
#' @param boot \[\code{logical(1)}\]\cr
#' Compute standard errors using the bootstrap? Default is \code{FALSE}. If \code{FALSE}, standard
#' errors will be calculated using the empirical variance of the efficient influence function.
#' @param outcome_type \[\code{character(1)}\]\cr
#' Outcome variable type (i.e., continuous, binomial, survival).
#' @param id \[\code{character(1)}\]\cr
Expand Down Expand Up @@ -77,7 +80,8 @@
#' \item{standard_error}{The estimated standard error of the LMTP effect.}
#' \item{low}{Lower bound of the 95% confidence interval of the LMTP effect.}
#' \item{high}{Upper bound of the 95% confidence interval of the LMTP effect.}
#' \item{eif}{The estimated, un-centered, influence function of the estimate.}
#' \item{eif}{The estimated, un-centered, influence function of the estimate,
#' \code{NULL} if \code{boot = TRUE}.}
#' \item{shift}{The shift function specifying the treatment policy of interest.}
#' \item{outcome_reg}{An n x Tau + 1 matrix of outcome regression predictions.
#' The mean of the first column is used for calculating theta.}
Expand All @@ -92,7 +96,8 @@
#' @export
lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
cens = NULL, shift = NULL, shifted = NULL, k = Inf,
mtp = FALSE, outcome_type = c("binomial", "continuous", "survival"),
mtp = FALSE, boot = FALSE,
outcome_type = c("binomial", "continuous", "survival"),
id = NULL, bounds = NULL,
learners_outcome = "SL.glm",
learners_trt = "SL.glm",
Expand Down Expand Up @@ -125,6 +130,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
checkmate::assertNumber(control$.bound)
checkmate::assertNumber(control$.trim, upper = 1)
checkmate::assertLogical(control$.return_full_fits, len = 1)
checkmate::assertLogical(boot, len = 1)
check_trt_type(data, unlist(trt), mtp)

task <- lmtp_task$new(
Expand All @@ -147,6 +153,33 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,

pb <- progressr::progressor(task$tau*folds*2)

if (isTRUE(boot)) {
ratios <- cf_r(task, learners_trt, mtp, control, pb)
Qn <- cf_sub(task, "tmp_lmtp_scaled_outcome", learners_outcome, control, pb)
Qnb_eps <- cf_tmle2(task, ratios$ratios, Qn, control)

ans <- theta_boot(
list(
estimator = "TMLE",
m = Qnb_eps$psi,
r = ratios$ratios,
boots = Qnb_eps$booted,
tau = task$tau,
folds = task$folds,
id = task$id,
outcome_type = task$outcome_type,
bounds = task$bounds,
weights = task$weights,
shift = if (is.null(shifted)) deparse(substitute((shift))) else NULL,
fits_m = Qn$fits,
fits_r = ratios$fits,
outcome_type = task$outcome_type,
seed = Qnb_eps$seed
)
)
return(ans)
}

ratios <- cf_r(task, learners_trt, mtp, control, pb)
estims <- cf_tmle(task,
"tmp_lmtp_scaled_outcome",
Expand Down Expand Up @@ -486,7 +519,7 @@ lmtp_sub <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, cens

theta_sub(
eta = list(
m = estims$m,
m = estims$ms,
outcome_type = task$outcome_type,
bounds = task$bounds,
folds = task$folds,
Expand Down
14 changes: 8 additions & 6 deletions R/gcomp.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ cf_sub <- function(task, outcome, learners, control, pb) {
out <- future::value(out)

list(
m = recombine_outcome(out, "m", task$folds),
ms = recombine_outcome(out, "ms", task$folds),
mn = recombine_outcome(out, "mn", task$folds),
fits = lapply(out, function(x) x[["fits"]])
)
}

estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
tau, outcome_type, learners, control, pb) {

m <- matrix(nrow = nrow(natural$valid), ncol = tau)
ms <- mn <- matrix(nrow = nrow(natural$valid), ncol = tau)
fits <- vector("list", length = tau)

for (t in tau:1) {
Expand Down Expand Up @@ -79,13 +79,15 @@ estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
under_shift_valid[, trt_t] <- shifted$valid[jv & rv, trt_t]

natural$train[jt & rt, pseudo] <- bound(SL_predict(fit, under_shift_train), 1e-05)
m[jv & rv, t] <- bound(SL_predict(fit, under_shift_valid), 1e-05)
ms[jv & rv, t] <- bound(SL_predict(fit, under_shift_valid), 1e-05)
mn[jv & rv, t] <- bound(SL_predict(fit, natural$valid[jv & rv, vars]), 1e-05)

natural$train[!rt, pseudo] <- 0
m[!rv, t] <- 0
ms[!rv, t] <- 0
mn[!rv, t] <- 0

pb()
}

list(m = m, fits = fits)
list(ms = ms, mn = mn, fits = fits)
}
10 changes: 8 additions & 2 deletions R/lmtp_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#' The number of cross-validation folds for \code{learners_trt}.
#' @param .return_full_fits \[\code{logical(1)}\]\cr
#' Return full SuperLearner fits? Default is \code{FALSE}, return only SuperLearner weights.
#' @param .B description
#' @param .boot_seed description
#'
#' @return A list of parameters controlling the estimation procedure.
#' @export
Expand All @@ -23,10 +25,14 @@ lmtp_control <- function(.bound = 1e5,
.trim = 0.999,
.learners_outcome_folds = 10,
.learners_trt_folds = 10,
.return_full_fits = FALSE) {
.return_full_fits = FALSE,
.B = 1000,
.boot_seed = NULL) {
list(.bound = .bound,
.trim = .trim,
.learners_outcome_folds = .learners_outcome_folds,
.learners_trt_folds = .learners_trt_folds,
.return_full_fits = .return_full_fits)
.return_full_fits = .return_full_fits,
.B = .B,
.boot_seed = .boot_seed)
}
5 changes: 5 additions & 0 deletions R/lmtp_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
#' @param mtp \[\code{logical(1)}\]\cr
#' Is the intervention of interest a modified treatment policy?
#' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.
#' @param boot \[\code{logical(1)}\]\cr
#' Compute standard errors using the bootstrap? Default is \code{FALSE}. If \code{FALSE}, standard
#' errors will be calculated using the empirical variance of the efficient influence function.
#' Ignored if \code{estimator = "lmtp_sdr"}.
#' @param id \[\code{character(1)}\]\cr
#' An optional column name containing cluster level identifiers.
#' @param learners_outcome \[\code{character}\]\cr A vector of \code{SuperLearner} algorithms for estimation
Expand All @@ -64,6 +68,7 @@ lmtp_survival <- function(data, trt, outcomes, baseline = NULL, time_vary = NULL
estimator = c("lmtp_tmle", "lmtp_sdr"),
k = Inf,
mtp = FALSE,
boot = FALSE,
id = NULL,
learners_outcome = "SL.glm",
learners_trt = "SL.glm",
Expand Down
38 changes: 38 additions & 0 deletions R/theta.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,41 @@ theta_dr <- function(eta, augmented = FALSE) {
class(out) <- "lmtp"
out
}

# TODO: NEED TO SAVE THE SEED FOR THE REPLICATES AND THE BOOTED ESTIMATES FOR ESTIMATNG CONTRASTS
theta_boot <- function(eta) {
theta <- weighted.mean(eta$m[, 1], eta$weights)

if (eta$outcome_type == "continuous") {
theta <- rescale_y_continuous(theta, eta$bounds)
}

# TODO: NEED TO FIGURE OUT HOW THIS WOULD WORK WITH CLUSTERING
se <- sqrt(var(eta$boots))
ci_low <- theta - (qnorm(0.975) * se)
ci_high <- theta + (qnorm(0.975) * se)

out <- list(
estimator = eta$estimator,
theta = theta,
standard_error = se,
low = ci_low,
high = ci_high,
boots = eta$boots,
id = eta$id,
shift = eta$shift,
outcome_reg = switch(
eta$outcome_type,
continuous = rescale_y_continuous(eta$m, eta$bounds),
binomial = eta$m
),
density_ratios = eta$r,
fits_m = eta$fits_m,
fits_r = eta$fits_r,
outcome_type = eta$outcome_type,
seed = eta$seed
)

class(out) <- "lmtp"
out
}
67 changes: 67 additions & 0 deletions R/tmle2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
cf_tmle2 <- function(task, ratios, m_init, control) {
psi <- estimate_tmle2(task$natural,
task$cens,
task$risk,
task$tau,
m_init$mn,
m_init$ms,
ratios,
task$weights)

if (!is.null(control$.boot_seed)) {
seed <- control$.boot_seed
} else {
seed <- .Random.seed[1]
}

set.seed(seed)
boots <- replicate(control$.B,
sample(1:nrow(task$natural), nrow(task$natural), replace = TRUE),
simplify = FALSE)

Qnb <- vector("list", control$.B)
for (i in seq_along(boots)) {
ib <- boots[[i]]
Qnb[[i]] <- future::future({
psis <- estimate_tmle2(task$natural[ib, , drop = FALSE],
task$cens,
task$risk,
task$tau,
m_init$mn[ib, , drop = FALSE],
m_init$ms[ib, , drop = FALSE],
ratios[ib, , drop = FALSE],
task$weights[ib])
weighted.mean(psis[, 1], task$weights[ib])
},
seed = TRUE)
}

list(psi = psi, booted = unlist(future::value(Qnb)), seed = seed)
}

estimate_tmle2 <- function(data, cens, risk, tau, mn, ms, ratios, weights) {
m_eps <- matrix(nrow = nrow(data), ncol = tau + 1)
m_eps[, tau + 1] <- data$tmp_lmtp_scaled_outcome

fits <- vector("list", length = tau)
for (t in tau:1) {
i <- censored(data, cens, t)$i
j <- censored(data, cens, t)$j
r <- at_risk(data, risk, t)

wts <- ratios[i & r, t] * weights[i & r]

fit <- sw(
glm(
m_eps[i & r, t + 1] ~ offset(qlogis(mn[i & r, t])),
weights = wts,
family = "binomial"
)
)

m_eps[j & r, t] <- bound(plogis(qlogis(ms[j & r, t]) + coef(fit)))
m_eps[!r, t] <- 0
}

m_eps
}
8 changes: 7 additions & 1 deletion man/lmtp_control.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/lmtp_survival.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/lmtp_tmle.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-censoring.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ truth <- 0.8

sub <- sw(lmtp_sub(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))
ipw <- sw(lmtp_ipw(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))
tmle <- sw(lmtp_tmle(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))
tmle <- sw(lmtp_tmle(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, boot = F, folds = 1))
sdr <- sw(lmtp_sdr(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))

# tests
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ cens <- c("C1", "C2")
nodes <- list(c(NULL), c("L"))

# truth = 0.308
tml.stc <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = static_binary_on, folds = 1))
tml.stc <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = static_binary_on, folds = 1, boot = F))

# truth = 0.528
sdr.stc <- sw(lmtp_sdr(sim, a, "Y", baseline, nodes, cens, shift = static_binary_off, folds = 1))

# truth = 0.433
tml.tv <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = time_vary_on, folds = 1))
tml.tv <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = time_vary_on, folds = 1, boot = F))
sdr.tv <- sw(lmtp_sdr(sim, a, "Y", baseline, nodes, cens, shift = time_vary_on, folds = 1))

# time varying and covariate dynamic
# truth = 0.345
tml.dyn <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = dynamic_vec, folds = 1))
tml.dyn <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = dynamic_vec, folds = 1, boot = F))
sdr.dyn <- sw(lmtp_sdr(sim, a, "Y", baseline, nodes, cens, shift = dynamic_vec, folds = 1))

test_that("Dynamic intervention fidelity", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-point_treatment.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ truth <- mean(tmp$Y.1)

sub <- lmtp_sub(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)
ipw <- lmtp_ipw(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)
tmle <- lmtp_tmle(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)
tmle <- lmtp_tmle(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1, boot = F)
sdr <- lmtp_sdr(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)

# tests
Expand Down
Loading
Loading