Skip to content

Commit

Permalink
fix problems
Browse files Browse the repository at this point in the history
  • Loading branch information
xbw886 committed Oct 24, 2024
1 parent cab923a commit 8c1f00e
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 87 deletions.
24 changes: 15 additions & 9 deletions sml/linear_model/emulations/quantile_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def proc_wrapper(
def proc(X, y):
quantile_custom_fit = quantile_custom.fit(X, y)
result = quantile_custom_fit.predict(X)
return result
return result, quantile_custom_fit.coef_, quantile_custom_fit.intercept_

return proc

Expand Down Expand Up @@ -69,30 +69,36 @@ def generate_data():

# compare with sklearn
quantile_sklearn = SklearnQuantileRegressor(
quantile=0.3, alpha=0.1, fit_intercept=True, solver='highs'
quantile=0.2, alpha=0.1, fit_intercept=True, solver='highs'
)
start = time.time()
quantile_sklearn_fit = quantile_sklearn.fit(X, y)
score_plain = jnp.mean(y <= quantile_sklearn_fit.predict(X))
y_pred_plain = quantile_sklearn_fit.predict(X)
rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2))
end = time.time()
print(f"Running time in SKlearn: {end - start:.2f}s")
print(quantile_sklearn_fit.coef_)
print(quantile_sklearn_fit.intercept_)

# mark these data to be protected in SPU
X_spu, y_spu = emulator.seal(X, y)

# run
# Larger max_iter can give higher accuracy, but it will take more time to run
proc = proc_wrapper(
quantile=0.3, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100
)
start = time.time()
result = emulator.run(proc)(X_spu, y_spu)
result, coef, intercept = emulator.run(proc)(X_spu, y_spu)
end = time.time()
score_encrpted = jnp.mean(y <= result)
rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2))
print(f"Running time in SPU: {end - start:.2f}s")
print(coef)
print(intercept)

# print acc
print(f"Accuracy in SKlearn: {score_plain:.2f}")
print(f"Accuracy in SPU: {score_encrpted:.2f}")
# print RMSE
print(f"RMSE in SKlearn: {rmse_plain:.2f}")
print(f"RMSE in SPU: {rmse_encrpted:.2f}")

finally:
emulator.down()
Expand Down
37 changes: 24 additions & 13 deletions sml/linear_model/quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers
import warnings
from warnings import warn

import jax
import jax.numpy as jnp
import pandas as pd
Expand Down Expand Up @@ -46,6 +42,8 @@ class QuantileRegressor:
The maximum number of iterations for the optimization algorithm.
This controls how long the model will continue to update the weights
before stopping.
max_val : float, default=1e10
The maximum value allowed for the model parameters.
Attributes
----------
coef_ : array-like of shape (n_features,)
Expand All @@ -57,13 +55,20 @@ class QuantileRegressor:
"""

def __init__(
self, quantile=0.5, alpha=1.0, fit_intercept=True, lr=0.01, max_iter=1000
self,
quantile=0.5,
alpha=1.0,
fit_intercept=True,
lr=0.01,
max_iter=1000,
max_val=1e10,
):
self.quantile = quantile
self.alpha = alpha
self.fit_intercept = fit_intercept
self.lr = lr
self.max_iter = max_iter
self.max_val = max_val

self.coef_ = None
self.intercept_ = None
Expand Down Expand Up @@ -94,7 +99,6 @@ def fit(self, X, y, sample_weight=None):
n_samples, n_features = X.shape
n_params = n_features

# sample_weight = jnp.ones((n_samples,))
if sample_weight is None:
sample_weight = jnp.ones((n_samples,))

Expand Down Expand Up @@ -141,9 +145,11 @@ def fit(self, X, y, sample_weight=None):

b = y

result = _linprog_simplex(c, A, b, maxiter=self.max_iter, tol=1e-3)
result = _linprog_simplex(
c, A, b, maxiter=self.max_iter, tol=1e-3, max_val=self.max_val
)

solution = result[0]
solution = result

params = solution[:n_params] - solution[n_params : 2 * n_params]

Expand Down Expand Up @@ -177,9 +183,14 @@ def predict(self, X):
- If there is no intercept, the method simply computes the dot product between `X` and the coefficients.
"""

if self.fit_intercept:
X = jnp.column_stack((jnp.ones(X.shape[0]), X))
assert (
self.coef_ is not None and self.intercept_ is not None
), "Model has not been fitted yet. Please fit the model before predicting."

return jnp.dot(X, jnp.hstack([self.intercept_, self.coef_]))
else:
return jnp.dot(X, self.coef_)
n_features = len(self.coef_)
assert X.shape[1] == n_features, (
f"Input X must have {n_features} features, "
f"but got {X.shape[1]} features instead."
)

return jnp.dot(X, self.coef_) + self.intercept_
16 changes: 8 additions & 8 deletions sml/linear_model/tests/quantile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import unittest

import jax.numpy as jnp

# import numpy as np
from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor

import spu.spu_pb2 as spu_pb2 # type: ignore
Expand Down Expand Up @@ -71,20 +69,22 @@ def generate_data():
quantile=0.2, alpha=0.1, fit_intercept=True, solver='revised simplex'
)
quantile_sklearn_fit = quantile_sklearn.fit(X, y)
acc_sklearn = jnp.mean(y <= quantile_sklearn_fit.predict(X))
print(f"Accuracy in SKlearn: {acc_sklearn:.2f}")
y_pred_plain = quantile_sklearn_fit.predict(X)
rmse_plain = jnp.sqrt(jnp.mean((y - y_pred_plain) ** 2))
print(f"RMSE in SKlearn: {rmse_plain:.2f}")
print(quantile_sklearn_fit.coef_)
print(quantile_sklearn_fit.intercept_)

# run
# Larger max_iter can give higher accuracy, but it will take more time to run
proc = proc_wrapper(
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=300
quantile=0.2, alpha=0.1, fit_intercept=True, lr=0.01, max_iter=100
)
result, coef, intercept = spsim.sim_jax(sim, proc)(X, y)
acc_custom = jnp.mean(y <= result)
rmse_encrpted = jnp.sqrt(jnp.mean((y - result) ** 2))

# print accuracy
print(f"Accuracy in SPU: {acc_custom:.2f}")
# print RMSE
print(f"RMSE in SPU: {rmse_encrpted:.2f}")
print(coef)
print(intercept)

Expand Down
95 changes: 38 additions & 57 deletions sml/linear_model/utils/_linprog_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from warnings import warn

import jax
import jax.numpy as jnp
from jax import jit, lax


def _pivot_col(T, tol=1e-5, bland=False):
def _pivot_col(T, tol=1e-5):
mask = T[-1, :-1] >= -tol

all_masked = jnp.all(mask)

bland_first_col = jnp.argmin(jnp.where(mask, jnp.inf, jnp.arange(T.shape[1] - 1)))
# 定义根据最小值选择列的函数
ma = jnp.where(mask, jnp.inf, T[-1, :-1])
min_col = jnp.argmin(ma)

result = jnp.where(bland, bland_first_col, min_col)

valid = ~all_masked
result = jnp.where(all_masked, 0, result)
result = jnp.where(all_masked, 0, min_col)

return valid, result


def _pivot_row(T, basis, pivcol, phase, tol=1e-5, bland=False):
def _pivot_row(T, basis, pivcol, phase, tol=1e-5, max_val=1e10):

def true_mask_func(T, pivcol):
mask = T[:-2, pivcol] <= tol
ma = jnp.where(mask, jnp.inf, T[:-2, pivcol])
mb = jnp.where(mask, jnp.inf, T[:-2, -1])
if phase == 1:
k = 2
else:
k = 1

mask = T[:-k, pivcol] <= tol
ma = jnp.where(mask, jnp.inf, T[:-k, pivcol])
mb = jnp.where(mask, jnp.inf, T[:-k, -1])

q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma)
q = jnp.where(ma >= max_val, jnp.inf, mb / ma)

# 选择最小比值的行
min_rows = jnp.nanargmin(q)
all_masked = jnp.all(mask)
return min_rows, all_masked

def false_mask_func(T, pivcol):
mask = T[:-1, pivcol] <= tol
ma = jnp.where(mask, jnp.inf, T[:-1, pivcol])
mb = jnp.where(mask, jnp.inf, T[:-1, -1])
if phase == 1:
k = 2
else:
k = 1

q = jnp.where(ma == 1.75921860e13, jnp.inf, mb / ma)
mask = T[:-k, pivcol] <= tol
ma = jnp.where(mask, jnp.inf, T[:-k, pivcol])
mb = jnp.where(mask, jnp.inf, T[:-k, -1])

q = jnp.where(ma >= max_val, jnp.inf, mb / ma)

# 选择最小比值的行
min_rows = jnp.nanargmin(q)
Expand All @@ -69,20 +73,14 @@ def false_mask_func(T, pivcol):
min_rows = jnp.where(phase == 1, true_min_rows, false_min_rows)
all_masked = jnp.where(phase == 1, true_all_masked, false_all_masked)

# 检查掩码数组是否全被掩盖
has_valid_row = min_rows.size > 0
row = min_rows

# 处理全被掩盖的情况
row = jnp.where(all_masked, 0, row)

# 处理没有满足条件的行的情况
row = jnp.where(has_valid_row, row, 0)
return ~all_masked, row

return ~all_masked & has_valid_row, row


def _apply_pivot(T, basis, pivrow, pivcol, tol=1e-5):
def _apply_pivot(T, basis, pivrow, pivcol):
pivrow = jnp.int32(pivrow)
pivcol = jnp.int32(pivcol)

Expand Down Expand Up @@ -110,57 +108,45 @@ def _solve_simplex(
T,
n,
basis,
maxiter=300,
maxiter=100,
tol=1e-5,
phase=2,
bland=False,
):
status = 0
complete = False

num = 0
pivcol = 0
pivrow = 0
while num < maxiter:
pivcol_found, pivcol = _pivot_col(T, tol, bland)
pivcol_found, pivcol = _pivot_col(T, tol)

def cal_pivcol_found_True(
T, basis, pivcol, phase, tol, bland, status, complete
):
pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol, bland)
def cal_pivcol_found_True(T, basis, pivcol, phase, tol, complete):
pivrow_found, pivrow = _pivot_row(T, basis, pivcol, phase, tol)

pivrow_isnot_found = pivrow_found == False
status = jnp.where(pivrow_isnot_found, 1, status)
complete = jnp.where(pivrow_isnot_found, True, complete)

return pivrow, status, complete

pivcol_isnot_found = pivcol_found == False
pivcol = jnp.where(pivcol_isnot_found, 0, pivcol)
pivrow = jnp.where(pivcol_isnot_found, 0, pivrow)
status = jnp.where(pivcol_isnot_found, 0, status)
complete = jnp.where(pivcol_isnot_found, True, complete)
return pivrow, complete

pivcol_is_found = pivcol_found == True
pivrow_True, status_True, complete_True = cal_pivcol_found_True(
T, basis, pivcol, phase, tol, bland, status, complete
pivrow_True, complete_True = cal_pivcol_found_True(
T, basis, pivcol, phase, tol, complete
)

pivrow = jnp.where(pivcol_is_found, pivrow_True, pivrow)
status = jnp.where(pivcol_is_found, status_True, status)
pivrow = jnp.where(pivcol_is_found, pivrow_True, 0)

complete = jnp.where(pivcol_is_found, complete_True, complete)

complete_is_False = complete == False
apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol, tol)
apply_T, apply_basis = _apply_pivot(T, basis, pivrow, pivcol)
T = jnp.where(complete_is_False, apply_T, T)
basis = jnp.where(complete_is_False, apply_basis, basis)
num = num + 1

return T, basis, status
return T, basis


def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, bland=False):
status = 0
def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, max_val=1e10):
n, m = A.shape

# All constraints must have b >= 0.
Expand All @@ -178,21 +164,16 @@ def _linprog_simplex(c, A, b, c0=0, maxiter=300, tol=1e-5, bland=False):
T = jnp.vstack((row_constraints, row_objective, row_pseudo_objective))

# phase 1
T, basis, status = _solve_simplex(
T, n, basis, maxiter=maxiter, tol=tol, phase=1, bland=bland
)

status = jnp.where(jnp.abs(T[-1, -1]) < tol, status, 1)
T, basis = _solve_simplex(T, n, basis, maxiter=maxiter, tol=tol, phase=1)

T_new = T[:-1, :]
jit_delete = jit(jnp.delete, static_argnames=['assume_unique_indices'])
T = jnp.delete(T_new, av, 1, assume_unique_indices=True)

# phase 2
T, basis, status = _solve_simplex(T, n, basis, maxiter, tol, 2, bland)
T, basis = _solve_simplex(T, n, basis, maxiter, tol, 2)

solution = jnp.zeros(n + m)
solution = solution.at[basis[:n]].set(T[:n, -1])
x = solution[:m]

return x, status
return x

0 comments on commit 8c1f00e

Please sign in to comment.