Skip to content

Commit

Permalink
Add support for multiple data types (remainder of #196) (#203)
Browse files Browse the repository at this point in the history
Addresses #95 and #111.
Follow-up to #198, #199, #201

Trying again, since #130 failed. This time, I made the Model class to be polymorphic. This way, the amount of pointer indirection is minimized.

Summary: Model is an opaque container that wraps the polymorphic handle ModelImpl<ThresholdType, LeafOutputType>. The handle in turn stores the list of trees Tree<ThresholdType, LeafOutputType>. To unbox the Model container and obtain ModelImpl<ThresholdType, LeafOutputType>, use Model::Dispatch(<lambda expression>).

Also, upgrade to C++14 to access the generic lambda feature, which proved to be very useful in the dispatching logic for the polymorphic Model class.

* Turn the Model and Tree classes into template classes
* Revise the string templates so that correct data types are used in the generated C code
* Rewrite the model builder class
* Revise the zero-copy serializer
* Create an abstract matrix class that supports multiple data types (float32, float64 for now).
* Move the DMatrix class to the runtime.
* Extend the DMatrix class so that it can hold float32 and float64.
* Redesign the C runtime API using the DMatrix class.
* Ensure accuracy of scikit-learn models. To achieve the best results, use float32 for the input matrix and float64 for the split thresholds and leaf outputs.
* Revise the JVM runtime.
  • Loading branch information
hcho3 committed Oct 9, 2020
1 parent 607a92b commit 303fd4f
Show file tree
Hide file tree
Showing 86 changed files with 2,722 additions and 2,317 deletions.
6 changes: 6 additions & 0 deletions cmake/ExternalLibs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ FetchContent_Declare(
GIT_TAG v0.4
)
FetchContent_MakeAvailable(dmlccore)
target_compile_options(dmlc PRIVATE
-D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE)
if (TARGET dmlc_unit_tests)
target_compile_options(dmlc_unit_tests PRIVATE
-D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE)
endif (TARGET dmlc_unit_tests)

# fmtlib
find_package(fmt)
Expand Down
3 changes: 1 addition & 2 deletions include/treelite/annotator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ class BranchAnnotator {
* \param nthread number of threads to use
* \param verbose whether to produce extra messages
*/
void Annotate(const Model& model, const DMatrix* dmat,
int nthread, int verbose);
void Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose);
/*!
* \brief load branch annotation from a JSON file
* \param fi input stream
Expand Down
3 changes: 2 additions & 1 deletion include/treelite/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ enum class Operator : int8_t {
kGT, /*!< operator > */
kGE, /*!< operator >= */
};
/*! \brief conversion table from string to operator, defined in optable.cc */

/*! \brief conversion table from string to Operator, defined in tables.cc */
extern const std::unordered_map<std::string, Operator> optable;

/*!
Expand Down
194 changes: 52 additions & 142 deletions include/treelite/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
* opaque handles
* \{
*/
/*! \brief handle to a data matrix */
typedef void* DMatrixHandle;
/*! \brief handle to a decision tree ensemble model */
typedef void* ModelHandle;
/*! \brief handle to tree builder class */
Expand All @@ -31,100 +29,8 @@ typedef void* ModelBuilderHandle;
typedef void* AnnotationHandle;
/*! \brief handle to compiler class */
typedef void* CompilerHandle;
/*! \} */

/*!
* \defgroup dmatrix
* Data matrix interface
* \{
*/
/*!
* \brief create DMatrix from a file
* \param path file path
* \param format file format
* \param nthread number of threads to use
* \param verbose whether to produce extra messages
* \param out the created DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixCreateFromFile(const char* path,
const char* format,
int nthread,
int verbose,
DMatrixHandle* out);
/*!
* \brief create DMatrix from a (in-memory) CSR matrix
* \param data feature values
* \param col_ind feature indices
* \param row_ptr pointer to row headers
* \param num_row number of rows
* \param num_col number of columns
* \param out the created DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixCreateFromCSR(const float* data,
const unsigned* col_ind,
const size_t* row_ptr,
size_t num_row,
size_t num_col,
DMatrixHandle* out);
/*!
* \brief create DMatrix from a (in-memory) dense matrix
* \param data feature values
* \param num_row number of rows
* \param num_col number of columns
* \param missing_value value to represent missing value
* \param out the created DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixCreateFromMat(const float* data,
size_t num_row,
size_t num_col,
float missing_value,
DMatrixHandle* out);
/*!
* \brief get dimensions of a DMatrix
* \param handle handle to DMatrix
* \param out_num_row used to set number of rows
* \param out_num_col used to set number of columns
* \param out_nelem used to set number of nonzero entries
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixGetDimension(DMatrixHandle handle,
size_t* out_num_row,
size_t* out_num_col,
size_t* out_nelem);

/*!
* \brief produce a human-readable preview of a DMatrix
* Will print first and last 25 non-zero entries, along with their locations
* \param handle handle to DMatrix
* \param out_preview used to save the address of the string literal
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixGetPreview(DMatrixHandle handle,
const char** out_preview);

/*!
* \brief extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
* \param handle handle to DMatrix
* \param out_data used to save pointer to array containing feature values
* \param out_col_ind used to save pointer to array containing feature indices
* \param out_row_ptr used to save pointer to array containing pointers to
* row headers
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixGetArrays(DMatrixHandle handle,
const float** out_data,
const uint32_t** out_col_ind,
const size_t** out_row_ptr);

/*!
* \brief delete DMatrix from memory
* \param handle handle to DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle);
/*! \brief handle to a polymorphic value type, used in the model builder API */
typedef void* ValueHandle;
/*! \} */

/*!
Expand All @@ -142,11 +48,8 @@ TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle);
* \param out used to save handle for the created annotation
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteAnnotateBranch(ModelHandle model,
DMatrixHandle dmat,
int nthread,
int verbose,
AnnotationHandle* out);
TREELITE_DLL int TreeliteAnnotateBranch(
ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out);
/*!
* \brief save branch annotation to a JSON file
* \param handle annotation to save
Expand Down Expand Up @@ -311,12 +214,33 @@ TREELITE_DLL int TreeliteFreeModel(ModelHandle handle);
* Model builder interface: build trees incrementally
* \{
*/
/*!
* \brief Create a new Value object. Some model builder API functions accept this Value type to
* accommodate values of multiple types.
* \param init_value pointer to the value to be stored
* \param type Type of the value to be stored
* \param out newly created Value object
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderCreateValue(const void* init_value, const char* type,
ValueHandle* out);
/*!
* \brief Delete a Value object from memory
* \param handle pointer to the Value object to be deleted
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderDeleteValue(ValueHandle handle);
/*!
* \brief Create a new tree builder
* \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model
* must have the same type.
* \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the
* same type.
* \param out newly created tree builder
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteCreateTreeBuilder(TreeBuilderHandle* out);
TREELITE_DLL int TreeliteCreateTreeBuilder(const char* threshold_type, const char* leaf_output_type,
TreeBuilderHandle* out);
/*!
* \brief Delete a tree builder from memory
* \param handle tree builder to remove
Expand All @@ -329,24 +253,21 @@ TREELITE_DLL int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle);
* \param node_key unique integer key to identify the new node
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle,
int node_key);
TREELITE_DLL int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key);
/*!
* \brief Remove a node from a tree
* \param handle tree builder
* \param node_key unique integer key to identify the node to be removed
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle,
int node_key);
TREELITE_DLL int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key);
/*!
* \brief Set a node as the root of a tree
* \param handle tree builder
* \param node_key unique integer key to identify the root node
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle,
int node_key);
TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key);
/*!
* \brief Turn an empty node into a test node with numerical split.
* The test is in the form [feature value] OP [threshold]. Depending on the
Expand All @@ -363,12 +284,8 @@ TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle,
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode(
TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const char* opname,
float threshold, int default_left,
int left_child_key,
int right_child_key);
TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname,
ValueHandle threshold, int default_left, int left_child_key, int right_child_key);
/*!
* \brief Turn an empty node into a test node with categorical split.
* A list defines all categories that would be classified as the left side.
Expand All @@ -386,13 +303,9 @@ TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode(
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode(
TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const unsigned int* left_categories,
size_t left_categories_len,
int default_left,
int left_child_key,
int right_child_key);
TreeBuilderHandle handle, int node_key, unsigned feature_id,
const unsigned int* left_categories, size_t left_categories_len, int default_left,
int left_child_key, int right_child_key);
/*!
* \brief Turn an empty node into a leaf node
* \param handle tree builder
Expand All @@ -401,9 +314,8 @@ TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode(
* \param leaf_value leaf value (weight) of the leaf node
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle,
int node_key,
float leaf_value);
TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(
TreeBuilderHandle handle, int node_key, ValueHandle leaf_value);
/*!
* \brief Turn an empty node into a leaf vector node
* The leaf vector (collection of multiple leaf weights per leaf node) is
Expand All @@ -415,29 +327,27 @@ TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle,
* \param leaf_vector_len length of leaf_vector
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle,
int node_key,
const float* leaf_vector,
size_t leaf_vector_len);
TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode(
TreeBuilderHandle handle, int node_key, const ValueHandle* leaf_vector, size_t leaf_vector_len);
/*!
* \brief Create a new model builder
* \param num_feature number of features used in model being built. We assume
* that all feature indices are between 0 and
* (num_feature - 1).
* \param num_output_group number of output groups. Set to 1 for binary
* classification and regression; >1 for multiclass
* classification
* \param random_forest_flag whether the model is a random forest. Set to 0 if
* the model is gradient boosted trees. Any nonzero
* value shall indicate that the model is a
* random forest.
* \param num_feature number of features used in model being built. We assume that all feature
* indices are between 0 and (num_feature - 1).
* \param num_output_group number of output groups. Set to 1 for binary classification and
* regression; >1 for multiclass classification
* \param random_forest_flag whether the model is a random forest. Set to 0 if the model is
* gradient boosted trees. Any nonzero value shall indicate that the
* model is a random forest.
* \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model
* must have the same type.
* \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the
* same type.
* \param out newly created model builder
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteCreateModelBuilder(int num_feature,
int num_output_group,
int random_forest_flag,
ModelBuilderHandle* out);
TREELITE_DLL int TreeliteCreateModelBuilder(
int num_feature, int num_output_group, int random_forest_flag, const char* threshold_type,
const char* leaf_output_type, ModelBuilderHandle* out);
/*!
* \brief Set a model parameter
* \param handle model builder
Expand Down
Loading

0 comments on commit 303fd4f

Please sign in to comment.