From f60813ff48527822c6a865f015a5e60f599116cb Mon Sep 17 00:00:00 2001 From: "yzhao062@gmail.com" <9io9mZ9K#nNL> Date: Mon, 3 Jul 2023 18:04:05 +0800 Subject: [PATCH] sklearn version change --- CHANGES.txt | 3 ++- pyod/models/iforest.py | 32 +++++++++----------------------- pyod/models/sampling.py | 8 +++++++- pyod/utils/utility.py | 8 ++++++-- requirements.txt | 2 +- 5 files changed, 25 insertions(+), 28 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 23701b235..6e98df804 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -177,4 +177,5 @@ v<1.0.8>, <03/08/2023> -- Improve clone compatibility (#471). v<1.0.8>, <03/08/2023> -- Add QMCD detector (#452). v<1.0.8>, <03/08/2023> -- Optimized ECDF and drop Statsmodels dependency (#467). v<1.0.9>, <03/19/2023> -- Hot fix for errors in ECOD and COPOD due to the issue of scipy. -v<1.1.0>, <06/19/2023> -- Further integration of PyThresh. \ No newline at end of file +v<1.1.0>, <06/19/2023> -- Further integration of PyThresh. +v<1.1.1>, <07/03/2023> -- Bump up sklearn requirement and some hot fixes. \ No newline at end of file diff --git a/pyod/models/iforest.py b/pyod/models/iforest.py index f45c69f40..02259992d 100644 --- a/pyod/models/iforest.py +++ b/pyod/models/iforest.py @@ -16,7 +16,6 @@ from .base import BaseDetector # noinspection PyProtectedMember -from ..utils.utility import _get_sklearn_version from ..utils.utility import invert_order @@ -207,28 +206,15 @@ def fit(self, X, y=None): # In sklearn 0.20+ new behaviour is added (arg behaviour={'new','old'}) # to IsolationForest that shifts the location of the anomaly scores # noinspection PyProtectedMember - sklearn_version = _get_sklearn_version() - if sklearn_version == 21: - self.detector_ = IsolationForest(n_estimators=self.n_estimators, - max_samples=self.max_samples, - contamination=self.contamination, - max_features=self.max_features, - bootstrap=self.bootstrap, - n_jobs=self.n_jobs, - behaviour=self.behaviour, - random_state=self.random_state, - verbose=self.verbose) - - # Do not pass behaviour argument when sklearn version is < 0.20 or >0.21 - else: # pragma: no cover - self.detector_ = IsolationForest(n_estimators=self.n_estimators, - max_samples=self.max_samples, - contamination=self.contamination, - max_features=self.max_features, - bootstrap=self.bootstrap, - n_jobs=self.n_jobs, - random_state=self.random_state, - verbose=self.verbose) + + self.detector_ = IsolationForest(n_estimators=self.n_estimators, + max_samples=self.max_samples, + contamination=self.contamination, + max_features=self.max_features, + bootstrap=self.bootstrap, + n_jobs=self.n_jobs, + random_state=self.random_state, + verbose=self.verbose) self.detector_.fit(X=X, y=None, sample_weight=None) diff --git a/pyod/models/sampling.py b/pyod/models/sampling.py index 669764e49..cd42564ee 100644 --- a/pyod/models/sampling.py +++ b/pyod/models/sampling.py @@ -7,12 +7,18 @@ from __future__ import division, print_function import numpy as np -from sklearn.neighbors import DistanceMetric + from sklearn.utils import check_array, check_random_state from sklearn.utils.validation import check_is_fitted from .base import BaseDetector +from ..utils.utility import _get_sklearn_version +sklearn_version = _get_sklearn_version() +if sklearn_version[:3] >= '1.3': + from sklearn.metrics import DistanceMetric +else: + from sklearn.neighbors import DistanceMetric class Sampling(BaseDetector): """Sampling class for outlier detection. diff --git a/pyod/utils/utility.py b/pyod/utils/utility.py index 8f08c6e16..645374de6 100644 --- a/pyod/utils/utility.py +++ b/pyod/utils/utility.py @@ -443,8 +443,11 @@ def _get_sklearn_version(): # pragma: no cover # if int(sklearn_version.split(".")[1]) < 19 or int( # sklearn_version.split(".")[1]) > 24: # raise ValueError("Sklearn version error") + # print(sklearn_version) + + return sklearn_version - return int(sklearn_version.split(".")[1]) + # return int(sklearn_version.split(".")[1]) # def _sklearn_version_21(): # pragma: no cover @@ -544,6 +547,7 @@ def generate_indices(random_state, bootstrap, n_population, n_samples): return indices + # todo: add a test for it in test_utility.py def get_optimal_n_bins(X, upper_bound=None, epsilon=1): """ Determine optimal number of bins for a histogram using the Birge @@ -579,6 +583,6 @@ def get_optimal_n_bins(X, upper_bound=None, epsilon=1): maximum_likelihood[i] = np.sum( histogram * np.log(b * histogram / n + epsilon) - ( - b - 1 + np.power(np.log(b), 2.5))) + b - 1 + np.power(np.log(b), 2.5))) return np.argmax(maximum_likelihood) + 1 diff --git a/requirements.txt b/requirements.txt index b6a124017..532873d40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,5 @@ matplotlib numpy>=1.19 numba>=0.51 scipy>=1.5.1 -scikit_learn>=0.20.0 +scikit_learn>=0.22.0 six \ No newline at end of file