Skip to content

Commit

Permalink
Formatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
hesy7 committed Sep 3, 2024
1 parent a4668c2 commit d1d8166
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 30 deletions.
6 changes: 3 additions & 3 deletions sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def confusion_matrix(y_true, y_pred, labels, sample_weight=None, normalize=None)
# Calculate the confusion matrix
for i, label in enumerate(labels):
# Get the true label and predicted label as the Boolean value of the current label
true_mask = (y_true == label)
pred_mask = (y_pred == label)
true_mask = y_true == label
pred_mask = y_pred == label

# Update the confusion matrix
for j, _ in enumerate(labels):
Expand All @@ -50,7 +50,7 @@ def confusion_matrix(y_true, y_pred, labels, sample_weight=None, normalize=None)


def balanced_accuracy_score(y_true, y_pred, labels, sample_weight=None, adjusted=False):
""" calculate balanced accuracy score """
"""calculate balanced accuracy score"""
C = confusion_matrix(y_true, y_pred, labels, sample_weight=sample_weight)
with np.errstate(divide="ignore", invalid="ignore"):
per_class = jnp.diag(C) / C.sum(axis=1)
Expand Down
36 changes: 23 additions & 13 deletions sml/metrics/classification/classification_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@
from sml.metrics.classification.classification import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
f1_score,
precision_score,
recall_score,
roc_auc_score,
balanced_accuracy_score,
top_k_accuracy_score
top_k_accuracy_score,
)


# TODO: design the enumation framework, just like py.unittest
# all emulation action should begin with `emul_` (for reflection)


def emul_balanced_accuracy(mode: emulation.Mode.MULTIPROCESS):
def proc(y_true: jnp.ndarray, y_pred: jnp.ndarray, labels: jnp.ndarray):
balanced_score = balanced_accuracy_score(y_true, y_pred, labels)
Expand Down Expand Up @@ -69,9 +69,17 @@ def check(spu_result, sk_result):


def emul_top_k_accuracy_score(mode: emulation.Mode.MULTIPROCESS):
def proc(y_true: jnp.ndarray, y_pred: jnp.ndarray, k, normalize, sample_weight, labels):
top_k_score = top_k_accuracy_score(y_true, y_pred, k=k, normalize=normalize, sample_weight=sample_weight,
labels=labels)
def proc(
y_true: jnp.ndarray, y_pred: jnp.ndarray, k, normalize, sample_weight, labels
):
top_k_score = top_k_accuracy_score(
y_true,
y_pred,
k=k,
normalize=normalize,
sample_weight=sample_weight,
labels=labels,
)
return top_k_score

def sklearn_proc(y_true, y_pred, k, labels):
Expand All @@ -83,13 +91,15 @@ def check(spu_result, sk_result):

# Test multiclass
y_true = jnp.array([0, 1, 2, 2, 0])
y_score = jnp.array([
[0.8, 0.1, 0.1],
[0.3, 0.4, 0.3],
[0.1, 0.1, 0.8],
[0.2, 0.2, 0.6],
[0.7, 0.2, 0.1]
])
y_score = jnp.array(
[
[0.8, 0.1, 0.1],
[0.3, 0.4, 0.3],
[0.1, 0.1, 0.8],
[0.2, 0.2, 0.6],
[0.7, 0.2, 0.1],
]
)
spu_result = emulator.run(proc, static_argnums=(2, 3))(
y_true, y_score, 2, True, None, None
)
Expand Down
45 changes: 31 additions & 14 deletions sml/metrics/classification/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
from sml.metrics.classification.classification import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
bin_counts,
equal_obs,
f1_score,
precision_score,
recall_score,
roc_auc_score,
balanced_accuracy_score,
top_k_accuracy_score
top_k_accuracy_score,
)


Expand Down Expand Up @@ -83,27 +83,44 @@ def test_top_k_accuracy(self):
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

def proc(y_true: jnp.ndarray, y_pred: jnp.ndarray, k, normalize, sample_weight, labels):
top_k_score = top_k_accuracy_score(y_true, y_pred, k=k, normalize=normalize, sample_weight=sample_weight,
labels=labels)
def proc(
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
k,
normalize,
sample_weight,
labels,
):
top_k_score = top_k_accuracy_score(
y_true,
y_pred,
k=k,
normalize=normalize,
sample_weight=sample_weight,
labels=labels,
)
return top_k_score

def sklearn_proc(y_true, y_pred, k, labels):
top_k_score = metrics.top_k_accuracy_score(y_true, y_pred, k=k, labels=labels)
top_k_score = metrics.top_k_accuracy_score(
y_true, y_pred, k=k, labels=labels
)
return top_k_score

def check(spu_result, sk_result):
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5)

# Test multiclass
y_true = jnp.array([0, 1, 2, 2, 0])
y_score = jnp.array([
[0.8, 0.1, 0.1],
[0.3, 0.4, 0.3],
[0.1, 0.1, 0.8],
[0.2, 0.2, 0.6],
[0.7, 0.2, 0.1]
])
y_score = jnp.array(
[
[0.8, 0.1, 0.1],
[0.3, 0.4, 0.3],
[0.1, 0.1, 0.8],
[0.2, 0.2, 0.6],
[0.7, 0.2, 0.1],
]
)
spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 3))(
y_true, y_score, 2, True, None, None
)
Expand Down Expand Up @@ -160,7 +177,7 @@ def test_classification(self):
)

def proc(
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
y_true, y_pred, average='binary', labels=None, pos_label=1, transform=1
):
f1 = f1_score(
y_true,
Expand Down

0 comments on commit d1d8166

Please sign in to comment.