Skip to content

Commit

Permalink
Added immediate exit if not beta_fit does not converge
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Aug 28, 2024
1 parent 15ab6fe commit 71371bf
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
20 changes: 16 additions & 4 deletions R/main.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ fit_devil <- function(
groups <- devil:::get_groups_for_model_matrix(design_matrix)

if (!is.null(groups)) {
beta_0 <- devil:::init_beta_groups(input_mat, groups, offset_matrix)
beta_0 <- devil:::init_beta(input_mat, design_matrix, offset_matrix)
beta_0_groups <- devil:::init_beta_groups(input_mat, groups, offset_matrix)
} else {
beta_0 <- devil:::init_beta(input_mat, design_matrix, offset_matrix)
}
Expand Down Expand Up @@ -151,9 +152,20 @@ fit_devil <- function(
} else {

if (verbose) { message("Fit beta coefficients") }
tmp <- parallel::mclapply(1:ngenes, function(i) {
devil:::beta_fit(input_mat[i,], design_matrix, beta_0[i,], offset_matrix[i,], dispersion_init[i], max_iter = max_iter, eps = tolerance)
}, mc.cores = n.cores)

i <- 1
if (!is.null(groups)) {
tmp <- parallel::mclapply(1:ngenes, function(i) {
r <- devil:::beta_fit(input_mat[i,], design_matrix, beta_0_groups[i,], offset_matrix[i,], dispersion_init[i], max_iter = max_iter, eps = tolerance)
if (sum(is.na(r$mu_beta))) {r <- devil:::beta_fit(input_mat[i,], design_matrix, beta_0[i,], offset_matrix[i,], dispersion_init[i], max_iter = max_iter, eps = tolerance)}
r
}, mc.cores = n.cores)
} else {
tmp <- parallel::mclapply(1:ngenes, function(i) {
r <- devil:::beta_fit(input_mat[i,], design_matrix, beta_0[i,], offset_matrix[i,], dispersion_init[i], max_iter = max_iter, eps = tolerance)
r
}, mc.cores = n.cores)
}

beta <- lapply(1:ngenes, function(i) { tmp[[i]]$mu_beta }) %>% do.call("rbind", .)
rownames(beta) <- gene_names
Expand Down
4 changes: 4 additions & 0 deletions src/beta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ List beta_fit(Eigen::VectorXd y, Eigen::MatrixXd X, Eigen::VectorXd mu_beta, Eig
mu_beta += delta;
converged = delta.cwiseAbs().maxCoeff() < eps;
iter++;
if (delta[0] != delta[0]) {converged = TRUE;}

}

// Return both mu_beta and Zigma as a List
Expand All @@ -57,6 +59,8 @@ List beta_fit_group(Eigen::VectorXd y, float mu_beta, Eigen::VectorXd off, float
mu_g = (k + y.array()) / (1 + k * w_q.array());
Zigma = 1.0 / (k * (mu_g.array() * w_q.array()).sum());



delta = Zigma * (k * (mu_g.array() * w_q.array() - 1).sum());
mu_beta += delta;
converged = delta < eps;
Expand Down

0 comments on commit 71371bf

Please sign in to comment.