Skip to content

Commit

Permalink
add synbody whac
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei-Chen-hub committed Feb 22, 2024
1 parent f7f6b57 commit 677a572
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 23 deletions.
65 changes: 65 additions & 0 deletions mmhuman3d/data/data_converters/base_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from mmhuman3d.core.conventions.keypoints_mapping import get_keypoint_idxs_by_part

class BaseConverter(metaclass=ABCMeta):
"""Base dataset.
Expand Down Expand Up @@ -100,6 +101,70 @@ def _keypoints_to_scaled_bbox(keypoints, scale=1.0):
return bbox


def _keypoints_to_scaled_bbox_bfh(self,
keypoints,
occ=None,
body_scale=1.0,
fh_scale=1.0,
convention='smplx'):
'''Obtain scaled bbox in xyxy format given keypoints
Args:
keypoints (np.ndarray): Keypoints
scale (float): Bounding Box scale
Returns:
bbox_xyxy (np.ndarray): Bounding box in xyxy format
'''
bboxs = []

# supported kps.shape: (1, n, k) or (n, k), k = 2 or 3
if keypoints.ndim == 3:
keypoints = keypoints[0]
if keypoints.shape[-1] != 2:
keypoints = keypoints[:, :2]

for body_part in ['body', 'head', 'left_hand', 'right_hand']:
if body_part == 'body':
scale = body_scale
kps = keypoints
else:
scale = fh_scale
kp_id = get_keypoint_idxs_by_part(
body_part, convention=convention)
kps = keypoints[kp_id]

if not occ is None:
occ_p = occ[kp_id]
if np.sum(occ_p) / len(kp_id) >= 0.1:
conf = 0
# print(f'{body_part} occluded, occlusion: {np.sum(occ_p) / len(kp_id)}, skip')
else:
# print(f'{body_part} good, {np.sum(self_occ_p + occ_p) / len(kp_id)}')
conf = 1
else:
conf = 1
if body_part == 'body':
conf = 1

xmin, ymin = np.amin(kps, axis=0)
xmax, ymax = np.amax(kps, axis=0)

width = (xmax - xmin) * scale
height = (ymax - ymin) * scale

x_center = 0.5 * (xmax + xmin)
y_center = 0.5 * (ymax + ymin)
xmin = x_center - 0.5 * width
xmax = x_center + 0.5 * width
ymin = y_center - 0.5 * height
ymax = y_center + 0.5 * height

bbox = np.stack([xmin, ymin, xmax, ymax, conf],
axis=0).astype(np.float32)
bboxs.append(bbox)

return bboxs


class BaseModeConverter(BaseConverter):
"""Convert datasets by mode.
Expand Down
78 changes: 55 additions & 23 deletions mmhuman3d/data/data_converters/synbody_whac.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,12 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
for contact_key in self.misc_config['contact_label']:
contact_[contact_key] = []

for seq in seqs_targeted:
# load contact region
from tools.utils.convert_contact_label import vertex2part_smplx_dict
left_foot_idxs = np.array(vertex2part_smplx_dict['left_foot'])
right_foot_idxs = np.array(vertex2part_smplx_dict['right_foot'])

for seq in tqdm(seqs_targeted, desc=f'Processing {mode}', leave=False, position=0):

# preprocess sequence

Expand All @@ -216,8 +221,35 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
# parse camera params and image
seq_base = os.path.dirname(os.path.dirname(seq))

# prepare smplx tensor
smplx_param_tensor = {}
for key in self.smplx_shape.keys():
smplx_param_tensor[key] = torch.tensor(smplx_param[key]
.reshape(self.smplx_shape[key])).to(self.device)

# get output
output = gendered_smplx[gender](**smplx_param_tensor, return_verts=True)

kps3d = output['joints'].detach().cpu().numpy()
pelvis_world = kps3d[:, get_keypoint_idx('pelvis', 'smplx'), :]

# get vertices and contact
vertices = output['vertices'].detach().cpu().numpy()

# height is -y, get lowest from frame 0
left_foot_y_lowest = np.sort(vertices[1, left_foot_idxs, 1])[-1]
right_foot_y_lowest = np.sort(vertices[1, right_foot_idxs, 1])[-1]

left_foot_contact = np.zeros([datalen])
right_foot_contact = np.zeros([datalen])

threshold = 0.01
for i in range(datalen):
left_foot_contact[i] = 1 if (vertices[i, left_foot_idxs, 1].max() > (left_foot_y_lowest - threshold)) else 0
right_foot_contact[i] = 1 if (vertices[i, right_foot_idxs, 1].max() > (right_foot_y_lowest - threshold)) else 0

# cids
cids = os.listdir(os.path.join(seq_base, 'img'))
cids = sorted(os.listdir(os.path.join(seq_base, 'img')))

for cid in cids:

Expand All @@ -238,20 +270,10 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
valid_idxs = np.intersect1d(valid_idxs_img, valid_idex_cam)
# valid_idxs = valid_idxs[valid_idxs > 0]


# prepare smplx tensor
smplx_param_tensor = {}
for key in self.smplx_shape.keys():
smplx_param_tensor[key] = torch.tensor(smplx_param[key][valid_idxs]
.reshape(self.smplx_shape[key])).to(self.device)

# get output
output = gendered_smplx[gender](**smplx_param_tensor, return_verts=True)

kps3d = output['joints'].detach().cpu().numpy()
pelvis_world = kps3d[:, get_keypoint_idx('pelvis', 'smplx'), :]

for vid in valid_idxs:
for vid in tqdm(valid_idxs, desc=f'Processing {sequence_name}, {cid} / {cids[-1]}',
leave=False, position=1):
if vid == 0:
continue
# get image path
img_p = os.path.join(img_f, f'{vid:04d}.jpeg')
image_path = img_p.replace(dataset_path + '/', '')
Expand All @@ -270,7 +292,7 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
[0, 0, 1, 0],
[0, 0, 0, 1]])

Rt = (Rt.T @ ue2opencv).T
# Rt = (Rt.T @ ue2opencv).T

# transform to cam space
global_orient, transl = transform_to_camera_frame(
Expand Down Expand Up @@ -303,12 +325,8 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
kps2d = camera.transform_points(kps3d_c).detach().cpu().numpy().squeeze()[:, :2]
kps3d_c = kps3d_c.detach().cpu().numpy().squeeze()

# test overlay
# img = cv2.imread(img_p)
# for kp in kps2d:
# cv2.circle(img, (int(kp[0]), int(kp[1])), 5, (0, 255, 0), -1)
# cv2.imwrite(f'{out_path}/{os.path.basename(seq_base)}_{cid}_{vid}.jpg', img)

kps2d = [1920, 1080] - kps2d

# get bbox from 2d keypoints
bboxs = self._keypoints_to_scaled_bbox_bfh(
kps2d,
Expand All @@ -329,6 +347,20 @@ def convert_by_mode(self, dataset_path: str, out_path: str,
bbox_xywh.append(conf) # (5,)
bboxs_[bbox_name].append(bbox_xywh)

# append contact
contact_['part_segmentation'].append([left_foot_contact[vid],
right_foot_contact[vid]])
# if 0 in [left_foot_contact[vid], right_foot_contact[vid]]:
# # test overlay
# img = cv2.imread(img_p)
# for kp in kps2d:
# cv2.circle(img, (int(kp[0]), int(kp[1])), 5, (0, 255, 0), -1)
# if left_foot_contact[vid] == 0:
# cv2.putText(img, 'left foot', (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
# if right_foot_contact[vid] == 0:
# cv2.putText(img, 'right foot', (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
# cv2.imwrite(f'{out_path}/{os.path.basename(seq_base)}_{cid}_{vid}.jpg', img)

# append image path
image_path_.append(image_path)

Expand Down

0 comments on commit 677a572

Please sign in to comment.