Skip to content

Commit

Permalink
added data checking logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Nabil Fayak committed Sep 8, 2023
1 parent 7c3e224 commit 9abccac
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 91 deletions.
182 changes: 92 additions & 90 deletions checkmates/data_checks/checks/distribution_data_check.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
"""Data check that checks if the target data contains certain distributions that may need to be transformed prior training to improve model performance."""
import diptest
import numpy as np
"""Data check that screens data for skewed or bimodal distrbutions prior to model training to ensure model performance is unaffected."""
import woodwork as ww

from checkmates.data_checks import (
DataCheck,
DataCheckActionCode,
DataCheckActionOption,
DataCheckError,
DataCheckMessageCode,
DataCheckWarning,
)
from checkmates.utils import infer_feature_types

from scipy.stats import skew
from diptest import diptest

class DistributionDataCheck(DataCheck):
"""Check if the overall data contains certain distributions that may need to be transformed prior training to improve model performance. Uses the skew test and yeojohnson transformation."""
Expand Down Expand Up @@ -43,7 +40,7 @@ def validate(self, X, y):
... "details": {"distribution type": "positive skew", "Skew Value": 0.7939, "Bimodal Coefficient": 1.0,},
... "action_options": [
... {
... "code": "TRANSFORM_TARGET",
... "code": "TRANSFORM_FEATURES",
... "data_check_name": "DistributionDataCheck",
... "parameters": {},
... "metadata": {
Expand All @@ -54,97 +51,51 @@ def validate(self, X, y):
... ]
... }
... ]
...
>>> X = pd.Series([1, 1, 1, 2, 2, 3, 4, 4, 5, 5, 5])
>>> assert target_check.validate(X, y) == []
...
...
>>> X = pd.Series(pd.date_range("1/1/21", periods=10))
>>> assert target_check.validate(X, y) == [
... {
... "message": "Target is unsupported datetime type. Valid Woodwork logical types include: integer, double, age, age_fractional",
... "data_check_name": "DistributionDataCheck",
... "level": "error",
... "details": {"columns": None, "rows": None, "unsupported_type": "datetime"},
... "code": "TARGET_UNSUPPORTED_TYPE",
... "action_options": []
... }
... ]
"""
messages = []

if y is None:
messages.append(
DataCheckError(
message="Data is None",
data_check_name=self.name,
message_code=DataCheckMessageCode.TARGET_IS_NONE,
details={},
).to_dict(),
)
return messages

y = infer_feature_types(y)
allowed_types = [
ww.logical_types.Integer.type_string,
ww.logical_types.Double.type_string,
ww.logical_types.Age.type_string,
ww.logical_types.AgeFractional.type_string,
]
is_supported_type = y.ww.logical_type.type_string in allowed_types

if not is_supported_type:
messages.append(
DataCheckError(
message="Target is unsupported {} type. Valid Woodwork logical types include: {}".format(
y.ww.logical_type.type_string,
", ".join([ltype for ltype in allowed_types]),
),
data_check_name=self.name,
message_code=DataCheckMessageCode.TARGET_UNSUPPORTED_TYPE,
details={"unsupported_type": X.ww.logical_type.type_string},
).to_dict(),
)
return messages

(
is_skew,
distribution_type,
skew_value,
coef,
) = _detect_skew_distribution_helper(X)

if is_skew:
details = {
"distribution type": distribution_type,
"Skew Value": skew_value,
"Bimodal Coefficient": coef,
}
messages.append(
DataCheckWarning(
message="Data may have a skewed distribution.",
data_check_name=self.name,
message_code=DataCheckMessageCode.SKEWED_DISTRIBUTION,
details=details,
action_options=[
DataCheckActionOption(
DataCheckActionCode.TRANSFORM_TARGET,
data_check_name=self.name,
metadata={
"is_skew": True,
"transformation_strategy": "yeojohnson",
},
),
],
).to_dict(),
)
numeric_X = X.ww.select(["Integer", "Double"])

for col in numeric_X:
(
is_skew,
distribution_type,
skew_value,
coef,
) = _detect_skew_distribution_helper(col)

if is_skew:
details = {
"distribution type": distribution_type,
"Skew Value": skew_value,
"Bimodal Coefficient": coef,
}
messages.append(
DataCheckWarning(
message="Data may have a skewed distribution.",
data_check_name=self.name,
message_code=DataCheckMessageCode.SKEWED_DISTRIBUTION,
details=details,
action_options=[
DataCheckActionOption(
DataCheckActionCode.TRANSFORM_FEATURES,
data_check_name=self.name,
metadata={
"is_skew": True,
"transformation_strategy": "yeojohnson",
"columns" : col
},
),
],
).to_dict(),
)
return messages


def _detect_skew_distribution_helper(X):
"""Helper method to detect skewed or bimodal distribution. Returns boolean, distribution type, the skew value, and bimodal coefficient."""
skew_value = np.stats.skew(X)
coef = diptest.diptest(X)[1]
skew_value = skew(X)
coef = diptest(X)[1]

if coef < 0.05:
return True, "bimodal distribution", skew_value, coef
Expand All @@ -153,3 +104,54 @@ def _detect_skew_distribution_helper(X):
if skew_value > 0.5:
return True, "positive skew", skew_value, coef
return False, "no skew", skew_value, coef


# Testing Data to make sure skews are recognized-- successful
# import numpy as np
# import pandas as pd
# data = {
# 'Column1': np.random.normal(0, 1, 1000), # Normally distributed data
# 'Column2': np.random.exponential(1, 1000), # Right-skewed data
# 'Column3': np.random.gamma(2, 2, 1000) # Right-skewed data
# }

# df = pd.DataFrame(data)
# df.ww.init()
# messages = []

# numeric_X = df.ww.select(["Integer", "Double"])
# print(numeric_X)
# for col in numeric_X:
# (
# is_skew,
# distribution_type,
# skew_value,
# coef,
# ) = _detect_skew_distribution_helper(numeric_X['Column2'])

# if is_skew:
# details = {
# "distribution type": distribution_type,
# "Skew Value": skew_value,
# "Bimodal Coefficient": coef,
# }
# messages.append(
# DataCheckWarning(
# message="Data may have a skewed distribution.",
# data_check_name="Distribution Data Check",
# message_code=DataCheckMessageCode.SKEWED_DISTRIBUTION,
# details=details,
# action_options=[
# DataCheckActionOption(
# DataCheckActionCode.TRANSFORM_FEATURES,
# data_check_name="Distribution Data Check",
# metadata={
# "is_skew": True,
# "transformation_strategy": "yeojohnson",
# "columns" : col
# },
# ),
# ],
# ).to_dict(),
# )
# print(messages)
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class DataCheckActionCode(Enum):
TRANSFORM_TARGET = "transform_target"
"""Action code for transforming the target data."""

TRANSFORM_FEATURES = "transform_features"
"""Action code for transforming the features data."""

REGULARIZE_AND_IMPUTE_DATASET = "regularize_and_impute_dataset"
"""Action code for regularizing and imputing all features and target time series data."""

Expand Down
6 changes: 6 additions & 0 deletions checkmates/pipelines/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def transform(self, X, y=None):
Returns:
pd.DataFrame: Transformed X
"""

# If there are no columns to normalize, return early
if not self._cols_to_normalize:
return self

X = X[self._cols_to_normalize]
# Transform the data
X_t = yeojohnson(X)

Expand Down
8 changes: 8 additions & 0 deletions checkmates/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TimeSeriesRegularizer,
)
from checkmates.pipelines.training_validation_split import TrainingValidationSplit
from checkmates.pipelines.transformers import SimpleNormalizer
from checkmates.problem_types import is_classification, is_regression, is_time_series
from checkmates.utils import infer_feature_types

Expand All @@ -31,6 +32,8 @@ def _make_component_list_from_actions(actions):
components = []
cols_to_drop = []
indices_to_drop = []
cols_to_normalize = []


for action in actions:
if action.action_code == DataCheckActionCode.REGULARIZE_AND_IMPUTE_DATASET:
Expand All @@ -47,6 +50,8 @@ def _make_component_list_from_actions(actions):
)
elif action.action_code == DataCheckActionCode.DROP_COL:
cols_to_drop.extend(action.metadata["columns"])
elif action.action_code == DataCheckActionCode.TRANSFORM_FEATURES:
cols_to_normalize.extend(action.metadata["columns"])
elif action.action_code == DataCheckActionCode.IMPUTE_COL:
metadata = action.metadata
parameters = metadata.get("parameters", {})
Expand All @@ -65,6 +70,9 @@ def _make_component_list_from_actions(actions):
if indices_to_drop:
indices_to_drop = sorted(set(indices_to_drop))
components.append(DropRowsTransformer(indices_to_drop=indices_to_drop))
if cols_to_normalize:
cols_to_normalize = sorted(set(cols_to_normalize))
components.append(SimpleNormalizer(columns=cols_to_normalize))

return components

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dependencies = [
"click>=8.0.0",
"black[jupyter]>=22.3.0",
"diptest>=0.5.2",
"scipy>=1.9.3",
]
requires-python = ">=3.8,<4.0"
readme = "README.md"
Expand Down

0 comments on commit 9abccac

Please sign in to comment.