diff --git a/gokart/task.py b/gokart/task.py index 2f3bda07..b4e7724e 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -191,21 +191,16 @@ def load_data_frame(self, target: Union[None, str, TargetOnKart] = None, required_columns: Optional[Set[str]] = None, drop_columns: bool = False) -> pd.DataFrame: - data = self.load(target=target) - if isinstance(data, list): - - def _pd_concat(dfs): - if isinstance(dfs, list): - return pd.concat([_pd_concat(df) for df in dfs]) - else: - return dfs - - data = _pd_concat(data) + def _flatten_recursively(dfs): + if isinstance(dfs, list): + return pd.concat([_flatten_recursively(df) for df in dfs]) + else: + return dfs + data = _flatten_recursively(self.load(target=target)) required_columns = required_columns or set() - if data.empty: + if data.empty and len(data.index) == 0: return pd.DataFrame(columns=required_columns) - assert required_columns.issubset(set(data.columns)), f'data must have columns {required_columns}, but actually have only {data.columns}.' if drop_columns: data = data[required_columns] diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index dc6a510f..a99d0495 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -277,6 +277,18 @@ def test_load_data_frame_drop_columns(self): self.assertEqual(1, df.shape[0]) self.assertSetEqual({'a', 'c'}, set(df.columns)) + def test_load_index_only_dataframe(self): + task = _DummyTask() + task.load = MagicMock(return_value=pd.DataFrame(index=range(3))) + + # connnot load index only frame with required_columns + self.assertRaises(AssertionError, lambda : task.load_data_frame(required_columns={'a', 'c'})) + + df: pd.DataFrame = task.load_data_frame() + self.assertIsInstance(df, pd.DataFrame) + self.assertTrue(df.empty) + self.assertListEqual(list(range(3)), list(df.index)) + def test_use_rerun_with_inherits(self): # All tasks are completed. task_c = _DummyTaskC()