From 174817d0dbce0f14915e87aec98b1332686ac376 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Fri, 22 Nov 2024 14:13:24 +0100 Subject: [PATCH] Add the option for multiple views loading - take2 (#346) * moving to new branch * final test fix * Update movement/io/load_poses.py Co-authored-by: Niko Sirmpilatze * doc change * fixed doc --------- Co-authored-by: Niko Sirmpilatze --- .../getting_started/movement_dataset.md | 15 +++++++- movement/io/load_poses.py | 35 +++++++++++++++++++ tests/test_unit/test_load_poses.py | 17 +++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/movement_dataset.md b/docs/source/getting_started/movement_dataset.md index c13128d2..79106e2e 100644 --- a/docs/source/getting_started/movement_dataset.md +++ b/docs/source/getting_started/movement_dataset.md @@ -16,7 +16,6 @@ To discuss the specifics of both types of `movement` datasets, it is useful to c To learn more about `xarray` data structures in general, see the relevant [documentation](xarray:user-guide/data-structures.html). - ## Dataset structure ```{figure} ../_static/dataset_structure.png @@ -135,6 +134,20 @@ In both cases, appropriate **coordinates** are assigned to each **dimension**. - `space` is labelled with either `x`, `y` (2D) or `x`, `y`, `z` (3D). Note that bounding boxes datasets are restricted to 2D space. - `time` is labelled in seconds if `fps` is provided, otherwise the **coordinates** are expressed in frames (ascending 0-indexed integers). +:::{dropdown} Additional dimensions +:color: info +:icon: info +The above **dimensions** and **coordinates** are created +by default when loading a `movement` dataset from a single +file containing pose or bounding boxes tracks. + +In some cases, you may encounter or create datasets with extra +**dimensions**. For example, the +{func}`movement.io.load_poses.from_multiview_files()` function +creates an additional `views` **dimension**, +with the **coordinates** being the names given to each camera view. +::: + ### Data variables The data variables in a `movement` dataset are the arrays that hold the actual data, as {class}`xarray.DataArray` objects. diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index f425d8a1..4255607b 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -351,6 +351,41 @@ def from_dlc_file( ) +def from_multiview_files( + file_path_dict: dict[str, Path | str], + source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], + fps: float | None = None, +) -> xr.Dataset: + """Load and merge pose tracking data from multiple views (cameras). + + Parameters + ---------- + file_path_dict : dict[str, Union[Path, str]] + A dict whose keys are the view names and values are the paths to load. + source_software : {'LightningPose', 'SLEAP', 'DeepLabCut'} + The source software of the file. + fps : float, optional + The number of frames per second in the video. If None (default), + the `time` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata, with an additional ``views`` dimension. + + """ + views_list = list(file_path_dict.keys()) + new_coord_views = xr.DataArray(views_list, dims="view") + + dataset_list = [ + from_file(f, source_software=source_software, fps=fps) + for f in file_path_dict.values() + ] + + return xr.concat(dataset_list, dim=new_coord_views) + + def _ds_from_lp_or_dlc_file( file_path: Path | str, source_software: Literal["LightningPose", "DeepLabCut"], diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index 77990a42..b5145fc9 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -300,3 +300,20 @@ def test_from_numpy_valid( source_software=source_software, ) self.assert_dataset(ds, expected_source_software=source_software) + + def test_from_multiview_files(self): + """Test that the from_file() function delegates to the correct + loader function according to the source_software. + """ + view_names = ["view_0", "view_1"] + file_path_dict = { + view: DATA_PATHS.get("DLC_single-wasp.predictions.h5") + for view in view_names + } + multi_view_ds = load_poses.from_multiview_files( + file_path_dict, source_software="DeepLabCut" + ) + + assert isinstance(multi_view_ds, xr.Dataset) + assert "view" in multi_view_ds.dims + assert multi_view_ds.view.values.tolist() == view_names