Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #726: Implemented two accuracy functions #842

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,91 @@

import jax
import jax.numpy as jnp
import numpy as np

from sml.preprocessing.preprocessing import label_binarize
from spu.ops.groupby import groupby, groupby_sum

from .auc import binary_clf_curve, binary_roc_auc


def confusion_matrix(y_true, y_pred, labels, sample_weight=None, normalize=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not implement the logic about sample_weight and normalize

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some docs for the functionality of this function, and the means of all params.

"""calculate the confusion matrix"""
y_true = jnp.array(y_true)
y_pred = jnp.array(y_pred)

# Get the number of tags
num_labels = len(labels)

# Initialize the confusion matrix
cm = jnp.zeros((num_labels, num_labels), dtype=jnp.int32)

# Calculate the confusion matrix
for i, label in enumerate(labels):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vectorized operations can replace these two for-loops, So the round complexity can be reduced.

e.g.
y_true == labels gives an n*c matrix, where n is the number of samples, and c is the number of labels.
Same as y_pred == labels, then the cm is just the inner product of all column-pairs of two matrics.

# Get the true label and predicted label as the Boolean value of the current label
true_mask = y_true == label
pred_mask = y_pred == label

# Update the confusion matrix
for j, _ in enumerate(labels):
# Calculate TP, FP, FN, TN
cm = cm.at[i, j].set(jnp.sum(true_mask & (y_pred == j)))

return cm


def balanced_accuracy_score(y_true, y_pred, labels, sample_weight=None, adjusted=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some docs for the functionality of this function, and the means of all params.

"""calculate balanced accuracy score"""
C = confusion_matrix(y_true, y_pred, labels, sample_weight=sample_weight)
with np.errstate(divide="ignore", invalid="ignore"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context management can not work in SPU, can just delete this

per_class = jnp.diag(C) / C.sum(axis=1)
score = jnp.mean(per_class)
if adjusted:
n_classes = len(per_class)
chance = 1 / n_classes
score -= chance
score /= 1 - chance
return score


def top_k_accuracy_score(
y_true, y_score, k=2, normalize=True, sample_weight=None, labels=None
):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some docs for the functionality of this function, and the means of all params.

Top-k Accuracy classification score.
This metric computes the number of times when the correct label is among
the top `k` labels predicted (ranked by predicted scores).
"""

y_true = jnp.asarray(y_true)
y_score = jnp.asarray(y_score)

if labels is not None:
# If labels are provided, make sure y_true and y_score are included in labels
labels = jnp.asarray(labels)
y_true = jnp.searchsorted(labels, y_true, sorter=jnp.argsort(labels))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Searchsorted is a costly op, you can just skip it.

Then, you should comment in the docs to hint to the user that the y_true should only contain values 0,1,2..., len(labels)-1


# Compute the indices of the top k predictions for each sample
top_k_indices = jnp.argsort(y_score, axis=1)[:, -k:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some easy cases: Binary case, when k=1, you just compare to 0.5; k>1, you just return 1 (Indeed, you can always check whether k >= len(labels)).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general cases (multi-class case), you can use jax.lax.top_k to replace the full sort


# Check if y_true is among the top k predictions
y_true_in_top_k = jnp.any(jnp.isin(y_true[:, None], top_k_indices), axis=1)

# Calculate accuracy
correct_predictions = jnp.sum(y_true_in_top_k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should think twice about the sample_weight and normalize.

  • if sample_weight is not None, and normalize=False, correct_predictions should also take sample_weight into consideration.

Maybe you can refer to sklearn about the logic


if sample_weight is not None:
sample_weight = jnp.asarray(sample_weight)
accuracy = jnp.sum(sample_weight * y_true_in_top_k) / jnp.sum(sample_weight)
else:
accuracy = correct_predictions / len(y_true)

if normalize:
return accuracy
else:
return correct_predictions


def roc_auc_score(y_true, y_pred):
sorted_arr = create_sorted_label_score_pair(y_true, y_pred)
return binary_roc_auc(sorted_arr)
Expand Down
73 changes: 72 additions & 1 deletion sml/metrics/classification/classification_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,86 @@
from sml.metrics.classification.classification import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
f1_score,
precision_score,
recall_score,
roc_auc_score,
top_k_accuracy_score,
)


# TODO: design the enumation framework, just like py.unittest
# all emulation action should begin with `emul_` (for reflection)


def emul_balanced_accuracy(mode: emulation.Mode.MULTIPROCESS):
def proc(y_true: jnp.ndarray, y_pred: jnp.ndarray, labels: jnp.ndarray):
balanced_score = balanced_accuracy_score(y_true, y_pred, labels)
return balanced_score

def sklearn_proc(y_true, y_pred):
balanced_score = metrics.balanced_accuracy_score(y_true, y_pred)
return balanced_score

def check(spu_result, sk_result):
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5)

# Test binary
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
labels = jnp.array([0, 1])
spu_result = emulator.run(proc)(y_true, y_pred, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should seal the input tensors before feeding them into the emulator, you can refer to other emul_ funcs

sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)

# Test multiclass
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
labels = jnp.array([0, 1, 2])
spu_result = emulator.run(proc)(y_true, y_pred, labels)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)


def emul_top_k_accuracy_score(mode: emulation.Mode.MULTIPROCESS):
def proc(
y_true: jnp.ndarray, y_pred: jnp.ndarray, k, normalize, sample_weight, labels
):
top_k_score = top_k_accuracy_score(
y_true,
y_pred,
k=k,
normalize=normalize,
sample_weight=sample_weight,
labels=labels,
)
return top_k_score

def sklearn_proc(y_true, y_pred, k, labels):
top_k_score = metrics.top_k_accuracy_score(y_true, y_pred, k=k, labels=labels)
return top_k_score

def check(spu_result, sk_result):
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5)

# Test multiclass
y_true = jnp.array([0, 1, 2, 2, 0])
y_score = jnp.array(
[
[0.8, 0.1, 0.1],
[0.3, 0.4, 0.3],
[0.1, 0.1, 0.8],
[0.2, 0.2, 0.6],
[0.7, 0.2, 0.1],
]
)
spu_result = emulator.run(proc, static_argnums=(2, 3))(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, should seal the data beforehand

y_true, y_score, 2, True, None, None
)
sk_result = sklearn_proc(y_true, y_score, k=2, labels=None)
check(spu_result, sk_result)


def emul_auc(mode: emulation.Mode.MULTIPROCESS):
# Create dataset
row = 10000
Expand Down
84 changes: 84 additions & 0 deletions sml/metrics/classification/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,100 @@
from sml.metrics.classification.classification import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
bin_counts,
equal_obs,
f1_score,
precision_score,
recall_score,
roc_auc_score,
top_k_accuracy_score,
)


class UnitTests(unittest.TestCase):

def test_balanced_accuracy(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FM64 is enough

)

def proc(y_true: jnp.ndarray, y_pred: jnp.ndarray, labels: jnp.ndarray):
balanced_score = balanced_accuracy_score(y_true, y_pred, labels)
return balanced_score

def sklearn_proc(y_true, y_pred):
balanced_score = metrics.balanced_accuracy_score(y_true, y_pred)
return balanced_score

def check(spu_result, sk_result):
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rtol and atol can be set to 1e-3


# Test binary
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
labels = jnp.array([0, 1])
spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test the param sample_weight and adjusted;
Test larger datasets, please (maybe ~1000 samples are enough).

sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)

# Test multiclass
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
labels = jnp.array([0, 1, 2])
spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, labels)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)

def test_top_k_accuracy(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FM64 is enough

)

def proc(
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
k,
normalize,
sample_weight,
labels,
):
top_k_score = top_k_accuracy_score(
y_true,
y_pred,
k=k,
normalize=normalize,
sample_weight=sample_weight,
labels=labels,
)
return top_k_score

def sklearn_proc(y_true, y_pred, k, labels):
top_k_score = metrics.top_k_accuracy_score(
y_true, y_pred, k=k, labels=labels
)
return top_k_score

def check(spu_result, sk_result):
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rtol and atol can be set to 1e-3


# Test multiclass
y_true = jnp.array([0, 1, 2, 2, 0])
y_score = jnp.array(
[
[0.8, 0.1, 0.1],
[0.3, 0.4, 0.3],
[0.1, 0.1, 0.8],
[0.2, 0.2, 0.6],
[0.7, 0.2, 0.1],
]
)
spu_result = spsim.sim_jax(sim, proc, static_argnums=(2, 3))(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test more please, at least you should test sample_weight, and normalize.
And you should check for some larger data, maybe a size of 1000 is the least requirement.

y_true, y_score, 2, True, None, None
)
sk_result = sklearn_proc(y_true, y_score, k=2, labels=None)
check(spu_result, sk_result)

def test_auc(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64
Expand Down
Loading