Skip to content

Commit

Permalink
Merge pull request #373 from yzhao062/development
Browse files Browse the repository at this point in the history
V0.9.8 Add ECOD
  • Loading branch information
yzhao062 authored Mar 5, 2022
2 parents 13b0cd5 + 73eb14c commit c5cdb11
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 19 deletions.
4 changes: 3 additions & 1 deletion CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ v<0.9.6>, <11/05/2021> -- Minor bug fix for COPOD.
v<0.9.6>, <12/24/2021> -- Bug fix for MAD (#358).
v<0.9.6>, <12/24/2021> -- Bug fix for COPOD plotting (#337).
v<0.9.6>, <12/24/2021> -- Model persistence doc improvement.
v<0.9.7>, <01/03/2021> -- Add ECOD.
v<0.9.7>, <01/03/2022> -- Add ECOD.
v<0.9.8>, <02/23/2022> -- Add Feature Importance for iForest.
v<0.9.8>, <03/05/2022> -- Update ECOD (TKDE 2022).



Expand Down
16 changes: 8 additions & 8 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ multivariate data. This exciting yet challenging field is commonly referred as
or `Anomaly Detection <https://en.wikipedia.org/wiki/Anomaly_detection>`_.

PyOD includes more than 30 detection algorithms, from classical LOF (SIGMOD 2000) to
the latest COPOD (ICDM 2020) and SUOD (MLSys 2021). Since 2017, PyOD has been successfully used in numerous academic researches and
commercial products [#Zhao2019LSCP]_ [#Zhao2021SUOD]_.
the latest SUOD (MLSys 2021) and ECOD (TKDE 2022). Since 2017, PyOD has been successfully used in numerous academic researches and
commercial products [#Zhao2019LSCP]_ [#Zhao2021SUOD]_ with more than 5 million downloads.
It is also well acknowledged by the machine learning community with various dedicated posts/tutorials, including
`Analytics Vidhya <https://www.analyticsvidhya.com/blog/2019/02/outlier-detection-python-pyod/>`_,
`KDnuggets <https://www.kdnuggets.com/2019/02/outlier-detection-methods-cheat-sheet.html>`_,
Expand All @@ -74,7 +74,7 @@ It is also well acknowledged by the machine learning community with various dedi
PyOD is featured for:

* **Unified APIs, detailed documentation, and interactive examples** across various algorithms.
* **Advanced models**\ , including **classical ones from scikit-learn**, **latest deep learning methods**, and **emerging algorithms like COPOD**.
* **Advanced models**\ , including **classical ones from scikit-learn**, **latest deep learning methods**, and **emerging algorithms like ECOD**.
* **Optimized performance with JIT and parallelization** when possible, using `numba <https://github.com/numba/numba>`_ and `joblib <https://github.com/joblib/joblib>`_.
* **Fast training & prediction with SUOD** [#Zhao2021SUOD]_.
* **Compatible with both Python 2 & 3**.
Expand All @@ -86,9 +86,9 @@ PyOD is featured for:
.. code-block:: python
# train the COPOD detector
from pyod.models.copod import COPOD
clf = COPOD()
# train the ECOD detector
from pyod.models.ecod import ECOD
clf = ECOD()
clf.fit(X_train)
# get outlier scores
Expand Down Expand Up @@ -307,7 +307,7 @@ PyOD toolkit consists of three major functional groups:
=================== ================== ====================================================================================================== ===== ========================================
Type Abbr Algorithm Year Ref
=================== ================== ====================================================================================================== ===== ========================================
Probabilistic ECOD Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions 2021 [#Li2021ECOD]_
Probabilistic ECOD Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions 2022 [#Li2021ECOD]_
Probabilistic ABOD Angle-Based Outlier Detection 2008 [#Kriegel2008Angle]_
Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 [#Kriegel2008Angle]_
Probabilistic COPOD COPOD: Copula-Based Outlier Detection 2020 [#Li2020COPOD]_
Expand Down Expand Up @@ -572,7 +572,7 @@ Reference
.. [#Li2020COPOD] Li, Z., Zhao, Y., Botta, N., Ionescu, C. and Hu, X. COPOD: Copula-Based Outlier Detection. *IEEE International Conference on Data Mining (ICDM)*, 2020.
.. [#Li2021ECOD] Li, Z., Zhao, Y., Hu, X., Botta, N., Ionescu, C. and Chen, H. G. ECOD: Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions. arXiv preprint arXiv:2201.00382 (2021).
.. [#Li2021ECOD] Li, Z., Zhao, Y., Hu, X., Botta, N., Ionescu, C. and Chen, H. G. ECOD: Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions. *IEEE Transactions on Knowledge and Data Engineering (TKDE)*, 2022.
.. [#Liu2008Isolation] Liu, F.T., Ting, K.M. and Zhou, Z.H., 2008, December. Isolation forest. In *International Conference on Data Mining*\ , pp. 413-422. IEEE.
Expand Down
14 changes: 7 additions & 7 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ multivariate data. This exciting yet challenging field is commonly referred as
or `Anomaly Detection <https://en.wikipedia.org/wiki/Anomaly_detection>`_.

PyOD includes more than 30 detection algorithms, from classical LOF (SIGMOD 2000) to
the latest COPOD (ICDM 2020) and SUOD (MLSys 2021). Since 2017, PyOD :cite:`a-zhao2019pyod` has been successfully used in numerous
academic researches and commercial products :cite:`a-zhao2019lscp,a-zhao2021suod`.
the latest SUOD (MLSys 2021) and ECOD (TKDE 2020). Since 2017, PyOD :cite:`a-zhao2019pyod` has been successfully used in numerous
academic researches and commercial products :cite:`a-zhao2019lscp,a-zhao2021suod` with more than 5 million downloads.
It is also well acknowledged by the machine learning community with various dedicated posts/tutorials, including
`Analytics Vidhya <https://www.analyticsvidhya.com/blog/2019/02/outlier-detection-python-pyod/>`_,
`Towards Data Science <https://towardsdatascience.com/anomaly-detection-for-dummies-15f148e559c1>`_,
Expand All @@ -80,7 +80,7 @@ It is also well acknowledged by the machine learning community with various dedi
PyOD is featured for:

* **Unified APIs, detailed documentation, and interactive examples** across various algorithms.
* **Advanced models**\ , including **classical ones from scikit-learn**, **latest deep learning methods**, and **emerging algorithms like COPOD**.
* **Advanced models**\ , including **classical ones from scikit-learn**, **latest deep learning methods**, and **emerging algorithms like ECOD**.
* **Optimized performance with JIT and parallelization** when possible, using `numba <https://github.com/numba/numba>`_ and `joblib <https://github.com/joblib/joblib>`_.
* **Fast training & prediction with SUOD** :cite:`a-zhao2021suod`.
* **Compatible with both Python 2 & 3**.
Expand All @@ -92,9 +92,9 @@ PyOD is featured for:
.. code-block:: python
# train the COPOD detector
from pyod.models.copod import COPOD
clf = COPOD()
# train the ECOD detector
from pyod.models.ecod import ECOD
clf = ECOD()
clf.fit(X_train)
# get outlier scores
Expand Down Expand Up @@ -146,7 +146,7 @@ PyOD toolkit consists of three major functional groups:
=================== ================ ====================================================================================================== ===== =================================================== ======================================================
Type Abbr Algorithm Year Class Ref
=================== ================ ====================================================================================================== ===== =================================================== ======================================================
Probabilistic ECOD Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions 2021 :class:`pyod.models.ecod.ECOD` :cite:`a-li2021ecod`
Probabilistic ECOD Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions 2022 :class:`pyod.models.ecod.ECOD` :cite:`a-li2021ecod`
Probabilistic COPOD COPOD: Copula-Based Outlier Detection 2020 :class:`pyod.models.copod.COPOD` :cite:`a-li2020copod`
Probabilistic ABOD Angle-Based Outlier Detection 2008 :class:`pyod.models.abod.ABOD` :cite:`a-kriegel2008angle`
Probabilistic FastABOD Fast Angle-Based Outlier Detection using approximation 2008 :class:`pyod.models.abod.ABOD` :cite:`a-kriegel2008angle`
Expand Down
5 changes: 3 additions & 2 deletions docs/zreferences.bib
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ @inproceedings{perini2020quantifying
@article{Li2021ecod,
title={ECOD: Unsupervised Outlier Detection Using Empirical Cumulative Distribution Functions},
author={Li, Zheng and Zhao, Yue and Hu, Xiyang and Botta, Nicola and Ionescu, Cezar and Chen, H. George},
journal={arXiv preprint arXiv:2201.00382},
year={2021}
journal={IEEE Transactions on Knowledge and Data Engineering},
year={2022},
publisher={IEEE}
}
4 changes: 4 additions & 0 deletions examples/iforest_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
print("\nOn Test Data:")
evaluate_print(clf_name, y_test, y_test_scores)

# example of the feature importance
feature_importance = clf.feature_importances_
print("Feature importance", feature_importance)

# visualize the results
visualize(clf_name, X_train, y_train, X_test, y_test, y_train_pred,
y_test_pred, show_figure=True, save_figure=False)
40 changes: 40 additions & 0 deletions pyod/models/iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from __future__ import division
from __future__ import print_function

import numpy as np
from joblib import Parallel
from joblib.parallel import delayed
from sklearn.utils.fixes import _joblib_parallel_args

from sklearn.ensemble import IsolationForest
from sklearn.utils.validation import check_is_fitted
from sklearn.utils import check_array
Expand Down Expand Up @@ -278,3 +283,38 @@ def max_samples_(self):
Decorator for scikit-learn Isolation Forest attributes.
"""
return self.detector_.max_samples_

@property
def feature_importances_(self):
"""The impurity-based feature importance. The higher, the more
important the feature. The importance of a feature is computed as the
(normalized) total reduction of the criterion brought by that feature.
It is also known as the Gini importance.
.. warning::
impurity-based feature importance can be misleading for
high cardinality features (many unique values). See
https://scikit-learn.org/stable/modules/generated/sklearn.inspection.permutation_importance.html
as an alternative.
Returns
-------
feature_importances_ : ndarray of shape (n_features,)
The values of this array sum to 1, unless all trees are single node
trees consisting of only the root node, in which case it will be an
array of zeros.
"""
check_is_fitted(self)
all_importances = Parallel(
n_jobs=self.n_jobs, **_joblib_parallel_args(prefer="threads")
)(
delayed(getattr)(tree, "feature_importances_")
for tree in self.detector_.estimators_
if tree.tree_.node_count > 1
)

if not all_importances:
return np.zeros(self.n_features_in_, dtype=np.float64)

all_importances = np.mean(all_importances, axis=0, dtype=np.float64)
return all_importances / np.sum(all_importances)
4 changes: 4 additions & 0 deletions pyod/test/test_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def test_predict_rank_normalized(self):
def test_model_clone(self):
clone_clf = clone(self.clf)

def test_feature_importances(self):
feature_importances = self.clf.feature_importances_
assert (len(feature_importances) == 2)

def tearDown(self):
pass

Expand Down
2 changes: 1 addition & 1 deletion pyod/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
#
__version__ = '0.9.7' # pragma: no cover
__version__ = '0.9.8' # pragma: no cover

0 comments on commit c5cdb11

Please sign in to comment.