Skip to content

Commit

Permalink
Fixed jit tags for MacOS and check_params()
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobrepucci committed Jul 17, 2024
1 parent d115eb9 commit 10532f9
Showing 1 changed file with 81 additions and 87 deletions.
168 changes: 81 additions & 87 deletions src/multivelo/dynamical_chrom_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import scvelo as scv
import pandas as pd
import seaborn as sns
from numba import jit
from numba import njit
import numba
from numba.typed import List
from tqdm.auto import tqdm
from joblib import Parallel, delayed
Expand All @@ -27,7 +28,7 @@
src_path = os.path.join(current_path, "..")
sys.path.append(src_path)


# a funciton to check for invalid values of different parameters
def check_params(alpha_c,
alpha,
beta,
Expand Down Expand Up @@ -127,69 +128,15 @@ def check_params(alpha_c,

return new_alpha_c, new_alpha, new_beta, new_gamma

# @jit(nopython=True, fastmath=True, debug=True)
# def check_params(alpha_c,
# alpha,
# beta,
# gamma,
# c0=None,
# u0=None,
# s0=None):

# # check if any of our parameters are infinite
# if c0 is not None and math.isinf(c0):
# logg.error("c0 is infinite.", v=1)
# if u0 is not None and math.isinf(u0):
# logg.error("u0 is infinite.", v=1)
# if s0 is not None and math.isinf(s0):
# logg.error("s0 is infinite.", v=1)
# if math.isinf(alpha_c):
# logg.error("alpha_c is infinite.", v=1)
# if math.isinf(alpha):
# logg.error("alpha is infinite.", v=1)
# if math.isinf(beta):
# logg.error("beta is infinite.", v=1)
# if math.isinf(gamma):
# logg.error("gamma is infinite.", v=1)

# # check if any of our parameters are nan
# if c0 is not None and math.isnan(c0):
# logg.error("c0 is infinite.", v=1)
# if u0 is not None and math.isnan(u0):
# logg.error("u0 is infinite.", v=1)
# if s0 is not None and math.isnan(s0):
# logg.error("s0 is infinite.", v=1)
# if math.isnan(alpha_c):
# logg.error("alpha_c is infinite.", v=1)
# if math.isnan(alpha):
# logg.error("alpha is infinite.", v=1)
# if math.isnan(beta):
# logg.error("beta is infinite.", v=1)
# if math.isnan(gamma):
# logg.error("gamma is infinite.", v=1)

# # check if any of our rate parameters are 0
# if alpha_c < 1e-7:
# logg.error("alpha_c is zero.", v=1)
# if alpha < 1e-7:
# logg.error("alpha is zero.", v=1)
# if beta < 1e-7:
# logg.error("beta is zero.", v=1)
# if gamma < 1e-7:
# logg.error("gamma is zero.", v=1)

# if beta == alpha_c:
# logg.error("alpha_c and beta are equal, leading to divide by zero",
# v=1)
# if beta == gamma:
# logg.error("gamma and beta are equal, leading to divide by zero",
# v=1)
# if alpha_c == gamma:
# logg.error("gamma and alpha_c are equal, leading to divide by zero",
# v=1)


# @jit(nopython=True, fastmath=True, debug=True)
@njit(
locals={
"res": numba.types.float64[:, ::1],
"eat": numba.types.float64[::1],
"ebt": numba.types.float64[::1],
"egt": numba.types.float64[::1],
},
fastmath=True)
def predict_exp(tau,
c0,
u0,
Expand All @@ -204,14 +151,6 @@ def predict_exp(tau,
backward=False,
rna_only=False):

# check_params(alpha_c,
# alpha,
# beta,
# gamma,
# c0,
# u0,
# s0)

if len(tau) == 0:
return np.empty((0, 3))
if backward:
Expand Down Expand Up @@ -250,7 +189,23 @@ def predict_exp(tau,
return res


# @jit(nopython=True, fastmath=True, debug=True)
@njit(locals={
"exp_sw1": numba.types.float64[:, ::1],
"exp_sw2": numba.types.float64[:, ::1],
"exp_sw3": numba.types.float64[:, ::1],
"exp1": numba.types.float64[:, ::1],
"exp2": numba.types.float64[:, ::1],
"exp3": numba.types.float64[:, ::1],
"exp4": numba.types.float64[:, ::1],
"tau_sw1": numba.types.float64[::1],
"tau_sw2": numba.types.float64[::1],
"tau_sw3": numba.types.float64[::1],
"tau1": numba.types.float64[::1],
"tau2": numba.types.float64[::1],
"tau3": numba.types.float64[::1],
"tau4": numba.types.float64[::1]
},
fastmath=True)
def generate_exp(tau_list,
t_sw_array,
alpha_c,
Expand Down Expand Up @@ -429,7 +384,23 @@ def generate_exp(tau_list,
return (exp1, exp2, exp3, exp4), (exp_sw1, exp_sw2, exp_sw3)


# @jit(nopython=True, fastmath=True, debug=True)
@njit(locals={
"exp_sw1": numba.types.float64[:, ::1],
"exp_sw2": numba.types.float64[:, ::1],
"exp_sw3": numba.types.float64[:, ::1],
"exp1": numba.types.float64[:, ::1],
"exp2": numba.types.float64[:, ::1],
"exp3": numba.types.float64[:, ::1],
"exp4": numba.types.float64[:, ::1],
"tau_sw1": numba.types.float64[::1],
"tau_sw2": numba.types.float64[::1],
"tau_sw3": numba.types.float64[::1],
"tau1": numba.types.float64[::1],
"tau2": numba.types.float64[::1],
"tau3": numba.types.float64[::1],
"tau4": numba.types.float64[::1]
},
fastmath=True)
def generate_exp_backward(tau_list, t_sw_array, alpha_c, alpha, beta, gamma,
scale_cc=1, model=1):
if beta == alpha_c:
Expand Down Expand Up @@ -535,7 +506,10 @@ def generate_exp_backward(tau_list, t_sw_array, alpha_c, alpha, beta, gamma,
return (exp1, exp2, exp3), (exp_sw1, exp_sw2)


# @jit(nopython=True, fastmath=True, debug=True)
@njit(locals={
"res": numba.types.float64[:, ::1],
},
fastmath=True)
def ss_exp(alpha_c, alpha, beta, gamma, pred_r=True, chrom_open=True):
res = np.empty((1, 3))
if not chrom_open:
Expand All @@ -553,7 +527,13 @@ def ss_exp(alpha_c, alpha, beta, gamma, pred_r=True, chrom_open=True):
return res


# @jit(nopython=True, fastmath=True, debug=True)
@njit(locals={
"ss1": numba.types.float64[:, ::1],
"ss2": numba.types.float64[:, ::1],
"ss3": numba.types.float64[:, ::1],
"ss4": numba.types.float64[:, ::1]
},
fastmath=True)
def compute_ss_exp(alpha_c, alpha, beta, gamma, model=0):
if model == 0:
ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False)
Expand All @@ -574,7 +554,7 @@ def compute_ss_exp(alpha_c, alpha, beta, gamma, model=0):
return np.vstack((ss1, ss2, ss3, ss4))


# @jit(nopython=True, fastmath=True, debug=True)
@njit(fastmath=True)
def velocity_equations(c, u, s, alpha_c, alpha, beta, gamma, scale_cc=1,
pred_r=True, chrom_open=True, rna_only=False):
if rna_only:
Expand All @@ -594,7 +574,30 @@ def velocity_equations(c, u, s, alpha_c, alpha, beta, gamma, scale_cc=1,
return alpha_c - alpha_c * c, np.zeros(len(u)), np.zeros(len(u))


# @jit(nopython=True, fastmath=True, debug=True)
@njit(locals={
"state0": numba.types.boolean[::1],
"state1": numba.types.boolean[::1],
"state2": numba.types.boolean[::1],
"state3": numba.types.boolean[::1],
"tau1": numba.types.float64[::1],
"tau2": numba.types.float64[::1],
"tau3": numba.types.float64[::1],
"tau4": numba.types.float64[::1],
"exp_list": numba.types.Tuple((numba.types.float64[:, ::1],
numba.types.float64[:, ::1],
numba.types.float64[:, ::1],
numba.types.float64[:, ::1])),
"exp_sw_list": numba.types.Tuple((numba.types.float64[:, ::1],
numba.types.float64[:, ::1],
numba.types.float64[:, ::1])),
"c": numba.types.float64[::1],
"u": numba.types.float64[::1],
"s": numba.types.float64[::1],
"vc_vec": numba.types.float64[::1],
"vu_vec": numba.types.float64[::1],
"vs_vec": numba.types.float64[::1]
},
fastmath=True)
def compute_velocity(t,
t_sw_array,
state,
Expand Down Expand Up @@ -1331,15 +1334,6 @@ def predict_exp_ten(self,
backward=False,
rna_only=False):

#TODO: Check params??
# check_params(alpha_c,
# alpha,
# beta,
# gamma,
# c0,
# u0,
# s0)

if scale_cc is None:
scale_cc = torch.tensor(1.0, requires_grad=True,
device=self.device,
Expand Down

0 comments on commit 10532f9

Please sign in to comment.