-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes #726: Implemented two accuracy functions #842
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,91 @@ | |
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from sml.preprocessing.preprocessing import label_binarize | ||
from spu.ops.groupby import groupby, groupby_sum | ||
|
||
from .auc import binary_clf_curve, binary_roc_auc | ||
|
||
|
||
def confusion_matrix(y_true, y_pred, labels, sample_weight=None, normalize=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add some docs for the functionality of this function, and the means of all params. |
||
"""calculate the confusion matrix""" | ||
y_true = jnp.array(y_true) | ||
y_pred = jnp.array(y_pred) | ||
|
||
# Get the number of tags | ||
num_labels = len(labels) | ||
|
||
# Initialize the confusion matrix | ||
cm = jnp.zeros((num_labels, num_labels), dtype=jnp.int32) | ||
|
||
# Calculate the confusion matrix | ||
for i, label in enumerate(labels): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Vectorized operations can replace these two for-loops, So the round complexity can be reduced. e.g. |
||
# Get the true label and predicted label as the Boolean value of the current label | ||
true_mask = y_true == label | ||
pred_mask = y_pred == label | ||
|
||
# Update the confusion matrix | ||
for j, _ in enumerate(labels): | ||
# Calculate TP, FP, FN, TN | ||
cm = cm.at[i, j].set(jnp.sum(true_mask & (y_pred == j))) | ||
|
||
return cm | ||
|
||
|
||
def balanced_accuracy_score(y_true, y_pred, labels, sample_weight=None, adjusted=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add some docs for the functionality of this function, and the means of all params. |
||
"""calculate balanced accuracy score""" | ||
C = confusion_matrix(y_true, y_pred, labels, sample_weight=sample_weight) | ||
with np.errstate(divide="ignore", invalid="ignore"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The context management can not work in SPU, can just delete this |
||
per_class = jnp.diag(C) / C.sum(axis=1) | ||
score = jnp.mean(per_class) | ||
if adjusted: | ||
n_classes = len(per_class) | ||
chance = 1 / n_classes | ||
score -= chance | ||
score /= 1 - chance | ||
return score | ||
|
||
|
||
def top_k_accuracy_score( | ||
y_true, y_score, k=2, normalize=True, sample_weight=None, labels=None | ||
): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add some docs for the functionality of this function, and the means of all params. |
||
Top-k Accuracy classification score. | ||
This metric computes the number of times when the correct label is among | ||
the top `k` labels predicted (ranked by predicted scores). | ||
""" | ||
|
||
y_true = jnp.asarray(y_true) | ||
y_score = jnp.asarray(y_score) | ||
|
||
if labels is not None: | ||
# If labels are provided, make sure y_true and y_score are included in labels | ||
labels = jnp.asarray(labels) | ||
y_true = jnp.searchsorted(labels, y_true, sorter=jnp.argsort(labels)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Searchsorted is a costly op, you can just skip it. Then, you should comment in the docs to hint to the user that the y_true should only contain values 0,1,2..., len(labels)-1 |
||
|
||
# Compute the indices of the top k predictions for each sample | ||
top_k_indices = jnp.argsort(y_score, axis=1)[:, -k:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are some easy cases: Binary case, when k=1, you just compare to 0.5; k>1, you just return 1 (Indeed, you can always check whether k >= len(labels)). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general cases (multi-class case), you can use |
||
|
||
# Check if y_true is among the top k predictions | ||
y_true_in_top_k = jnp.any(jnp.isin(y_true[:, None], top_k_indices), axis=1) | ||
|
||
# Calculate accuracy | ||
correct_predictions = jnp.sum(y_true_in_top_k) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should think twice about the
Maybe you can refer to sklearn about the logic |
||
|
||
if sample_weight is not None: | ||
sample_weight = jnp.asarray(sample_weight) | ||
accuracy = jnp.sum(sample_weight * y_true_in_top_k) / jnp.sum(sample_weight) | ||
else: | ||
accuracy = correct_predictions / len(y_true) | ||
|
||
if normalize: | ||
return accuracy | ||
else: | ||
return correct_predictions | ||
|
||
|
||
def roc_auc_score(y_true, y_pred): | ||
sorted_arr = create_sorted_label_score_pair(y_true, y_pred) | ||
return binary_roc_auc(sorted_arr) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,15 +27,86 @@ | |
from sml.metrics.classification.classification import ( | ||
accuracy_score, | ||
average_precision_score, | ||
balanced_accuracy_score, | ||
f1_score, | ||
precision_score, | ||
recall_score, | ||
roc_auc_score, | ||
top_k_accuracy_score, | ||
) | ||
|
||
|
||
# TODO: design the enumation framework, just like py.unittest | ||
# all emulation action should begin with `emul_` (for reflection) | ||
|
||
|
||
def emul_balanced_accuracy(mode: emulation.Mode.MULTIPROCESS): | ||
def proc(y_true: jnp.ndarray, y_pred: jnp.ndarray, labels: jnp.ndarray): | ||
balanced_score = balanced_accuracy_score(y_true, y_pred, labels) | ||
return balanced_score | ||
|
||
def sklearn_proc(y_true, y_pred): | ||
balanced_score = metrics.balanced_accuracy_score(y_true, y_pred) | ||
return balanced_score | ||
|
||
def check(spu_result, sk_result): | ||
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5) | ||
|
||
# Test binary | ||
y_true = jnp.array([0, 1, 1, 0, 1, 1]) | ||
y_pred = jnp.array([0, 0, 1, 0, 1, 1]) | ||
labels = jnp.array([0, 1]) | ||
spu_result = emulator.run(proc)(y_true, y_pred, labels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should seal the input tensors before feeding them into the emulator, you can refer to other |
||
sk_result = sklearn_proc(y_true, y_pred) | ||
check(spu_result, sk_result) | ||
|
||
# Test multiclass | ||
y_true = jnp.array([0, 1, 1, 0, 2, 1]) | ||
y_pred = jnp.array([0, 0, 1, 0, 2, 1]) | ||
labels = jnp.array([0, 1, 2]) | ||
spu_result = emulator.run(proc)(y_true, y_pred, labels) | ||
sk_result = sklearn_proc(y_true, y_pred) | ||
check(spu_result, sk_result) | ||
|
||
|
||
def emul_top_k_accuracy_score(mode: emulation.Mode.MULTIPROCESS): | ||
def proc( | ||
y_true: jnp.ndarray, y_pred: jnp.ndarray, k, normalize, sample_weight, labels | ||
): | ||
top_k_score = top_k_accuracy_score( | ||
y_true, | ||
y_pred, | ||
k=k, | ||
normalize=normalize, | ||
sample_weight=sample_weight, | ||
labels=labels, | ||
) | ||
return top_k_score | ||
|
||
def sklearn_proc(y_true, y_pred, k, labels): | ||
top_k_score = metrics.top_k_accuracy_score(y_true, y_pred, k=k, labels=labels) | ||
return top_k_score | ||
|
||
def check(spu_result, sk_result): | ||
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5) | ||
|
||
# Test multiclass | ||
y_true = jnp.array([0, 1, 2, 2, 0]) | ||
y_score = jnp.array( | ||
[ | ||
[0.8, 0.1, 0.1], | ||
[0.3, 0.4, 0.3], | ||
[0.1, 0.1, 0.8], | ||
[0.2, 0.2, 0.6], | ||
[0.7, 0.2, 0.1], | ||
] | ||
) | ||
spu_result = emulator.run(proc, static_argnums=(2, 3))( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, should seal the data beforehand |
||
y_true, y_score, 2, True, None, None | ||
) | ||
sk_result = sklearn_proc(y_true, y_score, k=2, labels=None) | ||
check(spu_result, sk_result) | ||
|
||
|
||
def emul_auc(mode: emulation.Mode.MULTIPROCESS): | ||
# Create dataset | ||
row = 10000 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,16 +33,100 @@ | |
from sml.metrics.classification.classification import ( | ||
accuracy_score, | ||
average_precision_score, | ||
balanced_accuracy_score, | ||
bin_counts, | ||
equal_obs, | ||
f1_score, | ||
precision_score, | ||
recall_score, | ||
roc_auc_score, | ||
top_k_accuracy_score, | ||
) | ||
|
||
|
||
class UnitTests(unittest.TestCase): | ||
|
||
def test_balanced_accuracy(self): | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
|
||
def proc(y_true: jnp.ndarray, y_pred: jnp.ndarray, labels: jnp.ndarray): | ||
balanced_score = balanced_accuracy_score(y_true, y_pred, labels) | ||
return balanced_score | ||
|
||
def sklearn_proc(y_true, y_pred): | ||
balanced_score = metrics.balanced_accuracy_score(y_true, y_pred) | ||
return balanced_score | ||
|
||
def check(spu_result, sk_result): | ||
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# Test binary | ||
y_true = jnp.array([0, 1, 1, 0, 1, 1]) | ||
y_pred = jnp.array([0, 0, 1, 0, 1, 1]) | ||
labels = jnp.array([0, 1]) | ||
spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, labels) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test the param |
||
sk_result = sklearn_proc(y_true, y_pred) | ||
check(spu_result, sk_result) | ||
|
||
# Test multiclass | ||
y_true = jnp.array([0, 1, 1, 0, 2, 1]) | ||
y_pred = jnp.array([0, 0, 1, 0, 2, 1]) | ||
labels = jnp.array([0, 1, 2]) | ||
spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, labels) | ||
sk_result = sklearn_proc(y_true, y_pred) | ||
check(spu_result, sk_result) | ||
|
||
def test_top_k_accuracy(self): | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
|
||
def proc( | ||
y_true: jnp.ndarray, | ||
y_pred: jnp.ndarray, | ||
k, | ||
normalize, | ||
sample_weight, | ||
labels, | ||
): | ||
top_k_score = top_k_accuracy_score( | ||
y_true, | ||
y_pred, | ||
k=k, | ||
normalize=normalize, | ||
sample_weight=sample_weight, | ||
labels=labels, | ||
) | ||
return top_k_score | ||
|
||
def sklearn_proc(y_true, y_pred, k, labels): | ||
top_k_score = metrics.top_k_accuracy_score( | ||
y_true, y_pred, k=k, labels=labels | ||
) | ||
return top_k_score | ||
|
||
def check(spu_result, sk_result): | ||
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# Test multiclass | ||
y_true = jnp.array([0, 1, 2, 2, 0]) | ||
y_score = jnp.array( | ||
[ | ||
[0.8, 0.1, 0.1], | ||
[0.3, 0.4, 0.3], | ||
[0.1, 0.1, 0.8], | ||
[0.2, 0.2, 0.6], | ||
[0.7, 0.2, 0.1], | ||
] | ||
) | ||
spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 3))( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test more please, at least you should test |
||
y_true, y_score, 2, True, None, None | ||
) | ||
sk_result = sklearn_proc(y_true, y_score, k=2, labels=None) | ||
check(spu_result, sk_result) | ||
|
||
def test_auc(self): | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not implement the logic about
sample_weight
andnormalize