Skip to content

Commit

Permalink
fix var len layer
Browse files Browse the repository at this point in the history
  • Loading branch information
oaksharks committed Nov 20, 2024
1 parent 76af489 commit dcf28c4
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 61 deletions.
19 changes: 2 additions & 17 deletions deeptables/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,20 +931,11 @@ def __init__(self, pooling_strategy='max', dropout_rate=0., **kwargs):
super(VarLenColumnEmbedding, self).__init__(**kwargs)

def build(self, input_shape):
import keras
super(VarLenColumnEmbedding, self).build(input_shape)

height = input_shape[1]
if self.pooling_strategy == "mean":
self._pooling_layer = keras.layers.AveragePooling2D(pool_size=(height, 1))
else:
self._pooling_layer = keras.layers.MaxPooling2D(pool_size=(height, 1))

if self.dropout_rate > 0:
self._dropout = SpatialDropout1D(self.dropout_rate)
else:
self._dropout = None

self.built = True

def call(self, inputs):
Expand All @@ -957,14 +948,8 @@ def call(self, inputs):
else:
dropout_output = embedding_output

# 3. expand dim for polling
inputs_4d = tf.expand_dims(dropout_output, 3) # add channels dim

# 4. polling
tensor_pooling = self._pooling_layer(inputs_4d)

# 5. format output
return tf.squeeze(tensor_pooling, 3)
# 3. format output
return dropout_output

def compute_mask(self, inputs, mask):
return None
Expand Down
60 changes: 30 additions & 30 deletions deeptables/tests/models/var_len_categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,33 @@
from hypernets.tabular import get_tool_box


# class TestVarLenCategoricalFeature:
#
# def setup_class(cls):
# cls.df = dsutils.load_movielens().drop(['timestamp', "title"], axis=1)
#
# def test_var_categorical_feature(self):
# X = self.df.copy()
# y = X.pop('rating').values.astype('float32')
#
# conf = deeptable.ModelConfig(nets=['dnn_nets'],
# task=consts.TASK_REGRESSION,
# categorical_columns=["movie_id", "user_id", "gender", "occupation", "zip", "title",
# "age"],
# metrics=['mse'],
# fixed_embedding_dim=True,
# embeddings_output_dim=4,
# apply_gbm_features=False,
# apply_class_weight=True,
# earlystopping_patience=5,
# var_len_categorical_columns=[('genres', "|", "max")]
# )
#
# dt = deeptable.DeepTable(config=conf)
#
# X_train, X_validation, y_train, y_validation = get_tool_box(X).train_test_split(X, y, test_size=0.2)
#
# model, history = dt.fit(X_train, y_train, validation_data=(X_validation, y_validation),
# epochs=10, batch_size=32)
#
# assert 'genres' in model.model.input_names
class TestVarLenCategoricalFeature:

def setup_class(cls):
cls.df = dsutils.load_movielens().drop(['timestamp', "title"], axis=1)

def test_var_categorical_feature(self):
X = self.df.copy()
y = X.pop('rating').values.astype('float32')

conf = deeptable.ModelConfig(nets=['dnn_nets'],
task=consts.TASK_REGRESSION,
categorical_columns=["movie_id", "user_id", "gender", "occupation", "zip", "title",
"age"],
metrics=['mse'],
fixed_embedding_dim=True,
embeddings_output_dim=4,
apply_gbm_features=False,
apply_class_weight=True,
earlystopping_patience=5,
var_len_categorical_columns=[('genres', "|", "max")]
)

dt = deeptable.DeepTable(config=conf)

X_train, X_validation, y_train, y_validation = get_tool_box(X).train_test_split(X, y, test_size=0.2)

model, history = dt.fit(X_train, y_train, validation_data=(X_validation, y_validation),
epochs=10, batch_size=32)
names = [_.name for _ in model.model.inputs]
assert 'genres' in names
28 changes: 14 additions & 14 deletions deeptables/tests/models/zdask_var_len_categorical_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# -*- encoding: utf-8 -*-
# from hypernets.tests.tabular.tb_dask import is_dask_installed, if_dask_ready, setup_dask
# from .var_len_categorical_test import TestVarLenCategoricalFeature
#
# if is_dask_installed:
# import dask.dataframe as dd
#
from hypernets.tests.tabular.tb_dask import is_dask_installed, if_dask_ready, setup_dask
from .var_len_categorical_test import TestVarLenCategoricalFeature

# @if_dask_ready
# class TestVarLenCategoricalFeatureByDask(TestVarLenCategoricalFeature):
#
# def setup_class(self):
# TestVarLenCategoricalFeature.setup_class(self)
#
# setup_dask(self)
# self.df = dd.from_pandas(self.df, npartitions=2)
if is_dask_installed:
import dask.dataframe as dd


@if_dask_ready
class TestVarLenCategoricalFeatureByDask(TestVarLenCategoricalFeature):

def setup_class(self):
TestVarLenCategoricalFeature.setup_class(self)

setup_dask(self)
self.df = dd.from_pandas(self.df, npartitions=2)

0 comments on commit dcf28c4

Please sign in to comment.