From 22a4a35ce369914fbf17f1d106f09e0e002c4bd9 Mon Sep 17 00:00:00 2001 From: Hi-king Date: Sun, 24 May 2020 17:42:22 +0900 Subject: [PATCH 1/2] :bug: allow gokart.task.TaskOnKart.load_data_frame to handle index only dataframe --- gokart/task.py | 19 +++++++------------ test/test_task_on_kart.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/gokart/task.py b/gokart/task.py index 2f3bda07..892aea61 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -191,21 +191,16 @@ def load_data_frame(self, target: Union[None, str, TargetOnKart] = None, required_columns: Optional[Set[str]] = None, drop_columns: bool = False) -> pd.DataFrame: - data = self.load(target=target) - if isinstance(data, list): - - def _pd_concat(dfs): - if isinstance(dfs, list): - return pd.concat([_pd_concat(df) for df in dfs]) - else: - return dfs - - data = _pd_concat(data) + def _pd_concat(dfs): + if isinstance(dfs, list): + return pd.concat([_pd_concat(df) for df in dfs]) + else: + return dfs + data = _pd_concat(self.load(target=target)) required_columns = required_columns or set() - if data.empty: + if data.empty and len(data.index) == 0: return pd.DataFrame(columns=required_columns) - assert required_columns.issubset(set(data.columns)), f'data must have columns {required_columns}, but actually have only {data.columns}.' if drop_columns: data = data[required_columns] diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index dc6a510f..a99d0495 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -277,6 +277,18 @@ def test_load_data_frame_drop_columns(self): self.assertEqual(1, df.shape[0]) self.assertSetEqual({'a', 'c'}, set(df.columns)) + def test_load_index_only_dataframe(self): + task = _DummyTask() + task.load = MagicMock(return_value=pd.DataFrame(index=range(3))) + + # connnot load index only frame with required_columns + self.assertRaises(AssertionError, lambda : task.load_data_frame(required_columns={'a', 'c'})) + + df: pd.DataFrame = task.load_data_frame() + self.assertIsInstance(df, pd.DataFrame) + self.assertTrue(df.empty) + self.assertListEqual(list(range(3)), list(df.index)) + def test_use_rerun_with_inherits(self): # All tasks are completed. task_c = _DummyTaskC() From 73c7ed486018d1e1b2def7e8cb8ec68b26ae9c06 Mon Sep 17 00:00:00 2001 From: Hi-king Date: Sun, 24 May 2020 17:50:11 +0900 Subject: [PATCH 2/2] :art: rename _flatten_recursively --- gokart/task.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gokart/task.py b/gokart/task.py index 892aea61..b4e7724e 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -191,12 +191,12 @@ def load_data_frame(self, target: Union[None, str, TargetOnKart] = None, required_columns: Optional[Set[str]] = None, drop_columns: bool = False) -> pd.DataFrame: - def _pd_concat(dfs): + def _flatten_recursively(dfs): if isinstance(dfs, list): - return pd.concat([_pd_concat(df) for df in dfs]) + return pd.concat([_flatten_recursively(df) for df in dfs]) else: return dfs - data = _pd_concat(self.load(target=target)) + data = _flatten_recursively(self.load(target=target)) required_columns = required_columns or set() if data.empty and len(data.index) == 0: