From 262f54be3b132d3525298cee91b793d7623992c5 Mon Sep 17 00:00:00 2001 From: louis-huang Date: Fri, 29 Sep 2023 12:58:02 -0700 Subject: [PATCH 1/3] add multi-label support --- xgboost_ray/data_sources/data_source.py | 6 +-- xgboost_ray/matrix.py | 12 ++++-- xgboost_ray/tests/test_matrix.py | 55 +++++++++++++++++++++++-- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/xgboost_ray/data_sources/data_source.py b/xgboost_ray/data_sources/data_source.py index 774bf6c9..c9bcfc7f 100644 --- a/xgboost_ray/data_sources/data_source.py +++ b/xgboost_ray/data_sources/data_source.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union import pandas as pd from ray.actor import ActorHandle @@ -118,12 +118,12 @@ def convert_to_series(data: Any) -> pd.Series: @classmethod def get_column( cls, data: pd.DataFrame, column: Any - ) -> Tuple[pd.Series, Optional[str]]: + ) -> Tuple[pd.Series, Optional[Union[str, List]]]: """Helper method wrapping around convert to series. This method should usually not be overwritten. """ - if isinstance(column, str): + if isinstance(column, str) or isinstance(column, List): return data[column], column elif column is not None: return cls.convert_to_series(column), None diff --git a/xgboost_ray/matrix.py b/xgboost_ray/matrix.py index f3f75239..dc9e895b 100644 --- a/xgboost_ray/matrix.py +++ b/xgboost_ray/matrix.py @@ -31,7 +31,6 @@ from ray.util.annotations import DeveloperAPI, PublicAPI from xgboost_ray.data_sources import DataSource, RayFileType, data_sources - try: from ray.data.dataset import Dataset as RayDataset except ImportError: @@ -307,7 +306,10 @@ def _split_dataframe( label, exclude = data_source.get_column(local_data, self.label) if exclude: - exclude_cols.add(exclude) + if isinstance(exclude, List): + exclude_cols.update(exclude) + else: + exclude_cols.add(exclude) weight, exclude = data_source.get_column(local_data, self.weight) if exclude: @@ -406,7 +408,9 @@ def get_data_source(self) -> Type[DataSource]: ): # noqa: E721: # Label is an object of a different type than the main data. # We have to make sure they are compatible - if not data_source.is_data_type(self.label): + # if it's a parquet data source and label is a list, then we consider it a multi-label data + if not data_source.is_data_type(self.label) \ + and not (isinstance(self.label, List) and data_source.__name__ == "Parquet"): raise ValueError( "The passed `data` and `label` types are not compatible." "\nFIX THIS by passing the same types to the " @@ -521,7 +525,7 @@ def get_data_source(self) -> Type[DataSource]: f"RayDMatrix." ) - if self.label is not None and not isinstance(self.label, str): + if self.label is not None and not isinstance(self.label, str) and not isinstance(self.label, List): raise ValueError( f"Invalid `label` value for distributed datasets: " f"{self.label}. Only strings are supported. " diff --git a/xgboost_ray/tests/test_matrix.py b/xgboost_ray/tests/test_matrix.py index 6c764492..b023cb16 100644 --- a/xgboost_ray/tests/test_matrix.py +++ b/xgboost_ray/tests/test_matrix.py @@ -33,6 +33,15 @@ def setUp(self): * repeat ) self.y = np.array([0, 1, 2, 3] * repeat) + self.multi_y = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 0], + ] + * repeat + ) @classmethod def setUpClass(cls): @@ -62,7 +71,7 @@ def testColumnOrdering(self): assert data.columns.tolist() == cols[:-1] - def _testMatrixCreation(self, in_x, in_y, **kwargs): + def _testMatrixCreation(self, in_x, in_y, multi_label = False, **kwargs): if "sharding" not in kwargs: kwargs["sharding"] = RayShardingMode.BATCH mat = RayDMatrix(in_x, in_y, **kwargs) @@ -81,7 +90,10 @@ def _load_data(params): x, y = _load_data(params) self.assertTrue(np.allclose(self.x, x)) - self.assertTrue(np.allclose(self.y, y)) + if multi_label: + self.assertTrue(np.allclose(self.multi_y, y)) + else: + self.assertTrue(np.allclose(self.y, y)) # Multi actor check mat = RayDMatrix(in_x, in_y, **kwargs) @@ -95,7 +107,10 @@ def _load_data(params): x2, y2 = _load_data(params) self.assertTrue(np.allclose(self.x, concat_dataframes([x1, x2]))) - self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2]))) + if multi_label: + self.assertTrue(np.allclose(self.multi_y, concat_dataframes([y1, y2]))) + else: + self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2]))) def testFromNumpy(self): in_x = self.x @@ -276,6 +291,18 @@ def testFromMultiCSVString(self): [data_file_1, data_file_2], "label", distributed=True ) + def testFromParquetStringMultiLabel(self): + with tempfile.TemporaryDirectory() as dir: + data_file = os.path.join(dir, "data.parquet") + + data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"]) + labels = [f"label_{l}" for l in range(4)] + data_df[labels] = self.multi_y + data_df.to_parquet(data_file) + + self._testMatrixCreation(data_file, labels, multi_label=True, distributed=False) + self._testMatrixCreation(data_file, labels, multi_label=True, distributed=True) + def testFromParquetString(self): with tempfile.TemporaryDirectory() as dir: data_file = os.path.join(dir, "data.parquet") @@ -286,6 +313,28 @@ def testFromParquetString(self): self._testMatrixCreation(data_file, "label", distributed=False) self._testMatrixCreation(data_file, "label", distributed=True) + + def testFromMultiParquetStringMultiLabel(self): + with tempfile.TemporaryDirectory() as dir: + data_file_1 = os.path.join(dir, "data_1.parquet") + data_file_2 = os.path.join(dir, "data_2.parquet") + + data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"]) + labels = [f"label_{l}" for l in range(4)] + data_df[labels] = self.multi_y + + df_1 = data_df[0 : len(data_df) // 2] + df_2 = data_df[len(data_df) // 2 :] + + df_1.to_parquet(data_file_1) + df_2.to_parquet(data_file_2) + + self._testMatrixCreation( + [data_file_1, data_file_2], labels, multi_label=True, distributed=False + ) + self._testMatrixCreation( + [data_file_1, data_file_2], labels, multi_label=True, distributed=True + ) def testFromMultiParquetString(self): with tempfile.TemporaryDirectory() as dir: From 8d4cfaff8647d8c26d23155f002e821de8a19f8d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 12 Oct 2023 12:48:23 -0700 Subject: [PATCH 2/3] Apply suggestions from code review Signed-off-by: Antoni Baum --- xgboost_ray/matrix.py | 3 ++- xgboost_ray/tests/test_matrix.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/xgboost_ray/matrix.py b/xgboost_ray/matrix.py index dc9e895b..f57a31fb 100644 --- a/xgboost_ray/matrix.py +++ b/xgboost_ray/matrix.py @@ -410,7 +410,8 @@ def get_data_source(self) -> Type[DataSource]: # We have to make sure they are compatible # if it's a parquet data source and label is a list, then we consider it a multi-label data if not data_source.is_data_type(self.label) \ - and not (isinstance(self.label, List) and data_source.__name__ == "Parquet"): + and not (isinstance(self.label, List) \ + and data_source.__name__ == "Parquet"): raise ValueError( "The passed `data` and `label` types are not compatible." "\nFIX THIS by passing the same types to the " diff --git a/xgboost_ray/tests/test_matrix.py b/xgboost_ray/tests/test_matrix.py index b023cb16..3aaf8598 100644 --- a/xgboost_ray/tests/test_matrix.py +++ b/xgboost_ray/tests/test_matrix.py @@ -296,7 +296,7 @@ def testFromParquetStringMultiLabel(self): data_file = os.path.join(dir, "data.parquet") data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"]) - labels = [f"label_{l}" for l in range(4)] + labels = [f"label_{label}" for label in range(4)] data_df[labels] = self.multi_y data_df.to_parquet(data_file) @@ -320,7 +320,7 @@ def testFromMultiParquetStringMultiLabel(self): data_file_2 = os.path.join(dir, "data_2.parquet") data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"]) - labels = [f"label_{l}" for l in range(4)] + labels = [f"label_{label}" for label in range(4)] data_df[labels] = self.multi_y df_1 = data_df[0 : len(data_df) // 2] From 462dcd55eb0f9c0d63c22076036dbbf74daa9d85 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 12 Oct 2023 17:38:39 -0700 Subject: [PATCH 3/3] Update xgboost_ray/matrix.py Signed-off-by: Antoni Baum --- xgboost_ray/matrix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xgboost_ray/matrix.py b/xgboost_ray/matrix.py index f57a31fb..e2dc11b5 100644 --- a/xgboost_ray/matrix.py +++ b/xgboost_ray/matrix.py @@ -408,7 +408,8 @@ def get_data_source(self) -> Type[DataSource]: ): # noqa: E721: # Label is an object of a different type than the main data. # We have to make sure they are compatible - # if it's a parquet data source and label is a list, then we consider it a multi-label data + # if it's a parquet data source and label is a list, + # then we consider it a multi-label data if not data_source.is_data_type(self.label) \ and not (isinstance(self.label, List) \ and data_source.__name__ == "Parquet"):