diff --git a/sml/linear_model/emulations/quantile_emul.py b/sml/linear_model/emulations/quantile_emul.py index 64ed95d3..c7c08d5f 100644 --- a/sml/linear_model/emulations/quantile_emul.py +++ b/sml/linear_model/emulations/quantile_emul.py @@ -41,7 +41,7 @@ def proc_wrapper( def proc(X, y): quantile_custom_fit = quantile_custom.fit(X, y) result = quantile_custom_fit.predict(X) - return result + return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_ return proc @@ -69,30 +69,36 @@ def generate_data(): # compare with sklearn quantile_sklearn = SklearnQuantileRegressor( - quantile=0.3, alpha=0.1, fit_intercept=True, solver='highs' + quantile=0.2, alpha=0.1, fit_intercept=True, solver='highs' ) start = time.time() quantile_sklearn_fit = quantile_sklearn.fit(X, y) - score_plain = jnp.mean(y <= quantile_sklearn_fit.predict(X)) + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) end = time.time() print(f"Running time in SKlearn: {end - start:.2f}s") + print(quantile_sklearn_fit.coef_) + print(quantile_sklearn_fit.intercept_) # mark these data to be protected in SPU X_spu, y_spu = emulator.seal(X, y) # run + # Larger max_iter can give higher accuracy, but it will take more time to run proc = proc_wrapper( - quantile=0.3, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300 + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100 ) start = time.time() - result = emulator.run(proc)(X_spu, y_spu) + result, coef, intercept = emulator.run(proc)(X_spu, y_spu) end = time.time() - score_encrpted = jnp.mean(y <= result) + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) print(f"Running time in SPU: {end - start:.2f}s") + print(coef) + print(intercept) - # print acc - print(f"Accuracy in SKlearn: {score_plain:.2f}") - print(f"Accuracy in SPU: {score_encrpted:.2f}") + # print RMSE + print(f"RMSE in SKlearn: {rmse_plain:.2f}") + print(f"RMSE in SPU: {rmse_encrpted:.2f}") finally: emulator.down() diff --git a/sml/linear_model/quantile.py b/sml/linear_model/quantile.py index ad7a2c8b..549e67ae 100644 --- a/sml/linear_model/quantile.py +++ b/sml/linear_model/quantile.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numbers -import warnings -from warnings import warn - import jax import jax.numpy as jnp import pandas as pd @@ -46,6 +42,8 @@ class QuantileRegressor: The maximum number of iterations for the optimization algorithm. This controls how long the model will continue to update the weights before stopping. + max_val : float, default=1e10 + The maximum value allowed for the model parameters. Attributes ---------- coef_ : array-like of shape (n_features,) @@ -57,13 +55,20 @@ class QuantileRegressor: """ def __init__( - self, quantile=0.5, alpha=1.0, fit_intercept=True, lr=0.01, max_iter=1000 + self, + quantile=0.5, + alpha=1.0, + fit_intercept=True, + lr=0.01, + max_iter=1000, + max_val=1e10, ): self.quantile = quantile self.alpha = alpha self.fit_intercept = fit_intercept self.lr = lr self.max_iter = max_iter + self.max_val = max_val self.coef_ = None self.intercept_ = None @@ -94,7 +99,6 @@ def fit(self, X, y, sample_weight=None): n_samples, n_features = X.shape n_params = n_features - # sample_weight = jnp.ones((n_samples,)) if sample_weight is None: sample_weight = jnp.ones((n_samples,)) @@ -141,9 +145,11 @@ def fit(self, X, y, sample_weight=None): b = y - result = _linprog_simplex(c, A, b, maxiter=self.max_iter, tol=1e-3) + result = _linprog_simplex( + c, A, b, maxiter=self.max_iter, tol=1e-3, max_val=self.max_val + ) - solution = result[0] + solution = result params = solution[:n_params] - solution[n_params : 2 * n_params] @@ -177,9 +183,14 @@ def predict(self, X): - If there is no intercept, the method simply computes the dot product between `X` and the coefficients. """ - if self.fit_intercept: - X = jnp.column_stack((jnp.ones(X.shape[0]), X)) + assert ( + self.coef_ is not None and self.intercept_ is not None + ), "Model has not been fitted yet. Please fit the model before predicting." - return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_])) - else: - return jnp.dot(X, self.coef_) + n_features = len(self.coef_) + assert X.shape[1] == n_features, ( + f"Input X must have {n_features} features, " + f"but got {X.shape[1]} features instead." + ) + + return jnp.dot(X, self.coef_) + self.intercept_ diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py index 54daa5ce..5e693ede 100644 --- a/sml/linear_model/tests/quantile_test.py +++ b/sml/linear_model/tests/quantile_test.py @@ -15,8 +15,6 @@ import unittest import jax.numpy as jnp - -# import numpy as np from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor import spu.spu_pb2 as spu_pb2 # type: ignore @@ -71,20 +69,22 @@ def generate_data(): quantile=0.2, alpha=0.1, fit_intercept=True, solver='revised simplex' ) quantile_sklearn_fit = quantile_sklearn.fit(X, y) - acc_sklearn = jnp.mean(y <= quantile_sklearn_fit.predict(X)) - print(f"Accuracy in SKlearn: {acc_sklearn:.2f}") + y_pred_plain = quantile_sklearn_fit.predict(X) + rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2)) + print(f"RMSE in SKlearn: {rmse_plain:.2f}") print(quantile_sklearn_fit.coef_) print(quantile_sklearn_fit.intercept_) # run + # Larger max_iter can give higher accuracy, but it will take more time to run proc = proc_wrapper( - quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300 + quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100 ) result, coef, intercept = spsim.sim_jax(sim, proc)(X, y) - acc_custom = jnp.mean(y <= result) + rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2)) - # print accuracy - print(f"Accuracy in SPU: {acc_custom:.2f}") + # print RMSE + print(f"RMSE in SPU: {rmse_encrpted:.2f}") print(coef) print(intercept) diff --git a/sml/linear_model/utils/_linprog_simplex.py b/sml/linear_model/utils/_linprog_simplex.py index a08c2bd1..a97a460f 100644 --- a/sml/linear_model/utils/_linprog_simplex.py +++ b/sml/linear_model/utils/_linprog_simplex.py @@ -12,40 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from warnings import warn - import jax import jax.numpy as jnp from jax import jit, lax -def _pivot_col(T, tol=1e-5, bland=False): +def _pivot_col(T, tol=1e-5): mask = T[-1, :-1] >= -tol all_masked = jnp.all(mask) - bland_first_col = jnp.argmin(jnp.where(mask, jnp.inf, jnp.arange(T.shape[1] - 1))) # 定义根据最小值选择列的函数 ma = jnp.where(mask, jnp.inf, T[-1, :-1]) min_col = jnp.argmin(ma) - result = jnp.where(bland, bland_first_col, min_col) - valid = ~all_masked - result = jnp.where(all_masked, 0, result) + result = jnp.where(all_masked, 0, min_col) return valid, result -def _pivot_row(T, basis, pivcol, phase, tol=1e-5, bland=False): +def _pivot_row(T, basis, pivcol, phase, tol=1e-5, max_val=1e10): def true_mask_func(T, pivcol): - mask = T[:-2, pivcol] <= tol - ma = jnp.where(mask, jnp.inf, T[:-2, pivcol]) - mb = jnp.where(mask, jnp.inf, T[:-2, -1]) + if phase == 1: + k = 2 + else: + k = 1 + + mask = T[:-k, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-k, -1]) - q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma) + q = jnp.where(ma >= max_val, jnp.inf, mb / ma) # 选择最小比值的行 min_rows = jnp.nanargmin(q) @@ -53,11 +52,16 @@ def true_mask_func(T, pivcol): return min_rows, all_masked def false_mask_func(T, pivcol): - mask = T[:-1, pivcol] <= tol - ma = jnp.where(mask, jnp.inf, T[:-1, pivcol]) - mb = jnp.where(mask, jnp.inf, T[:-1, -1]) + if phase == 1: + k = 2 + else: + k = 1 - q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma) + mask = T[:-k, pivcol] <= tol + ma = jnp.where(mask, jnp.inf, T[:-k, pivcol]) + mb = jnp.where(mask, jnp.inf, T[:-k, -1]) + + q = jnp.where(ma >= max_val, jnp.inf, mb / ma) # 选择最小比值的行 min_rows = jnp.nanargmin(q) @@ -69,20 +73,14 @@ def false_mask_func(T, pivcol): min_rows = jnp.where(phase == 1, true_min_rows, false_min_rows) all_masked = jnp.where(phase == 1, true_all_masked, false_all_masked) - # 检查掩码数组是否全被掩盖 - has_valid_row = min_rows.size > 0 row = min_rows - # 处理全被掩盖的情况 row = jnp.where(all_masked, 0, row) - # 处理没有满足条件的行的情况 - row = jnp.where(has_valid_row, row, 0) + return ~all_masked, row - return ~all_masked & has_valid_row, row - -def _apply_pivot(T, basis, pivrow, pivcol, tol=1e-5): +def _apply_pivot(T, basis, pivrow, pivcol): pivrow = jnp.int32(pivrow) pivcol = jnp.int32(pivcol) @@ -110,57 +108,45 @@ def _solve_simplex( T, n, basis, - maxiter=300, + maxiter=100, tol=1e-5, phase=2, - bland=False, ): - status = 0 complete = False num = 0 pivcol = 0 pivrow = 0 while num < maxiter: - pivcol_found, pivcol = _pivot_col(T, tol, bland) + pivcol_found, pivcol = _pivot_col(T, tol) - def cal_pivcol_found_True( - T, basis, pivcol, phase, tol, bland, status, complete - ): - pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol, bland) + def cal_pivcol_found_True(T, basis, pivcol, phase, tol, complete): + pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol) pivrow_isnot_found = pivrow_found == False - status = jnp.where(pivrow_isnot_found, 1, status) complete = jnp.where(pivrow_isnot_found, True, complete) - return pivrow, status, complete - - pivcol_isnot_found = pivcol_found == False - pivcol = jnp.where(pivcol_isnot_found, 0, pivcol) - pivrow = jnp.where(pivcol_isnot_found, 0, pivrow) - status = jnp.where(pivcol_isnot_found, 0, status) - complete = jnp.where(pivcol_isnot_found, True, complete) + return pivrow, complete pivcol_is_found = pivcol_found == True - pivrow_True, status_True, complete_True = cal_pivcol_found_True( - T, basis, pivcol, phase, tol, bland, status, complete + pivrow_True, complete_True = cal_pivcol_found_True( + T, basis, pivcol, phase, tol, complete ) - pivrow = jnp.where(pivcol_is_found, pivrow_True, pivrow) - status = jnp.where(pivcol_is_found, status_True, status) + pivrow = jnp.where(pivcol_is_found, pivrow_True, 0) + complete = jnp.where(pivcol_is_found, complete_True, complete) complete_is_False = complete == False - apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol, tol) + apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol) T = jnp.where(complete_is_False, apply_T, T) basis = jnp.where(complete_is_False, apply_basis, basis) num = num + 1 - return T, basis, status + return T, basis -def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, bland=False): - status = 0 +def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, max_val=1e10): n, m = A.shape # All constraints must have b >= 0. @@ -178,21 +164,16 @@ def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, bland=False): T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective)) # phase 1 - T, basis, status = _solve_simplex( - T, n, basis, maxiter=maxiter, tol=tol, phase=1, bland=bland - ) - - status = jnp.where(jnp.abs(T[-1, -1]) < tol, status, 1) + T, basis = _solve_simplex(T, n, basis, maxiter=maxiter, tol=tol, phase=1) T_new = T[:-1, :] - jit_delete = jit(jnp.delete, static_argnames=['assume_unique_indices']) T = jnp.delete(T_new, av, 1, assume_unique_indices=True) # phase 2 - T, basis, status = _solve_simplex(T, n, basis, maxiter, tol, 2, bland) + T, basis = _solve_simplex(T, n, basis, maxiter, tol, 2) solution = jnp.zeros(n + m) solution = solution.at[basis[:n]].set(T[:n, -1]) x = solution[:m] - return x, status + return x