Skip to content

Commit

Permalink
Merge pull request #130 from Hi-king/fix/handle_indexonly_dataframe
Browse files Browse the repository at this point in the history
allow load_data_frame to handle index only dataframe
  • Loading branch information
nishiba authored May 25, 2020
2 parents c2e2066 + 73c7ed4 commit 7c9fd25
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
19 changes: 7 additions & 12 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,16 @@ def load_data_frame(self,
target: Union[None, str, TargetOnKart] = None,
required_columns: Optional[Set[str]] = None,
drop_columns: bool = False) -> pd.DataFrame:
data = self.load(target=target)
if isinstance(data, list):

def _pd_concat(dfs):
if isinstance(dfs, list):
return pd.concat([_pd_concat(df) for df in dfs])
else:
return dfs

data = _pd_concat(data)
def _flatten_recursively(dfs):
if isinstance(dfs, list):
return pd.concat([_flatten_recursively(df) for df in dfs])
else:
return dfs
data = _flatten_recursively(self.load(target=target))

required_columns = required_columns or set()
if data.empty:
if data.empty and len(data.index) == 0:
return pd.DataFrame(columns=required_columns)

assert required_columns.issubset(set(data.columns)), f'data must have columns {required_columns}, but actually have only {data.columns}.'
if drop_columns:
data = data[required_columns]
Expand Down
12 changes: 12 additions & 0 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,18 @@ def test_load_data_frame_drop_columns(self):
self.assertEqual(1, df.shape[0])
self.assertSetEqual({'a', 'c'}, set(df.columns))

def test_load_index_only_dataframe(self):
task = _DummyTask()
task.load = MagicMock(return_value=pd.DataFrame(index=range(3)))

# connnot load index only frame with required_columns
self.assertRaises(AssertionError, lambda : task.load_data_frame(required_columns={'a', 'c'}))

df: pd.DataFrame = task.load_data_frame()
self.assertIsInstance(df, pd.DataFrame)
self.assertTrue(df.empty)
self.assertListEqual(list(range(3)), list(df.index))

def test_use_rerun_with_inherits(self):
# All tasks are completed.
task_c = _DummyTaskC()
Expand Down

0 comments on commit 7c9fd25

Please sign in to comment.