Skip to content

Commit

Permalink
Update code comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hesy7 committed Sep 3, 2024
1 parent 1e1029a commit a4668c2
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ def confusion_matrix(y_true, y_pred, labels, sample_weight=None, normalize=None)
y_true = jnp.array(y_true)
y_pred = jnp.array(y_pred)

# 获取标签的数量
# Get the number of tags
num_labels = len(labels)

# 初始化混淆矩阵
# Initialize the confusion matrix
cm = jnp.zeros((num_labels, num_labels), dtype=jnp.int32)

# 计算混淆矩阵
# Calculate the confusion matrix
for i, label in enumerate(labels):
# 获取真实标签和预测标签为当前标签的布尔值
# Get the true label and predicted label as the Boolean value of the current label
true_mask = (y_true == label)
pred_mask = (y_pred == label)

# 更新混淆矩阵
# Update the confusion matrix
for j, _ in enumerate(labels):
# 计算 TP, FP, FN, TN
# Calculate TP, FP, FN, TN
cm = cm.at[i, j].set(jnp.sum(true_mask & (y_pred == j)))

return cm
Expand Down Expand Up @@ -72,22 +72,21 @@ def top_k_accuracy_score(
the top `k` labels predicted (ranked by predicted scores).
"""

# 转换 y_true 和 y_score 为 JAX 数组
y_true = jnp.asarray(y_true)
y_score = jnp.asarray(y_score)

if labels is not None:
# 如果提供了标签,确保 y_true y_score 包含在 labels 中
# If labels are provided, make sure y_true and y_score are included in labels
labels = jnp.asarray(labels)
y_true = jnp.searchsorted(labels, y_true, sorter=jnp.argsort(labels))

# 计算每个样本的前 k 个预测的索引
# Compute the indices of the top k predictions for each sample
top_k_indices = jnp.argsort(y_score, axis=1)[:, -k:]

# 检查 y_true 是否在前 k 个预测中
# Check if y_true is among the top k predictions
y_true_in_top_k = jnp.any(jnp.isin(y_true[:, None], top_k_indices), axis=1)

# 计算准确率
# Calculate accuracy
correct_predictions = jnp.sum(y_true_in_top_k)

if sample_weight is not None:
Expand Down

0 comments on commit a4668c2

Please sign in to comment.