Skip to content

Commit

Permalink
Draft reorder poses dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Nov 22, 2024
1 parent 174817d commit 451ac86
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 32 deletions.
6 changes: 3 additions & 3 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,9 +699,9 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
},
coords={
DIM_NAMES[0]: time_coords,
DIM_NAMES[1]: data.individual_names,
DIM_NAMES[2]: data.keypoint_names,
DIM_NAMES[3]: ["x", "y", "z"][:n_space],
DIM_NAMES[2]: data.keypoint_names,
DIM_NAMES[1]: data.individual_names,
},
attrs={
"fps": data.fps,
Expand All @@ -710,4 +710,4 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
"source_file": None,
"ds_type": "poses",
},
)
).transpose("time", "space", "keypoints", "individuals")
13 changes: 8 additions & 5 deletions movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ def _ds_to_dlc_style_df(
"""
# Concatenate the pose tracks and confidence scores into one array
# and reverse the order of the dimensions except for the time dimension
tracks_with_scores = np.concatenate(
(
ds.position.data,
ds.confidence.data[..., np.newaxis],
ds.confidence.data[:, np.newaxis, ...],
),
axis=-1,
axis=1,
)
transpose_order = [0] + list(range(tracks_with_scores.ndim - 1, 0, -1))
tracks_with_scores = tracks_with_scores.transpose(transpose_order)

# Create DataFrame with multi-index columns
df = pd.DataFrame(
Expand Down Expand Up @@ -320,9 +323,9 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None:
n_frames = frame_idxs[-1] - frame_idxs[0] + 1
pos_x = ds.position.sel(space="x").values
# Mask denoting which individuals are present in each frame
track_occupancy = (~np.all(np.isnan(pos_x), axis=2)).astype(int)
tracks = np.transpose(ds.position.data, (1, 3, 2, 0))
point_scores = np.transpose(ds.confidence.data, (1, 2, 0))
track_occupancy = (~np.all(np.isnan(pos_x), axis=1)).astype(int)
tracks = np.transpose(ds.position.data, (3, 1, 2, 0))
point_scores = np.transpose(ds.confidence.data, (2, 1, 0))
instance_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)
tracking_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)
labels_path = (
Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def valid_poses_dataset(valid_position_array, request):
},
coords={
"time": np.arange(n_frames),
"individuals": [f"ind{i}" for i in range(1, n_individuals + 1)],
"keypoints": [f"key{i}" for i in range(1, n_keypoints + 1)],
"space": ["x", "y"],
"keypoints": [f"key{i}" for i in range(1, n_keypoints + 1)],
"individuals": [f"ind{i}" for i in range(1, n_individuals + 1)],
},
attrs={
"fps": None,
Expand All @@ -408,7 +408,7 @@ def valid_poses_dataset(valid_position_array, request):
"source_file": "test.h5",
"ds_type": "poses",
},
)
).transpose("time", "space", "keypoints", "individuals")


@pytest.fixture
Expand Down Expand Up @@ -504,9 +504,9 @@ def valid_poses_dataset_uniform_linear_motion(
},
coords={
dim_names[0]: np.arange(n_frames),
dim_names[1]: [f"id_{i}" for i in range(1, n_individuals + 1)],
dim_names[2]: ["centroid", "left", "right"],
dim_names[3]: ["x", "y"],
dim_names[2]: ["centroid", "left", "right"],
dim_names[1]: [f"id_{i}" for i in range(1, n_individuals + 1)],
},
attrs={
"fps": None,
Expand All @@ -515,7 +515,7 @@ def valid_poses_dataset_uniform_linear_motion(
"source_file": "test_poses.h5",
"ds_type": "poses",
},
)
).transpose("time", "space", "keypoints", "individuals")


@pytest.fixture
Expand Down
6 changes: 3 additions & 3 deletions tests/test_integration/test_kinematics_vector_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ def test_cart2pol_transform_on_kinematics(

# Build expected data array
expected_array_pol = xr.DataArray(
np.stack(expected_kinematics_polar, axis=1),
np.stack(expected_kinematics_polar, axis=-1),
# Stack along the "individuals" axis
dims=["time", "individuals", "space"],
dims=["time", "space", "individuals"],
)
if "keypoints" in ds.position.coords:
expected_array_pol = expected_array_pol.expand_dims(
{"keypoints": ds.position.coords["keypoints"].size}
)
expected_array_pol = expected_array_pol.transpose(
"time", "individuals", "keypoints", "space"
"time", "space", "keypoints", "individuals"
)

# Compare the values of the kinematic_array against the expected_array
Expand Down
12 changes: 6 additions & 6 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,19 @@ def test_kinematics_uniform_linear_motion(
# and in the final xarray.DataArray
expected_dims = ["time", "individuals"]
if kinematic_variable in ["displacement", "velocity", "acceleration"]:
expected_dims.append("space")
expected_dims.insert(1, "space")

# Build expected data array from the expected numpy array
expected_array = xr.DataArray(
# Stack along the "individuals" axis
np.stack(expected_kinematics, axis=1),
np.stack(expected_kinematics, axis=-1),
dims=expected_dims,
)
if "keypoints" in position.coords:
expected_array = expected_array.expand_dims(
{"keypoints": position.coords["keypoints"].size}
)
expected_dims.insert(2, "keypoints")
expected_dims.insert(-1, "keypoints")
expected_array = expected_array.transpose(*expected_dims)

# Compare the values of the kinematic_array against the expected_array
Expand Down Expand Up @@ -263,11 +263,11 @@ def test_path_length_across_time_ranges(
num_segments -= 9 - np.floor(min(9, stop))

expected_path_length = xr.DataArray(
np.ones((2, 3)) * np.sqrt(2) * num_segments,
dims=["individuals", "keypoints"],
np.ones((3, 2)) * np.sqrt(2) * num_segments,
dims=["keypoints", "individuals"],
coords={
"individuals": position.coords["individuals"],
"keypoints": position.coords["keypoints"],
"individuals": position.coords["individuals"],
},
)
xr.testing.assert_allclose(path_length, expected_path_length)
Expand Down
24 changes: 15 additions & 9 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,22 @@ def assert_dataset(
assert var in dataset.data_vars
assert isinstance(dataset[var], xr.DataArray)
assert dataset.position.ndim == 4
assert dataset.confidence.shape == dataset.position.shape[:-1]
# Check the dims and coords
position_shape = dataset.position.shape
# Confidence has the same shape as position, except for the space dim
assert (
dataset.confidence.shape == position_shape[:1] + position_shape[2:]
)
# Check the dims
DIM_NAMES = ValidPosesDataset.DIM_NAMES
assert all([i in dataset.dims for i in DIM_NAMES])
for d, dim in enumerate(DIM_NAMES[1:]):
assert dataset.sizes[dim] == dataset.position.shape[d + 1]
assert all(
[isinstance(s, str) for s in dataset.coords[dim].values]
)
assert all([i in dataset.coords["space"] for i in ["x", "y"]])
expected_dim_length_dict = {
DIM_NAMES[idx]: position_shape[i]
for i, idx in enumerate([0, 3, 2, 1])
}
assert expected_dim_length_dict == dataset.sizes
# Check the coords
for dim in DIM_NAMES[1:]:
assert all(isinstance(s, str) for s in dataset.coords[dim].values)
assert all(coord in dataset.coords["space"] for coord in ["x", "y"])
# Check the metadata attributes
assert (
dataset.source_file is None
Expand Down

0 comments on commit 451ac86

Please sign in to comment.