diff --git a/CHANGES.txt b/CHANGES.txt index 6bd0eee1f..0f21b512b 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -147,6 +147,7 @@ v<0.9.6>, <11/05/2021> -- Minor bug fix for COPOD. v<0.9.6>, <12/24/2021> -- Bug fix for MAD (#358). v<0.9.6>, <12/24/2021> -- Bug fix for COPOD plotting (#337). v<0.9.6>, <12/24/2021> -- Model persistence doc improvement. +v<0.9.7>, <01/03/2021> -- Add ECOD. diff --git a/README.rst b/README.rst index 501238b2c..e397fd9b3 100644 --- a/README.rst +++ b/README.rst @@ -307,6 +307,12 @@ PyOD toolkit consists of three major functional groups: =================== ================== ====================================================================================================== ===== ======================================== Type Abbr Algorithm Year Ref =================== ================== ====================================================================================================== ===== ======================================== +Probabilistic ECOD Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions 2021 [#Li2021ECOD]_ +Probabilistic ABOD Angle-Based Outlier Detection 2008 [#Kriegel2008Angle]_ +Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 [#Kriegel2008Angle]_ +Probabilistic COPOD COPOD: Copula-Based Outlier Detection 2020 [#Li2020COPOD]_ +Probabilistic MAD Median Absolute Deviation (MAD) 1993 [#Iglewicz1993How]_ +Probabilistic SOS Stochastic Outlier Selection 2012 [#Janssens2012Stochastic]_ Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 [#Shyu2003A]_ Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 [#Hardin2004Outlier]_ [#Rousseeuw1999A]_ Linear Model OCSVM One-Class Support Vector Machines 2001 [#Scholkopf2001Estimating]_ @@ -322,11 +328,6 @@ Proximity-Based AvgKNN Average kNN (use the average distance t Proximity-Based MedKNN Median kNN (use the median distance to k nearest neighbors as the outlier score) 2002 [#Angiulli2002Fast]_ Proximity-Based SOD Subspace Outlier Detection 2009 [#Kriegel2009Outlier]_ Proximity-Based ROD Rotation-based Outlier Detection 2020 [#Almardeny2020A]_ -Probabilistic ABOD Angle-Based Outlier Detection 2008 [#Kriegel2008Angle]_ -Probabilistic COPOD COPOD: Copula-Based Outlier Detection 2020 [#Li2020COPOD]_ -Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 [#Kriegel2008Angle]_ -Probabilistic MAD Median Absolute Deviation (MAD) 1993 [#Iglewicz1993How]_ -Probabilistic SOS Stochastic Outlier Selection 2012 [#Janssens2012Stochastic]_ Outlier Ensembles IForest Isolation Forest 2008 [#Liu2008Isolation]_ Outlier Ensembles FB Feature Bagging 2005 [#Lazarevic2005Feature]_ Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 [#Zhao2019LSCP]_ @@ -571,6 +572,8 @@ Reference .. [#Li2020COPOD] Li, Z., Zhao, Y., Botta, N., Ionescu, C. and Hu, X. COPOD: Copula-Based Outlier Detection. *IEEE International Conference on Data Mining (ICDM)*, 2020. +.. [#Li2021ECOD] Li, Z., Zhao, Y., Hu, X., Botta, N., Ionescu, C. and Chen, H. G. ECOD: Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions. arXiv preprint arXiv:2201.00382 (2021). + .. [#Liu2008Isolation] Liu, F.T., Ting, K.M. and Zhou, Z.H., 2008, December. Isolation forest. In *International Conference on Data Mining*\ , pp. 413-422. IEEE. .. [#Liu2019Generative] Liu, Y., Li, Z., Zhou, C., Jiang, Y., Sun, J., Wang, M. and He, X., 2019. Generative adversarial active learning for unsupervised outlier detection. *IEEE Transactions on Knowledge and Data Engineering*. diff --git a/docs/index.rst b/docs/index.rst index 8a2c557bd..d405bf5c6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -143,17 +143,22 @@ PyOD toolkit consists of three major functional groups: **(i) Individual Detection Algorithms** : -1. Linear Models for Outlier Detection: - =================== ================ ====================================================================================================== ===== =================================================== ====================================================== Type Abbr Algorithm Year Class Ref =================== ================ ====================================================================================================== ===== =================================================== ====================================================== +Probabilistic ECOD Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions 2021 :class:`pyod.models.ecod.ECOD` :cite:`a-li2021ecod` +Probabilistic COPOD COPOD: Copula-Based Outlier Detection 2020 :class:`pyod.models.copod.COPOD` :cite:`a-li2020copod` +Probabilistic ABOD Angle-Based Outlier Detection 2008 :class:`pyod.models.abod.ABOD` :cite:`a-kriegel2008angle` +Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 :class:`pyod.models.abod.ABOD` :cite:`a-kriegel2008angle` +Probabilistic MAD Median Absolute Deviation (MAD) 1993 :class:`pyod.models.mad.MAD` :cite:`a-iglewicz1993detect` +Probabilistic SOS Stochastic Outlier Selection 2012 :class:`pyod.models.sos.SOS` :cite:`a-janssens2012stochastic` Linear Model PCA Principal Component Analysis (the sum of weighted projected distances to the eigenvector hyperplanes) 2003 :class:`pyod.models.pca.PCA` :cite:`a-shyu2003novel` Linear Model MCD Minimum Covariance Determinant (use the mahalanobis distances as the outlier scores) 1999 :class:`pyod.models.mcd.MCD` :cite:`a-rousseeuw1999fast,a-hardin2004outlier` Linear Model OCSVM One-Class Support Vector Machines 2001 :class:`pyod.models.ocsvm.OCSVM` :cite:`a-scholkopf2001estimating` Linear Model LMDD Deviation-based Outlier Detection (LMDD) 1996 :class:`pyod.models.lmdd.LMDD` :cite:`a-arning1996linear` Proximity-Based LOF Local Outlier Factor 2000 :class:`pyod.models.lof.LOF` :cite:`a-breunig2000lof` Proximity-Based COF Connectivity-Based Outlier Factor 2002 :class:`pyod.models.cof.COF` :cite:`a-tang2002enhancing` +Proximity-Based Incr. COF Memory Efficient Connectivity-Based Outlier Factor (slower but reduce storage complexity) 2002 :class:`pyod.models.cof.COF` :cite:`a-tang2002enhancing` Proximity-Based CBLOF Clustering-Based Local Outlier Factor 2003 :class:`pyod.models.cblof.CBLOF` :cite:`a-he2003discovering` Proximity-Based LOCI LOCI: Fast outlier detection using the local correlation integral 2003 :class:`pyod.models.loci.LOCI` :cite:`a-papadimitriou2003loci` Proximity-Based HBOS Histogram-based Outlier Score 2012 :class:`pyod.models.hbos.HBOS` :cite:`a-goldstein2012histogram` @@ -162,13 +167,8 @@ Proximity-Based AvgKNN Average kNN (use the average distance to Proximity-Based MedKNN Median kNN (use the median distance to k nearest neighbors as the outlier score) 2002 :class:`pyod.models.knn.KNN` :cite:`a-ramaswamy2000efficient,a-angiulli2002fast` Proximity-Based SOD Subspace Outlier Detection 2009 :class:`pyod.models.sod.SOD` :cite:`a-kriegel2009outlier` Proximity-Based ROD Rotation-based Outlier Detection 2020 :class:`pyod.models.rod.ROD` :cite:`a-almardeny2020novel` -Probabilistic ABOD Angle-Based Outlier Detection 2008 :class:`pyod.models.abod.ABOD` :cite:`a-kriegel2008angle` -Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 :class:`pyod.models.abod.ABOD` :cite:`a-kriegel2008angle` -Probabilistic COPOD COPOD: Copula-Based Outlier Detection 2020 :class:`pyod.models.copod.COPOD` :cite:`a-li2020copod` -Probabilistic MAD Median Absolute Deviation (MAD) 1993 :class:`pyod.models.mad.MAD` :cite:`a-iglewicz1993detect` -Probabilistic SOS Stochastic Outlier Selection 2012 :class:`pyod.models.sos.SOS` :cite:`a-janssens2012stochastic` Outlier Ensembles IForest Isolation Forest 2008 :class:`pyod.models.iforest.IForest` :cite:`a-liu2008isolation,a-liu2012isolation` -Outlier Ensembles Feature Bagging 2005 :class:`pyod.models.feature_bagging.FeatureBagging` :cite:`a-lazarevic2005feature` +Outlier Ensembles FB Feature Bagging 2005 :class:`pyod.models.feature_bagging.FeatureBagging` :cite:`a-lazarevic2005feature` Outlier Ensembles LSCP LSCP: Locally Selective Combination of Parallel Outlier Ensembles 2019 :class:`pyod.models.lscp.LSCP` :cite:`a-zhao2019lscp` Outlier Ensembles XGBOD Extreme Boosting Based Outlier Detection **(Supervised)** 2018 :class:`pyod.models.xgbod.XGBOD` :cite:`a-zhao2018xgbod` Outlier Ensembles LODA Lightweight On-line Detector of Anomalies 2016 :class:`pyod.models.loda.LODA` :cite:`a-pevny2016loda` diff --git a/docs/pyod.models.rst b/docs/pyod.models.rst index 32859a1b4..123d36d4c 100644 --- a/docs/pyod.models.rst +++ b/docs/pyod.models.rst @@ -77,6 +77,16 @@ pyod.models.deep\_svdd module :show-inheritance: :inherited-members: +pyod.models.ecod module +------------------------ + +.. automodule:: pyod.models.ecod + :members: + :exclude-members: + :undoc-members: + :show-inheritance: + :inherited-members: + pyod.models.feature\_bagging module ----------------------------------- diff --git a/docs/zreferences.bib b/docs/zreferences.bib index 5f2f686a3..0ad77d308 100644 --- a/docs/zreferences.bib +++ b/docs/zreferences.bib @@ -377,4 +377,11 @@ @inproceedings{perini2020quantifying pages={227--243}, year={2020}, publisher={Springer} +} + +@article{Li2021ecod, + title={ECOD: Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions}, + author={Li, Zheng and Zhao, Yue and Hu, Xiyang and Botta, Nicola and Ionescu, Cezar and Chen, H. George}, + journal={arXiv preprint arXiv:2201.00382}, + year={2021} } \ No newline at end of file diff --git a/examples/ecod_example.py b/examples/ecod_example.py new file mode 100644 index 000000000..5ef68a394 --- /dev/null +++ b/examples/ecod_example.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +"""Example of using ECOD for outlier detection +""" +# Author: Yue Zhao +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import os +import sys + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) + +from pyod.models.ecod import ECOD +from pyod.utils.data import generate_data +from pyod.utils.data import evaluate_print +from pyod.utils.example import visualize + +if __name__ == "__main__": + contamination = 0.1 # percentage of outliers + n_train = 200 # number of training points + n_test = 100 # number of testing points + + # Generate sample data + X_train, y_train, X_test, y_test = \ + generate_data(n_train=n_train, + n_test=n_test, + n_features=2, + contamination=contamination, + random_state=42) + + # train ECOD detector + clf_name = 'ECOD' + clf = ECOD() + + # you could try parallel version as well. + # clf = ECOD(n_jobs=2) + clf.fit(X_train) + + # get the prediction labels and outlier scores of the training data + y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) + y_train_scores = clf.decision_scores_ # raw outlier scores + + # get the prediction on the test data + y_test_pred = clf.predict(X_test) # outlier labels (0 or 1) + y_test_scores = clf.decision_function(X_test) # outlier scores + + # evaluate and print the results + print("\nOn Training Data:") + evaluate_print(clf_name, y_train, y_train_scores) + print("\nOn Test Data:") + evaluate_print(clf_name, y_test, y_test_scores) + + # visualize the results + visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred, + y_test_pred, show_figure=True, save_figure=False) diff --git a/pyod/models/ecod.py b/pyod/models/ecod.py new file mode 100644 index 000000000..1593c07d3 --- /dev/null +++ b/pyod/models/ecod.py @@ -0,0 +1,302 @@ +"""Unsupervised Outlier Detection Using +Empirical Cumulative Distribution Functions (ECOD) +""" +# Author: Zheng Li +# Author: Yue Zhao +# License: BSD 2 clause + +from __future__ import division +from __future__ import print_function + +import warnings +import numpy as np + +from statsmodels.distributions.empirical_distribution import ECDF +from scipy.stats import skew +from sklearn.utils import check_array +from joblib import Parallel, delayed, effective_n_jobs +import matplotlib.pyplot as plt + +from .base import BaseDetector +from .sklearn_base import _partition_estimators + + +def ecdf(X): + """Calculated the empirical CDF of a given dataset. + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The training dataset. + Returns + ------- + ecdf(X) : float + Empirical CDF of X + """ + ecdf = ECDF(X) + return ecdf(X) + + +def _parallel_ecdf(n_dims, X): + """Private method to calculate ecdf in parallel. + Parameters + ---------- + n_dims : int + The number of dimensions of the current input matrix + + X : numpy array + The subarray for building the ECDF + + Returns + ------- + U_l_mat : numpy array + ECDF subarray. + + U_r_mat : numpy array + ECDF subarray. + """ + U_l_mat = np.zeros([X.shape[0], n_dims]) + U_r_mat = np.zeros([X.shape[0], n_dims]) + + for i in range(n_dims): + U_l_mat[:, i] = ecdf(X[:, i]) + U_r_mat[:, i] = ecdf(X[:, i] * -1) + return U_l_mat, U_r_mat + + +class ECOD(BaseDetector): + """ECOD class for Unsupervised Outlier Detection Using Empirical + Cumulative Distribution Functions (ECOD) + ECOD is a parameter-free, highly interpretable outlier detection algorithm + based on empirical CDF functions. + See :cite:`Li2021ecod` for details. + + Parameters + ---------- + contamination : float in (0., 0.5), optional (default=0.1) + The amount of contamination of the data set, i.e. + the proportion of outliers in the data set. Used when fitting to + define the threshold on the decision function. + + n_jobs : optional (default=1) + The number of jobs to run in parallel for both `fit` and + `predict`. If -1, then the number of jobs is set to the + number of cores. + + Attributes + ---------- + decision_scores_ : numpy array of shape (n_samples,) + The outlier scores of the training data. + The higher, the more abnormal. Outliers tend to have higher + scores. This value is available once the detector is + fitted. + threshold_ : float + The threshold is based on ``contamination``. It is the + ``n_samples * contamination`` most abnormal samples in + ``decision_scores_``. The threshold is calculated for generating + binary outlier labels. + labels_ : int, either 0 or 1 + The binary labels of the training data. 0 stands for inliers + and 1 for outliers/anomalies. It is generated by applying + ``threshold_`` on ``decision_scores_``. + """ + + def __init__(self, contamination=0.1, n_jobs=1): + super(ECOD, self).__init__(contamination=contamination) + self.n_jobs = n_jobs + + def fit(self, X, y=None): + """Fit detector. y is ignored in unsupervised methods. + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The input samples. + y : Ignored + Not used, present for API consistency by convention. + Returns + ------- + self : object + Fitted estimator. + """ + X = check_array(X) + self._set_n_classes(y) + self.decision_scores_ = self.decision_function(X) + self.X_train = X + self._process_decision_scores() + return self + + def decision_function(self, X): + """Predict raw anomaly score of X using the fitted detector. + For consistency, outliers are assigned with larger anomaly scores. + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only + if they are supported by the base estimator. + Returns + ------- + anomaly_scores : numpy array of shape (n_samples,) + The anomaly score of the input samples. + """ + # use multi-thread execution + if self.n_jobs != 1: + return self._decision_function_parallel(X) + if hasattr(self, 'X_train'): + original_size = X.shape[0] + X = np.concatenate((self.X_train, X), axis=0) + self.U_l = -1 * np.log(np.apply_along_axis(ecdf, 0, X)) + self.U_r = -1 * np.log(np.apply_along_axis(ecdf, 0, -X)) + + skewness = np.sign(skew(X, axis=0)) + self.U_skew = self.U_l * -1 * np.sign( + skewness - 1) + self.U_r * np.sign(skewness + 1) + self.O = np.maximum(self.U_skew, self.U_l, self.U_r) + if hasattr(self, 'X_train'): + decision_scores_ = self.O.sum(axis=1)[-original_size:] + else: + decision_scores_ = self.O.sum(axis=1) + return decision_scores_.ravel() + + def _decision_function_parallel(self, X): + """Predict raw anomaly score of X using the fitted detector. + For consistency, outliers are assigned with larger anomaly scores. + Parameters + ---------- + X : numpy array of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only + if they are supported by the base estimator. + Returns + ------- + anomaly_scores : numpy array of shape (n_samples,) + The anomaly score of the input samples. + """ + if hasattr(self, 'X_train'): + original_size = X.shape[0] + X = np.concatenate((self.X_train, X), axis=0) + + n_samples, n_features = X.shape[0], X.shape[1] + + if n_features < 2: + raise ValueError( + 'n_jobs should not be used on one dimensional dataset') + + if n_features <= self.n_jobs: + self.n_jobs = n_features + warnings.warn("n_features <= n_jobs; setting them equal instead.") + + n_jobs, n_dims_list, starts = _partition_estimators(n_features, + self.n_jobs) + + all_results = Parallel(n_jobs=n_jobs, max_nbytes=None, + verbose=True)( + delayed(_parallel_ecdf)( + n_dims_list[i], + X[:, starts[i]:starts[i + 1]], + ) + for i in range(n_jobs)) + + # recover the results + self.U_l = np.zeros([n_samples, n_features]) + self.U_r = np.zeros([n_samples, n_features]) + + for i in range(n_jobs): + self.U_l[:, starts[i]:starts[i + 1]] = all_results[i][0] + self.U_r[:, starts[i]:starts[i + 1]] = all_results[i][1] + + self.U_l = -1 * np.log(self.U_l) + self.U_r = -1 * np.log(self.U_r) + + skewness = np.sign(skew(X, axis=0)) + self.U_skew = self.U_l * -1 * np.sign( + skewness - 1) + self.U_r * np.sign(skewness + 1) + self.O = np.maximum(self.U_skew, self.U_l, self.U_r) + if hasattr(self, 'X_train'): + decision_scores_ = self.O.sum(axis=1)[-original_size:] + else: + decision_scores_ = self.O.sum(axis=1) + return decision_scores_.ravel() + + def explain_outlier(self, ind, columns=None, cutoffs=None, + feature_names=None, file_name=None, + file_type=None): # pragma: no cover + """Plot dimensional outlier graph for a given data point within + the dataset. + + Parameters + ---------- + ind : int + The index of the data point one wishes to obtain + a dimensional outlier graph for. + + columns : list + Specify a list of features/dimensions for plotting. If not + specified, use all features. + + cutoffs : list of floats in (0., 1), optional (default=[0.95, 0.99]) + The significance cutoff bands of the dimensional outlier graph. + + feature_names : list of strings + The display names of all columns of the dataset, + to show on the x-axis of the plot. + + file_name : string + The name to save the figure + + file_type : string + The file type to save the figure + + Returns + ------- + Plot : matplotlib plot + The dimensional outlier graph for data point with index ind. + """ + if columns is None: + columns = list(range(self.O.shape[1])) + column_range = range(1, self.O.shape[1] + 1) + else: + column_range = range(1, len(columns) + 1) + + cutoffs = [1 - self.contamination, + 0.99] if cutoffs is None else cutoffs + + # plot outlier scores + plt.scatter(column_range, self.O[ind, columns], marker='^', c='black', + label='Outlier Score') + + for i in cutoffs: + plt.plot(column_range, + np.quantile(self.O[:, columns], q=i, axis=0), + '--', + label='{percentile} Cutoff Band'.format(percentile=i)) + plt.xlim([1, max(column_range)]) + plt.ylim([0, int(self.O[:, columns].max().max()) + 1]) + plt.ylabel('Dimensional Outlier Score') + plt.xlabel('Dimension') + + ticks = list(column_range) + if feature_names is not None: + assert len(feature_names) == len(ticks), \ + "Length of feature_names does not match dataset dimensions." + plt.xticks(ticks, labels=feature_names) + else: + plt.xticks(ticks) + + plt.yticks(range(0, int(self.O[:, columns].max().max()) + 1)) + plt.xlim(0.95, ticks[-1] + 0.05) + label = 'Outlier' if self.labels_[ind] == 1 else 'Inlier' + plt.title( + 'Outlier score breakdown for sample #{index} ({label})'.format( + index=ind + 1, label=label)) + plt.legend() + plt.tight_layout() + + # save the file if specified + if file_name is not None: + if file_type is not None: + plt.savefig(file_name + '.' + file_type, dpi=300) + # if not specified, save as png + else: + plt.savefig(file_name + '.' + 'png', dpi=300) + plt.show() + + # todo: consider returning results + # return self.O[ind, columns], self.O[:, columns].quantile(q=cutoffs[0], axis=0), self.O[:, columns].quantile(q=cutoffs[1], axis=0) diff --git a/pyod/test/test_ecod.py b/pyod/test/test_ecod.py new file mode 100644 index 000000000..5dfd31e22 --- /dev/null +++ b/pyod/test/test_ecod.py @@ -0,0 +1,272 @@ +# -*- coding: utf-8 -*- +from __future__ import division +from __future__ import print_function + +import os +import sys + +import unittest +# noinspection PyProtectedMember +from numpy.testing import assert_allclose +from numpy.testing import assert_array_less +from numpy.testing import assert_equal +from numpy.testing import assert_raises + +from sklearn.metrics import roc_auc_score +from sklearn.base import clone +from scipy.stats import rankdata + +# temporary solution for relative imports in case pyod is not installed +# if pyod is installed, no need to use the following line +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from pyod.models.ecod import ECOD +from pyod.utils.data import generate_data + + +class TestECOD(unittest.TestCase): + def setUp(self): + self.n_train = 200 + self.n_test = 100 + self.contamination = 0.1 + self.roc_floor = 0.8 + self.X_train, self.y_train, self.X_test, self.y_test = generate_data( + n_train=self.n_train, n_test=self.n_test, n_features=10, + contamination=self.contamination, random_state=42) + + self.clf = ECOD(contamination=self.contamination) + self.clf.fit(self.X_train) + + def test_parameters(self): + assert (hasattr(self.clf, 'decision_scores_') and + self.clf.decision_scores_ is not None) + assert (hasattr(self.clf, 'labels_') and + self.clf.labels_ is not None) + assert (hasattr(self.clf, 'threshold_') and + self.clf.threshold_ is not None) + + def test_train_scores(self): + assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) + + def test_prediction_scores(self): + pred_scores = self.clf.decision_function(self.X_test) + + # check score shapes + assert_equal(pred_scores.shape[0], self.X_test.shape[0]) + + # check performance + assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) + + def test_prediction_labels(self): + pred_labels = self.clf.predict(self.X_test) + assert_equal(pred_labels.shape, self.y_test.shape) + + def test_prediction_proba(self): + pred_proba = self.clf.predict_proba(self.X_test) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_linear(self): + pred_proba = self.clf.predict_proba(self.X_test, method='linear') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_unify(self): + pred_proba = self.clf.predict_proba(self.X_test, method='unify') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_parameter(self): + with assert_raises(ValueError): + self.clf.predict_proba(self.X_test, method='something') + + def test_prediction_labels_confidence(self): + pred_labels, confidence = self.clf.predict(self.X_test, + return_confidence=True) + assert_equal(pred_labels.shape, self.y_test.shape) + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_prediction_proba_linear_confidence(self): + pred_proba, confidence = self.clf.predict_proba(self.X_test, + method='linear', + return_confidence=True) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_fit_predict(self): + pred_labels = self.clf.fit_predict(self.X_train) + assert_equal(pred_labels.shape, self.y_train.shape) + + def test_fit_predict_score(self): + self.clf.fit_predict_score(self.X_test, self.y_test) + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='roc_auc_score') + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='prc_n_score') + with assert_raises(NotImplementedError): + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='something') + + def test_predict_rank(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=3) + assert_array_less(pred_ranks, self.X_train.shape[0] + 1) + assert_array_less(-0.1, pred_ranks) + + def test_predict_rank_normalized(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=3) + assert_array_less(pred_ranks, 1.01) + assert_array_less(-0.1, pred_ranks) + + # def test_plot(self): + # os, cutoff1, cutoff2 = self.clf.explain_outlier(ind=1) + # assert_array_less(0, os) + + def test_model_clone(self): + clone_clf = clone(self.clf) + + def tearDown(self): + pass + + +class TestECODParallel(unittest.TestCase): + def setUp(self): + self.n_train = 200 + self.n_test = 100 + self.contamination = 0.1 + self.roc_floor = 0.8 + self.X_train, self.y_train, self.X_test, self.y_test = generate_data( + n_train=self.n_train, n_test=self.n_test, n_features=10, + contamination=self.contamination, random_state=42) + + self.clf = ECOD(contamination=self.contamination, n_jobs=2) + self.clf.fit(self.X_train) + + # get a copy from the single thread copy + self.clf_ = ECOD(contamination=self.contamination) + self.clf_.fit(self.X_train) + + def test_parameters(self): + assert (hasattr(self.clf, 'decision_scores_') and + self.clf.decision_scores_ is not None) + assert (hasattr(self.clf, 'labels_') and + self.clf.labels_ is not None) + assert (hasattr(self.clf, 'threshold_') and + self.clf.threshold_ is not None) + + def test_train_scores(self): + assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) + assert_allclose(self.clf.decision_scores_, self.clf_.decision_scores_) + + def test_prediction_scores(self): + pred_scores = self.clf.decision_function(self.X_test) + + # check score shapes + assert_equal(pred_scores.shape[0], self.X_test.shape[0]) + + # check performance + assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) + + def test_prediction_labels(self): + pred_labels = self.clf.predict(self.X_test) + assert_equal(pred_labels.shape, self.y_test.shape) + + pred_labels_ = self.clf_.predict(self.X_test) + assert_equal(pred_labels, pred_labels_) + + def test_prediction_proba(self): + pred_proba = self.clf.predict_proba(self.X_test) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_linear(self): + pred_proba = self.clf.predict_proba(self.X_test, method='linear') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_unify(self): + pred_proba = self.clf.predict_proba(self.X_test, method='unify') + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + def test_prediction_proba_parameter(self): + with assert_raises(ValueError): + self.clf.predict_proba(self.X_test, method='something') + + def test_prediction_labels_confidence(self): + pred_labels, confidence = self.clf.predict(self.X_test, + return_confidence=True) + assert_equal(pred_labels.shape, self.y_test.shape) + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_prediction_proba_linear_confidence(self): + pred_proba, confidence = self.clf.predict_proba(self.X_test, + method='linear', + return_confidence=True) + assert (pred_proba.min() >= 0) + assert (pred_proba.max() <= 1) + + assert_equal(confidence.shape, self.y_test.shape) + assert (confidence.min() >= 0) + assert (confidence.max() <= 1) + + def test_fit_predict(self): + pred_labels = self.clf.fit_predict(self.X_train) + assert_equal(pred_labels.shape, self.y_train.shape) + + def test_fit_predict_score(self): + self.clf.fit_predict_score(self.X_test, self.y_test) + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='roc_auc_score') + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='prc_n_score') + with assert_raises(NotImplementedError): + self.clf.fit_predict_score(self.X_test, self.y_test, + scoring='something') + + def test_predict_rank(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=3) + assert_array_less(pred_ranks, self.X_train.shape[0] + 1) + assert_array_less(-0.1, pred_ranks) + + def test_predict_rank_normalized(self): + pred_socres = self.clf.decision_function(self.X_test) + pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) + + # assert the order is reserved + assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=3) + assert_array_less(pred_ranks, 1.01) + assert_array_less(-0.1, pred_ranks) + + # def test_plot(self): + # os, cutoff1, cutoff2 = self.clf.explain_outlier(ind=1) + # assert_array_less(0, os) + + def test_model_clone(self): + clone_clf = clone(self.clf) + + def tearDown(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/pyod/version.py b/pyod/version.py index faea927cc..44ab5a58c 100644 --- a/pyod/version.py +++ b/pyod/version.py @@ -20,4 +20,4 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = '0.9.6' # pragma: no cover +__version__ = '0.9.7' # pragma: no cover