diff --git a/checkmates/data_checks/checks/distribution_data_check.py b/checkmates/data_checks/checks/distribution_data_check.py index 236b874..ce21bdf 100644 --- a/checkmates/data_checks/checks/distribution_data_check.py +++ b/checkmates/data_checks/checks/distribution_data_check.py @@ -1,18 +1,15 @@ -"""Data check that checks if the target data contains certain distributions that may need to be transformed prior training to improve model performance.""" -import diptest -import numpy as np +"""Data check that screens data for skewed or bimodal distrbutions prior to model training to ensure model performance is unaffected.""" import woodwork as ww from checkmates.data_checks import ( DataCheck, DataCheckActionCode, DataCheckActionOption, - DataCheckError, DataCheckMessageCode, DataCheckWarning, ) -from checkmates.utils import infer_feature_types - +from scipy.stats import skew +from diptest import diptest class DistributionDataCheck(DataCheck): """Check if the overall data contains certain distributions that may need to be transformed prior training to improve model performance. Uses the skew test and yeojohnson transformation.""" @@ -43,7 +40,7 @@ def validate(self, X, y): ... "details": {"distribution type": "positive skew", "Skew Value": 0.7939, "Bimodal Coefficient": 1.0,}, ... "action_options": [ ... { - ... "code": "TRANSFORM_TARGET", + ... "code": "TRANSFORM_FEATURES", ... "data_check_name": "DistributionDataCheck", ... "parameters": {}, ... "metadata": { @@ -54,97 +51,51 @@ def validate(self, X, y): ... ] ... } ... ] - ... - >>> X = pd.Series([1, 1, 1, 2, 2, 3, 4, 4, 5, 5, 5]) - >>> assert target_check.validate(X, y) == [] - ... - ... - >>> X = pd.Series(pd.date_range("1/1/21", periods=10)) - >>> assert target_check.validate(X, y) == [ - ... { - ... "message": "Target is unsupported datetime type. Valid Woodwork logical types include: integer, double, age, age_fractional", - ... "data_check_name": "DistributionDataCheck", - ... "level": "error", - ... "details": {"columns": None, "rows": None, "unsupported_type": "datetime"}, - ... "code": "TARGET_UNSUPPORTED_TYPE", - ... "action_options": [] - ... } - ... ] """ messages = [] - if y is None: - messages.append( - DataCheckError( - message="Data is None", - data_check_name=self.name, - message_code=DataCheckMessageCode.TARGET_IS_NONE, - details={}, - ).to_dict(), - ) - return messages - - y = infer_feature_types(y) - allowed_types = [ - ww.logical_types.Integer.type_string, - ww.logical_types.Double.type_string, - ww.logical_types.Age.type_string, - ww.logical_types.AgeFractional.type_string, - ] - is_supported_type = y.ww.logical_type.type_string in allowed_types - - if not is_supported_type: - messages.append( - DataCheckError( - message="Target is unsupported {} type. Valid Woodwork logical types include: {}".format( - y.ww.logical_type.type_string, - ", ".join([ltype for ltype in allowed_types]), - ), - data_check_name=self.name, - message_code=DataCheckMessageCode.TARGET_UNSUPPORTED_TYPE, - details={"unsupported_type": X.ww.logical_type.type_string}, - ).to_dict(), - ) - return messages - - ( - is_skew, - distribution_type, - skew_value, - coef, - ) = _detect_skew_distribution_helper(X) - - if is_skew: - details = { - "distribution type": distribution_type, - "Skew Value": skew_value, - "Bimodal Coefficient": coef, - } - messages.append( - DataCheckWarning( - message="Data may have a skewed distribution.", - data_check_name=self.name, - message_code=DataCheckMessageCode.SKEWED_DISTRIBUTION, - details=details, - action_options=[ - DataCheckActionOption( - DataCheckActionCode.TRANSFORM_TARGET, - data_check_name=self.name, - metadata={ - "is_skew": True, - "transformation_strategy": "yeojohnson", - }, - ), - ], - ).to_dict(), - ) + numeric_X = X.ww.select(["Integer", "Double"]) + + for col in numeric_X: + ( + is_skew, + distribution_type, + skew_value, + coef, + ) = _detect_skew_distribution_helper(col) + + if is_skew: + details = { + "distribution type": distribution_type, + "Skew Value": skew_value, + "Bimodal Coefficient": coef, + } + messages.append( + DataCheckWarning( + message="Data may have a skewed distribution.", + data_check_name=self.name, + message_code=DataCheckMessageCode.SKEWED_DISTRIBUTION, + details=details, + action_options=[ + DataCheckActionOption( + DataCheckActionCode.TRANSFORM_FEATURES, + data_check_name=self.name, + metadata={ + "is_skew": True, + "transformation_strategy": "yeojohnson", + "columns" : col + }, + ), + ], + ).to_dict(), + ) return messages def _detect_skew_distribution_helper(X): """Helper method to detect skewed or bimodal distribution. Returns boolean, distribution type, the skew value, and bimodal coefficient.""" - skew_value = np.stats.skew(X) - coef = diptest.diptest(X)[1] + skew_value = skew(X) + coef = diptest(X)[1] if coef < 0.05: return True, "bimodal distribution", skew_value, coef @@ -153,3 +104,54 @@ def _detect_skew_distribution_helper(X): if skew_value > 0.5: return True, "positive skew", skew_value, coef return False, "no skew", skew_value, coef + + +# Testing Data to make sure skews are recognized-- successful +# import numpy as np +# import pandas as pd +# data = { +# 'Column1': np.random.normal(0, 1, 1000), # Normally distributed data +# 'Column2': np.random.exponential(1, 1000), # Right-skewed data +# 'Column3': np.random.gamma(2, 2, 1000) # Right-skewed data +# } + +# df = pd.DataFrame(data) +# df.ww.init() +# messages = [] + +# numeric_X = df.ww.select(["Integer", "Double"]) +# print(numeric_X) +# for col in numeric_X: +# ( +# is_skew, +# distribution_type, +# skew_value, +# coef, +# ) = _detect_skew_distribution_helper(numeric_X['Column2']) + +# if is_skew: +# details = { +# "distribution type": distribution_type, +# "Skew Value": skew_value, +# "Bimodal Coefficient": coef, +# } +# messages.append( +# DataCheckWarning( +# message="Data may have a skewed distribution.", +# data_check_name="Distribution Data Check", +# message_code=DataCheckMessageCode.SKEWED_DISTRIBUTION, +# details=details, +# action_options=[ +# DataCheckActionOption( +# DataCheckActionCode.TRANSFORM_FEATURES, +# data_check_name="Distribution Data Check", +# metadata={ +# "is_skew": True, +# "transformation_strategy": "yeojohnson", +# "columns" : col +# }, +# ), +# ], +# ).to_dict(), +# ) +# print(messages) \ No newline at end of file diff --git a/checkmates/data_checks/datacheck_meta/data_check_action_code.py b/checkmates/data_checks/datacheck_meta/data_check_action_code.py index 0106221..7558195 100644 --- a/checkmates/data_checks/datacheck_meta/data_check_action_code.py +++ b/checkmates/data_checks/datacheck_meta/data_check_action_code.py @@ -19,6 +19,9 @@ class DataCheckActionCode(Enum): TRANSFORM_TARGET = "transform_target" """Action code for transforming the target data.""" + TRANSFORM_FEATURES = "transform_features" + """Action code for transforming the features data.""" + REGULARIZE_AND_IMPUTE_DATASET = "regularize_and_impute_dataset" """Action code for regularizing and imputing all features and target time series data.""" diff --git a/checkmates/pipelines/transformers.py b/checkmates/pipelines/transformers.py index 815ee2f..2e26fee 100644 --- a/checkmates/pipelines/transformers.py +++ b/checkmates/pipelines/transformers.py @@ -105,6 +105,12 @@ def transform(self, X, y=None): Returns: pd.DataFrame: Transformed X """ + + # If there are no columns to normalize, return early + if not self._cols_to_normalize: + return self + + X = X[self._cols_to_normalize] # Transform the data X_t = yeojohnson(X) diff --git a/checkmates/pipelines/utils.py b/checkmates/pipelines/utils.py index 5f4e555..00776cd 100644 --- a/checkmates/pipelines/utils.py +++ b/checkmates/pipelines/utils.py @@ -15,6 +15,7 @@ TimeSeriesRegularizer, ) from checkmates.pipelines.training_validation_split import TrainingValidationSplit +from checkmates.pipelines.transformers import SimpleNormalizer from checkmates.problem_types import is_classification, is_regression, is_time_series from checkmates.utils import infer_feature_types @@ -31,6 +32,8 @@ def _make_component_list_from_actions(actions): components = [] cols_to_drop = [] indices_to_drop = [] + cols_to_normalize = [] + for action in actions: if action.action_code == DataCheckActionCode.REGULARIZE_AND_IMPUTE_DATASET: @@ -47,6 +50,8 @@ def _make_component_list_from_actions(actions): ) elif action.action_code == DataCheckActionCode.DROP_COL: cols_to_drop.extend(action.metadata["columns"]) + elif action.action_code == DataCheckActionCode.TRANSFORM_FEATURES: + cols_to_normalize.extend(action.metadata["columns"]) elif action.action_code == DataCheckActionCode.IMPUTE_COL: metadata = action.metadata parameters = metadata.get("parameters", {}) @@ -65,6 +70,9 @@ def _make_component_list_from_actions(actions): if indices_to_drop: indices_to_drop = sorted(set(indices_to_drop)) components.append(DropRowsTransformer(indices_to_drop=indices_to_drop)) + if cols_to_normalize: + cols_to_normalize = sorted(set(cols_to_normalize)) + components.append(SimpleNormalizer(columns=cols_to_normalize)) return components diff --git a/pyproject.toml b/pyproject.toml index 93f4105..d19a80d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "click>=8.0.0", "black[jupyter]>=22.3.0", "diptest>=0.5.2", - "scipy>=1.9.3", ] requires-python = ">=3.8,<4.0" readme = "README.md"