Skip to content

Latest commit

 

History

History
98 lines (59 loc) · 2.59 KB

README.md

File metadata and controls

98 lines (59 loc) · 2.59 KB

MTRec

MTRec is a simple multi-task recommendation package based Tensorflow2.x. You can install it to invoke some classic models, such as MMoE, and we also provide a test file to help you better use it.

Let's get started!

Installation in Python

MTRec is on PyPI, so you can use pip to install it.

pip install mtrec==0.0.1

pip install mtrec

The default dependency that MTRec has are Tensorflow2.x.

Example

There are some simple tests to use mtrec package.

First,Dataset.

  1. MTREC stipulates that features must be sparse discrete features, and continuous features need discrete buckets.
  2. The dataset is output in the form of a dictionary, for example{'name1':[n1,n2,...], 'name2':[n1,n2,...]}

Second, Build Model.

from mtrec import MMoE

task_names = ['task1', 'task2']
num_experts = 3

model = MMoE(task_names, num_experts, sparse_feature_columns)

for sparse_feature_columns,

from mtrec.functions.feature_column import sparseFeature

embed_dim = 4

sparse_feature_columns = [sparseFeature(feat, len(data_df[feat].unique()), embed_dim=embed_dim) for feat in sparse_features]

see the test/utils.py file for specific details.

Third, Compile, Fit and Predict

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import AUC

model.compile(loss={'task1': 'binary_crossentropy', 'task2': 'binary_crossentropy'},
                  optimizer=Adam(learning_rate=learning_rate),
                  metrics=[AUC()])

model.fit(
        train_X,
        train_y,
        epochs=epochs,
        batch_size=batch_size,
    )

pred = model.predict(test_X, batch_size=batch_size)

Model

Model Paper Published Author/Group
MMoE Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts KDD, 2018 Google

Discussion

  1. If you have any suggestions or questions about the project, you can leave a comment on Issue or email [email protected].
  2. wechat: