Skip to content

Commit

Permalink
simplify imports of metric functions (#3292)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzkadc authored Oct 14, 2024
1 parent ad02551 commit a5d3464
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
16 changes: 5 additions & 11 deletions ignite/metrics/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,17 @@ def __init__(
# initalize weights
self.weights = weights

self.cohen_kappa_compute = self.get_cohen_kappa_fn()

super(CohenKappa, self).__init__(
self.cohen_kappa_compute,
self._cohen_kappa_score,
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
skip_unrolling=skip_unrolling,
)

def get_cohen_kappa_fn(self) -> Callable[[torch.Tensor, torch.Tensor], float]:
"""Return a function computing Cohen Kappa from scikit-learn."""
def _cohen_kappa_score(self, y_targets: torch.Tensor, y_preds: torch.Tensor) -> float:
from sklearn.metrics import cohen_kappa_score

def wrapper(y_targets: torch.Tensor, y_preds: torch.Tensor) -> float:
y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
return cohen_kappa_score(y_true, y_pred, weights=self.weights)

return wrapper
y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
return cohen_kappa_score(y_true, y_pred, weights=self.weights)
15 changes: 6 additions & 9 deletions ignite/metrics/regression/spearman_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types


def _get_spearman_r() -> Callable[[Tensor, Tensor], float]:
def _spearman_r(predictions: Tensor, targets: Tensor) -> float:
from scipy.stats import spearmanr

def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float:
np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
r = spearmanr(np_preds, np_targets).statistic
return r

return _compute_spearman_r
np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
r = spearmanr(np_preds, np_targets).statistic
return r


class SpearmanRankCorrelation(EpochMetric):
Expand Down Expand Up @@ -92,7 +89,7 @@ def __init__(
except ImportError:
raise ModuleNotFoundError("This module requires scipy to be installed.")

super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling)
super().__init__(_spearman_r, output_transform, check_compute_fn, device, skip_unrolling)

def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
Expand Down

0 comments on commit a5d3464

Please sign in to comment.