Skip to content

Commit

Permalink
Merge pull request #148 from nt-williams/bootstrap
Browse files Browse the repository at this point in the history
Bootstrap
  • Loading branch information
nt-williams authored Sep 17, 2024
2 parents db81fcc + 98fd383 commit ab9b1e4
Show file tree
Hide file tree
Showing 16 changed files with 192 additions and 22 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

### New Features

- Added `lmtp_survival()` function for estimating the entire survival curve. Enforces monotonicity using isotonic regression (see issue \#140.
- Added `lmtp_survival()` function for estimating the entire survival curve. Enforces monotonicity using isotonic regression (see issue \#140).
- Bootstrap for TMLE with the `boot` argument using a modified TMLE algorithm (https://arxiv.org/abs/1810.03030).

### Bug Fixes

Expand Down
39 changes: 36 additions & 3 deletions R/estimators.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
#' @param mtp \[\code{logical(1)}\]\cr
#' Is the intervention of interest a modified treatment policy?
#' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.
#' @param boot \[\code{logical(1)}\]\cr
#' Compute standard errors using the bootstrap? Default is \code{FALSE}. If \code{FALSE}, standard
#' errors will be calculated using the empirical variance of the efficient influence function.
#' @param outcome_type \[\code{character(1)}\]\cr
#' Outcome variable type (i.e., continuous, binomial, survival).
#' @param id \[\code{character(1)}\]\cr
Expand Down Expand Up @@ -77,7 +80,8 @@
#' \item{standard_error}{The estimated standard error of the LMTP effect.}
#' \item{low}{Lower bound of the 95% confidence interval of the LMTP effect.}
#' \item{high}{Upper bound of the 95% confidence interval of the LMTP effect.}
#' \item{eif}{The estimated, un-centered, influence function of the estimate.}
#' \item{eif}{The estimated, un-centered, influence function of the estimate,
#' \code{NULL} if \code{boot = TRUE}.}
#' \item{shift}{The shift function specifying the treatment policy of interest.}
#' \item{outcome_reg}{An n x Tau + 1 matrix of outcome regression predictions.
#' The mean of the first column is used for calculating theta.}
Expand All @@ -92,7 +96,8 @@
#' @export
lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
cens = NULL, shift = NULL, shifted = NULL, k = Inf,
mtp = FALSE, outcome_type = c("binomial", "continuous", "survival"),
mtp = FALSE, boot = FALSE,
outcome_type = c("binomial", "continuous", "survival"),
id = NULL, bounds = NULL,
learners_outcome = "SL.glm",
learners_trt = "SL.glm",
Expand Down Expand Up @@ -125,6 +130,7 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,
checkmate::assertNumber(control$.bound)
checkmate::assertNumber(control$.trim, upper = 1)
checkmate::assertLogical(control$.return_full_fits, len = 1)
checkmate::assertLogical(boot, len = 1)
check_trt_type(data, unlist(trt), mtp)

task <- lmtp_task$new(
Expand All @@ -147,6 +153,33 @@ lmtp_tmle <- function(data, trt, outcome, baseline = NULL, time_vary = NULL,

pb <- progressr::progressor(task$tau*folds*2)

if (isTRUE(boot)) {
ratios <- cf_r(task, learners_trt, mtp, control, pb)
Qn <- cf_sub(task, "tmp_lmtp_scaled_outcome", learners_outcome, control, pb)
Qnb_eps <- cf_tmle2(task, ratios$ratios, Qn, control)

ans <- theta_boot(
list(
estimator = "TMLE",
m = Qnb_eps$psi,
r = ratios$ratios,
boots = Qnb_eps$booted,
tau = task$tau,
folds = task$folds,
id = task$id,
outcome_type = task$outcome_type,
bounds = task$bounds,
weights = task$weights,
shift = if (is.null(shifted)) deparse(substitute((shift))) else NULL,
fits_m = Qn$fits,
fits_r = ratios$fits,
outcome_type = task$outcome_type,
seed = Qnb_eps$seed
)
)
return(ans)
}

ratios <- cf_r(task, learners_trt, mtp, control, pb)
estims <- cf_tmle(task,
"tmp_lmtp_scaled_outcome",
Expand Down Expand Up @@ -486,7 +519,7 @@ lmtp_sub <- function(data, trt, outcome, baseline = NULL, time_vary = NULL, cens

theta_sub(
eta = list(
m = estims$m,
m = estims$ms,
outcome_type = task$outcome_type,
bounds = task$bounds,
folds = task$folds,
Expand Down
14 changes: 8 additions & 6 deletions R/gcomp.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ cf_sub <- function(task, outcome, learners, control, pb) {
out <- future::value(out)

list(
m = recombine_outcome(out, "m", task$folds),
ms = recombine_outcome(out, "ms", task$folds),
mn = recombine_outcome(out, "mn", task$folds),
fits = lapply(out, function(x) x[["fits"]])
)
}

estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
tau, outcome_type, learners, control, pb) {

m <- matrix(nrow = nrow(natural$valid), ncol = tau)
ms <- mn <- matrix(nrow = nrow(natural$valid), ncol = tau)
fits <- vector("list", length = tau)

for (t in tau:1) {
Expand Down Expand Up @@ -79,13 +79,15 @@ estimate_sub <- function(natural, shifted, trt, outcome, node_list, cens, risk,
under_shift_valid[, trt_t] <- shifted$valid[jv & rv, trt_t]

natural$train[jt & rt, pseudo] <- bound(SL_predict(fit, under_shift_train), 1e-05)
m[jv & rv, t] <- bound(SL_predict(fit, under_shift_valid), 1e-05)
ms[jv & rv, t] <- bound(SL_predict(fit, under_shift_valid), 1e-05)
mn[jv & rv, t] <- bound(SL_predict(fit, natural$valid[jv & rv, vars]), 1e-05)

natural$train[!rt, pseudo] <- 0
m[!rv, t] <- 0
ms[!rv, t] <- 0
mn[!rv, t] <- 0

pb()
}

list(m = m, fits = fits)
list(ms = ms, mn = mn, fits = fits)
}
10 changes: 8 additions & 2 deletions R/lmtp_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#' The number of cross-validation folds for \code{learners_trt}.
#' @param .return_full_fits \[\code{logical(1)}\]\cr
#' Return full SuperLearner fits? Default is \code{FALSE}, return only SuperLearner weights.
#' @param .B description
#' @param .boot_seed description
#'
#' @return A list of parameters controlling the estimation procedure.
#' @export
Expand All @@ -23,10 +25,14 @@ lmtp_control <- function(.bound = 1e5,
.trim = 0.999,
.learners_outcome_folds = 10,
.learners_trt_folds = 10,
.return_full_fits = FALSE) {
.return_full_fits = FALSE,
.B = 1000,
.boot_seed = NULL) {
list(.bound = .bound,
.trim = .trim,
.learners_outcome_folds = .learners_outcome_folds,
.learners_trt_folds = .learners_trt_folds,
.return_full_fits = .return_full_fits)
.return_full_fits = .return_full_fits,
.B = .B,
.boot_seed = .boot_seed)
}
5 changes: 5 additions & 0 deletions R/lmtp_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
#' @param mtp \[\code{logical(1)}\]\cr
#' Is the intervention of interest a modified treatment policy?
#' Default is \code{FALSE}. If treatment variables are continuous this should be \code{TRUE}.
#' @param boot \[\code{logical(1)}\]\cr
#' Compute standard errors using the bootstrap? Default is \code{FALSE}. If \code{FALSE}, standard
#' errors will be calculated using the empirical variance of the efficient influence function.
#' Ignored if \code{estimator = "lmtp_sdr"}.
#' @param id \[\code{character(1)}\]\cr
#' An optional column name containing cluster level identifiers.
#' @param learners_outcome \[\code{character}\]\cr A vector of \code{SuperLearner} algorithms for estimation
Expand All @@ -64,6 +68,7 @@ lmtp_survival <- function(data, trt, outcomes, baseline = NULL, time_vary = NULL
estimator = c("lmtp_tmle", "lmtp_sdr"),
k = Inf,
mtp = FALSE,
boot = FALSE,
id = NULL,
learners_outcome = "SL.glm",
learners_trt = "SL.glm",
Expand Down
38 changes: 38 additions & 0 deletions R/theta.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,41 @@ theta_dr <- function(eta, augmented = FALSE) {
class(out) <- "lmtp"
out
}

# TODO: NEED TO SAVE THE SEED FOR THE REPLICATES AND THE BOOTED ESTIMATES FOR ESTIMATNG CONTRASTS
theta_boot <- function(eta) {
theta <- weighted.mean(eta$m[, 1], eta$weights)

if (eta$outcome_type == "continuous") {
theta <- rescale_y_continuous(theta, eta$bounds)
}

# TODO: NEED TO FIGURE OUT HOW THIS WOULD WORK WITH CLUSTERING
se <- sqrt(var(eta$boots))
ci_low <- theta - (qnorm(0.975) * se)
ci_high <- theta + (qnorm(0.975) * se)

out <- list(
estimator = eta$estimator,
theta = theta,
standard_error = se,
low = ci_low,
high = ci_high,
boots = eta$boots,
id = eta$id,
shift = eta$shift,
outcome_reg = switch(
eta$outcome_type,
continuous = rescale_y_continuous(eta$m, eta$bounds),
binomial = eta$m
),
density_ratios = eta$r,
fits_m = eta$fits_m,
fits_r = eta$fits_r,
outcome_type = eta$outcome_type,
seed = eta$seed
)

class(out) <- "lmtp"
out
}
67 changes: 67 additions & 0 deletions R/tmle2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
cf_tmle2 <- function(task, ratios, m_init, control) {
psi <- estimate_tmle2(task$natural,
task$cens,
task$risk,
task$tau,
m_init$mn,
m_init$ms,
ratios,
task$weights)

if (!is.null(control$.boot_seed)) {
seed <- control$.boot_seed
} else {
seed <- .Random.seed[1]
}

set.seed(seed)
boots <- replicate(control$.B,
sample(1:nrow(task$natural), nrow(task$natural), replace = TRUE),
simplify = FALSE)

Qnb <- vector("list", control$.B)
for (i in seq_along(boots)) {
ib <- boots[[i]]
Qnb[[i]] <- future::future({
psis <- estimate_tmle2(task$natural[ib, , drop = FALSE],
task$cens,
task$risk,
task$tau,
m_init$mn[ib, , drop = FALSE],
m_init$ms[ib, , drop = FALSE],
ratios[ib, , drop = FALSE],
task$weights[ib])
weighted.mean(psis[, 1], task$weights[ib])
},
seed = TRUE)
}

list(psi = psi, booted = unlist(future::value(Qnb)), seed = seed)
}

estimate_tmle2 <- function(data, cens, risk, tau, mn, ms, ratios, weights) {
m_eps <- matrix(nrow = nrow(data), ncol = tau + 1)
m_eps[, tau + 1] <- data$tmp_lmtp_scaled_outcome

fits <- vector("list", length = tau)
for (t in tau:1) {
i <- censored(data, cens, t)$i
j <- censored(data, cens, t)$j
r <- at_risk(data, risk, t)

wts <- ratios[i & r, t] * weights[i & r]

fit <- sw(
glm(
m_eps[i & r, t + 1] ~ offset(qlogis(mn[i & r, t])),
weights = wts,
family = "binomial"
)
)

m_eps[j & r, t] <- bound(plogis(qlogis(ms[j & r, t]) + coef(fit)))
m_eps[!r, t] <- 0
}

m_eps
}
8 changes: 7 additions & 1 deletion man/lmtp_control.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/lmtp_survival.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/lmtp_tmle.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-censoring.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ truth <- 0.8

sub <- sw(lmtp_sub(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))
ipw <- sw(lmtp_ipw(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))
tmle <- sw(lmtp_tmle(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))
tmle <- sw(lmtp_tmle(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, boot = F, folds = 1))
sdr <- sw(lmtp_sdr(sim_cens, A, "Y", time_vary = L, cens = C, k = 0, shift = NULL, folds = 1))

# tests
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-dynamic.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ cens <- c("C1", "C2")
nodes <- list(c(NULL), c("L"))

# truth = 0.308
tml.stc <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = static_binary_on, folds = 1))
tml.stc <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = static_binary_on, folds = 1, boot = F))

# truth = 0.528
sdr.stc <- sw(lmtp_sdr(sim, a, "Y", baseline, nodes, cens, shift = static_binary_off, folds = 1))

# truth = 0.433
tml.tv <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = time_vary_on, folds = 1))
tml.tv <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = time_vary_on, folds = 1, boot = F))
sdr.tv <- sw(lmtp_sdr(sim, a, "Y", baseline, nodes, cens, shift = time_vary_on, folds = 1))

# time varying and covariate dynamic
# truth = 0.345
tml.dyn <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = dynamic_vec, folds = 1))
tml.dyn <- sw(lmtp_tmle(sim, a, "Y", baseline, nodes, cens, shift = dynamic_vec, folds = 1, boot = F))
sdr.dyn <- sw(lmtp_sdr(sim, a, "Y", baseline, nodes, cens, shift = dynamic_vec, folds = 1))

test_that("Dynamic intervention fidelity", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-point_treatment.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ truth <- mean(tmp$Y.1)

sub <- lmtp_sub(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)
ipw <- lmtp_ipw(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)
tmle <- lmtp_tmle(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)
tmle <- lmtp_tmle(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1, boot = F)
sdr <- lmtp_sdr(tmp, "A", "Y", baseline = c("W1", "W2"), shift = static_binary_on, folds = 1)

# tests
Expand Down
Loading

0 comments on commit ab9b1e4

Please sign in to comment.