From 7a991d65801072c3a09474fa4f9b8be075bb7357 Mon Sep 17 00:00:00 2001 From: Wei-Chen-hub <1259566226@qq.com> Date: Tue, 15 Aug 2023 10:39:49 +0800 Subject: [PATCH] reformat and restructure 230815 --- mmhuman3d/data/data_converters/__init__.py | 8 +- mmhuman3d/data/data_converters/h36m.py | 427 ------------------ .../data/data_converters/h36m_neural_annot.py | 413 +++++++++++++++++ mmhuman3d/data/data_converters/mpii.py | 91 ---- .../data/data_converters/mpii_neural_annot.py | 285 ++++++++++++ tools/convert_datasets.py | 10 +- tools/convert_datasets_commands.txt | 76 ++-- tools/postprocess/humandata_sample.py | 0 tools/{solvepnp => postprocess}/solvepnp.py | 0 tools/postprocess/testpnp.py | 60 +++ .../fit_shape2smplx.py | 7 +- tools/preprocess/mpii.py | 81 ++++ 12 files changed, 901 insertions(+), 557 deletions(-) delete mode 100644 mmhuman3d/data/data_converters/h36m.py create mode 100644 mmhuman3d/data/data_converters/h36m_neural_annot.py delete mode 100644 mmhuman3d/data/data_converters/mpii.py create mode 100644 mmhuman3d/data/data_converters/mpii_neural_annot.py create mode 100644 tools/postprocess/humandata_sample.py rename tools/{solvepnp => postprocess}/solvepnp.py (100%) create mode 100644 tools/postprocess/testpnp.py rename tools/{shape2smplx => preprocess}/fit_shape2smplx.py (97%) create mode 100644 tools/preprocess/mpii.py diff --git a/mmhuman3d/data/data_converters/__init__.py b/mmhuman3d/data/data_converters/__init__.py index f8f34ee0..481fa15e 100644 --- a/mmhuman3d/data/data_converters/__init__.py +++ b/mmhuman3d/data/data_converters/__init__.py @@ -19,7 +19,7 @@ from .freihand import FreihandConverter from .gta_human import GTAHumanConverter from .gta_human2 import GTAHuman2Converter -from .h36m import H36mConverter +from .h36m_neural_annot import H36mNeuralConverter from .h36m_hybrik import H36mHybrIKConverter from .h36m_smplx import H36mSMPLXConverter from .hanco import HancoConverter @@ -34,7 +34,7 @@ from .moyo import MoyoConverter from .mpi_inf_3dhp import MpiInf3dhpConverter from .mpi_inf_3dhp_hybrik import MpiInf3dhpHybrIKConverter -from .mpii import MpiiConverter +from .mpii_neural_annot import MpiiNeuralConverter from .penn_action import PennActionConverter from .posetrack import PosetrackConverter from .pw3d import Pw3dConverter @@ -56,8 +56,8 @@ __all__ = [ 'build_data_converter', 'AgoraConverter', - 'MpiiConverter', - 'H36mConverter', + 'MpiiNeuralConverter', + 'H36mNeuralConverter', 'AmassConverter', 'CocoConverter', 'CocoWholebodyConverter', diff --git a/mmhuman3d/data/data_converters/h36m.py b/mmhuman3d/data/data_converters/h36m.py deleted file mode 100644 index bd5a638d..00000000 --- a/mmhuman3d/data/data_converters/h36m.py +++ /dev/null @@ -1,427 +0,0 @@ -import glob -import os -import pickle -import xml.etree.ElementTree as ET -from typing import List - -import cdflib -import cv2 -import h5py -import numpy as np -from tqdm import tqdm - -from mmhuman3d.core.cameras.camera_parameters import CameraParameter -from mmhuman3d.core.conventions.keypoints_mapping import convert_kps -from mmhuman3d.data.data_structures.human_data import HumanData -from mmhuman3d.data.data_structures.multi_human_data import MultiHumanData -from .base_converter import BaseModeConverter -from .builder import DATA_CONVERTERS - - -class H36mCamera(): - """Extract camera information from Human3.6M Metadata. - - Args: - metadata (str): path to metadata.xml file - """ - - def __init__(self, metadata: str): - self.subjects = [] - self.sequence_mappings = {} - self.action_names = {} - self.metadata = metadata - self.camera_ids = [] - self._load_metadata() - self.image_sizes = { - '54138969': { - 'width': 1000, - 'height': 1002 - }, - '55011271': { - 'width': 1000, - 'height': 1000 - }, - '58860488': { - 'width': 1000, - 'height': 1000 - }, - '60457274': { - 'width': 1000, - 'height': 1002 - } - } - - def _load_metadata(self) -> None: - """Load meta data from metadata.xml.""" - - assert os.path.exists(self.metadata) - - tree = ET.parse(self.metadata) - root = tree.getroot() - - for i, tr in enumerate(root.find('mapping')): - if i == 0: - _, _, *self.subjects = [td.text for td in tr] - self.sequence_mappings \ - = {subject: {} for subject in self.subjects} - elif i < 33: - action_id, subaction_id, *prefixes = [td.text for td in tr] - for subject, prefix in zip(self.subjects, prefixes): - self.sequence_mappings[subject][(action_id, subaction_id)]\ - = prefix - - for i, elem in enumerate(root.find('actionnames')): - action_id = str(i + 1) - self.action_names[action_id] = elem.text - - self.camera_ids \ - = [elem.text for elem in root.find('dbcameras/index2id')] - - w0 = root.find('w0') - self.cameras_raw = [float(num) for num in w0.text[1:-1].split()] - - @staticmethod - def get_intrinsic_matrix(f: List[float], - c: List[float], - inv: bool = False) -> np.ndarray: - """Get intrisic matrix (or its inverse) given f and c.""" - intrinsic_matrix = np.zeros((3, 3)).astype(np.float32) - intrinsic_matrix[0, 0] = f[0] - intrinsic_matrix[0, 2] = c[0] - intrinsic_matrix[1, 1] = f[1] - intrinsic_matrix[1, 2] = c[1] - intrinsic_matrix[2, 2] = 1 - - if inv: - intrinsic_matrix = np.linalg.inv(intrinsic_matrix).astype( - np.float32) - return intrinsic_matrix - - def _get_camera_params(self, camera: int, subject: str) -> dict: - """Get camera parameters given camera id and subject id.""" - metadata_slice = np.zeros(15) - start = 6 * (camera * 11 + (subject - 1)) - - metadata_slice[:6] = self.cameras_raw[start:start + 6] - metadata_slice[6:] = self.cameras_raw[265 + camera * 9 - 1:265 + - (camera + 1) * 9 - 1] - - # extrinsics - x, y, z = -metadata_slice[0], metadata_slice[1], -metadata_slice[2] - - R_x = np.array([[1, 0, 0], [0, np.cos(x), np.sin(x)], - [0, -np.sin(x), np.cos(x)]]) - R_y = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], - [-np.sin(y), 0, np.cos(y)]]) - R_z = np.array([[np.cos(z), np.sin(z), 0], [-np.sin(z), - np.cos(z), 0], [0, 0, 1]]) - R = (R_x @ R_y @ R_z).T - T = metadata_slice[3:6].reshape(-1) - # convert unit from millimeter to meter - T *= 0.001 - - # intrinsics - c = metadata_slice[8:10, None] - f = metadata_slice[6:8, None] - K = self.get_intrinsic_matrix(f, c) - - # distortion - k = metadata_slice[10:13, None] - p = metadata_slice[13:15, None] - - w = self.image_sizes[self.camera_ids[camera]]['width'] - h = self.image_sizes[self.camera_ids[camera]]['height'] - - camera_name = f'S{subject}_{self.camera_ids[camera]}' - camera_params = CameraParameter(camera_name, h, w) - camera_params.set_KRT(K, R, T) - camera_params.set_value('k1', float(k[0])) - camera_params.set_value('k2', float(k[1])) - camera_params.set_value('k3', float(k[2])) - camera_params.set_value('p1', float(p[0])) - camera_params.set_value('p2', float(p[1])) - return camera_params.to_dict() - - def generate_cameras_dict(self) -> dict: - """Generate dictionary of camera params which contains camera - parameters for 11 subjects each with 4 cameras.""" - cameras = {} - for subject in range(1, 12): - for camera in range(4): - key = (f'S{subject}', self.camera_ids[camera]) - cameras[key] = self._get_camera_params(camera, subject) - - return cameras - - -@DATA_CONVERTERS.register_module() -class H36mConverter(BaseModeConverter): - """Human3.6M dataset - `Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human - Sensing in Natural Environments' TPAMI`2014 - More details can be found in the `paper - `__. - - Args: - modes (list): 'valid' or 'train' for accepted modes - protocol (int): 1 or 2 for available protocols - extract_img (bool): Store True to extract images into a separate - folder. Default: False. - mosh_dir (str, optional): Path to directory containing mosh files. - """ - ACCEPTED_MODES = ['valid', 'train'] - - def __init__(self, - modes: List = [], - protocol: int = 1, - extract_img: bool = False, - mosh_dir=None) -> None: - super(H36mConverter, self).__init__(modes) - accepted_protocol = [1, 2] - if protocol not in accepted_protocol: - raise ValueError('Input protocol not in accepted protocol. \ - Use either 1 or 2') - self.protocol = protocol - self.extract_img = extract_img - self.get_mosh = False - if mosh_dir is not None and os.path.exists(mosh_dir): - self.get_mosh = True - self.mosh_dir = mosh_dir - self.camera_name_to_idx = { - '54138969': 0, - '55011271': 1, - '58860488': 2, - '60457274': 3, - } - - def convert_by_mode(self, - dataset_path: str, - out_path: str, - mode: str, - enable_multi_human_data: bool = False) -> dict: - """ - Args: - dataset_path (str): Path to directory where raw images and - annotations are stored. - out_path (str): Path to directory to save preprocessed npz file - mode (str): Mode in accepted modes - enable_multi_human_data (bool): - Whether to generate a multi-human data. If set to True, - stored in MultiHumanData() format. - Default: False, stored in HumanData() format. - - Returns: - dict: - A dict containing keys image_path, bbox_xywh, keypoints2d, - keypoints2d_mask, keypoints3d, keypoints3d_mask, cam_param - stored in HumanData() format - """ - if enable_multi_human_data: - # use MultiHumanData to store all data - human_data = MultiHumanData() - else: - # use HumanData to store all data - human_data = HumanData() - - # pick 17 joints from 32 (repeated) joints - h36m_idx = [ - 11, 6, 7, 8, 1, 2, 3, 12, 24, 14, 15, 17, 18, 19, 25, 26, 27 - ] - - # structs we use - image_path_, bbox_xywh_, keypoints2d_, keypoints3d_ = [], [], [], [] - - smpl = {} - smpl['body_pose'] = [] - smpl['global_orient'] = [] - smpl['betas'] = [] - - # choose users ids for different set - if mode == 'train': - user_list = [1, 5, 6, 7, 8] - elif mode == 'valid': - user_list = [9, 11] - - # go over each user - for user_i in tqdm(user_list, desc='user id'): - user_name = f'S{user_i}' - # path with GT bounding boxes - bbox_path = os.path.join(dataset_path, user_name, 'MySegmentsMat', - 'ground_truth_bs') - # path with GT 2D pose - pose2d_path = os.path.join(dataset_path, user_name, - 'MyPoseFeatures', 'D2_Positions') - # path with GT 3D pose - pose_path = os.path.join(dataset_path, user_name, 'MyPoseFeatures', - 'D3_Positions_mono') - # path with videos - vid_path = os.path.join(dataset_path, user_name, 'Videos') - - # go over all the sequences of each user - seq_list = glob.glob(os.path.join(pose_path, '*.cdf')) - seq_list.sort() - - # mosh path - if self.get_mosh: - mosh_path = os.path.join(self.mosh_dir, user_name) - - for seq_i in tqdm(seq_list, desc='sequence id'): - - # sequence info - seq_name = seq_i.split('/')[-1] - action, camera, _ = seq_name.split('.') - action_raw = action - action = action.replace(' ', '_') - # irrelevant sequences - if action == '_ALL': - continue - - # 2D pose file - pose2d_file = os.path.join(pose2d_path, seq_name) - poses_2d = cdflib.CDF(pose2d_file)['Pose'][0] - - # 3D pose file - poses_3d = cdflib.CDF(seq_i)['Pose'][0] - - # 3D mosh file - if self.get_mosh: - mosh_name = '%s_cam%s_aligned.pkl' % ( - action_raw, self.camera_name_to_idx[camera]) - mosh_file = os.path.join(mosh_path, mosh_name) - if os.path.exists(mosh_file): - with open(mosh_file, 'rb') as file: - mosh_data = pickle.load(file, encoding='latin1') - else: - print(f'mosh file {mosh_name} is missing') - continue - thetas = mosh_data['new_poses'] - betas = mosh_data['betas'] - - # bbox file - bbox_file = os.path.join(bbox_path, - seq_name.replace('cdf', 'mat')) - bbox_h5py = h5py.File(bbox_file) - - # video file - if self.extract_img: - vid_file = os.path.join(vid_path, - seq_name.replace('cdf', 'mp4')) - vidcap = cv2.VideoCapture(vid_file) - - # go over each frame of the sequence - for frame_i in tqdm(range(poses_3d.shape[0]), desc='frame id'): - # read video frame - if self.extract_img: - success, image = vidcap.read() - if not success: - break - - # check if you can keep this frame - if frame_i % 5 == 0 and (self.protocol == 1 - or camera == '60457274'): - # image name - seq_id = f'{user_name}_{action}' - image_name = f'{seq_id}.{camera}_{frame_i + 1:06d}.jpg' - img_folder_name = f'{user_name}_{action}.{camera}' - image_path = os.path.join(user_name, 'images', - img_folder_name, image_name) - image_abs_path = os.path.join(dataset_path, image_path) - # save image - if self.extract_img: - folder = os.path.dirname(image_abs_path) - if not os.path.exists(folder): - os.makedirs(folder, exist_ok=True) - cv2.imwrite(image_abs_path, image) - - # get bbox from mask - mask = bbox_h5py[bbox_h5py['Masks'][frame_i, 0]][:].T - ys, xs = np.where(mask == 1) - bbox_xyxy = np.array([ - np.min(xs), - np.min(ys), - np.max(xs) + 1, - np.max(ys) + 1 - ]) - bbox_xyxy = self._bbox_expand( - bbox_xyxy, scale_factor=0.9) - bbox_xywh = self._xyxy2xywh(bbox_xyxy) - - # read GT 2D pose - keypoints2dall = np.reshape(poses_2d[frame_i, :], - [-1, 2]) - keypoints2d17 = keypoints2dall[h36m_idx] - keypoints2d17 = np.concatenate( - [keypoints2d17, np.ones((17, 1))], axis=1) - - # read GT 3D pose - keypoints3dall = np.reshape(poses_3d[frame_i, :], - [-1, 3]) / 1000. - keypoints3d17 = keypoints3dall[h36m_idx] - keypoints3d17 -= keypoints3d17[0] # root-centered - keypoints3d17 = np.concatenate( - [keypoints3d17, np.ones((17, 1))], axis=1) - - # store data - image_path_.append(image_path) - bbox_xywh_.append(bbox_xywh) - keypoints2d_.append(keypoints2d17) - keypoints3d_.append(keypoints3d17) - - # get mosh data - if self.get_mosh: - pose = thetas[frame_i // 5, :] - R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0] - R_root = cv2.Rodrigues(pose[:3])[0] - new_root = R_root.dot(R_mod) - pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3) - smpl['body_pose'].append(pose[3:].reshape((23, 3))) - smpl['global_orient'].append(pose[:3]) - smpl['betas'].append(betas) - - if self.get_mosh: - smpl['body_pose'] = np.array(smpl['body_pose']).reshape( - (-1, 23, 3)) - smpl['global_orient'] = np.array(smpl['global_orient']).reshape( - (-1, 3)) - smpl['betas'] = np.array(smpl['betas']).reshape((-1, 10)) - human_data['smpl'] = smpl - - if enable_multi_human_data: - frame_range = np.array([[i, i + 1] - for i in range(len(image_path_))]) - human_data['frame_range'] = frame_range - - metadata_path = os.path.join(dataset_path, 'metadata.xml') - if isinstance(metadata_path, str): - camera = H36mCamera(metadata_path) - cam_param = camera.generate_cameras_dict() - bbox_xywh_ = np.array(bbox_xywh_).reshape((-1, 4)) - bbox_xywh_ = np.hstack([bbox_xywh_, np.ones([bbox_xywh_.shape[0], 1])]) - keypoints2d_ = np.array(keypoints2d_).reshape((-1, 17, 3)) - keypoints2d_, mask = convert_kps(keypoints2d_, 'h36m', 'human_data') - keypoints3d_ = np.array(keypoints3d_).reshape((-1, 17, 4)) - keypoints3d_, _ = convert_kps(keypoints3d_, 'h36m', 'human_data') - - human_data['image_path'] = image_path_ - human_data['bbox_xywh'] = bbox_xywh_ - human_data['keypoints2d_mask'] = mask - human_data['keypoints3d_mask'] = mask - human_data['keypoints2d'] = keypoints2d_ - human_data['keypoints3d'] = keypoints3d_ - human_data['cam_param'] = cam_param - human_data['config'] = 'h36m' - human_data.compress_keypoints_by_mask() - - # store the data struct - if not os.path.isdir(out_path): - os.makedirs(out_path) - - if mode == 'train': - if self.get_mosh: - out_file = os.path.join(out_path, 'h36m_mosh_train.npz') - else: - out_file = os.path.join(out_path, 'h36m_train.npz') - elif mode == 'valid': - out_file = os.path.join(out_path, - f'h36m_valid_protocol{self.protocol}.npz') - human_data.dump(out_file) diff --git a/mmhuman3d/data/data_converters/h36m_neural_annot.py b/mmhuman3d/data/data_converters/h36m_neural_annot.py new file mode 100644 index 00000000..7ec50d32 --- /dev/null +++ b/mmhuman3d/data/data_converters/h36m_neural_annot.py @@ -0,0 +1,413 @@ +import glob +import json +import os +import random +from typing import List + +import numpy as np +import torch +from tqdm import tqdm +import cv2 + +from mmhuman3d.core.cameras import build_cameras +# from mmhuman3d.core.conventions.keypoints_mapping import smplx +from mmhuman3d.core.conventions.keypoints_mapping import ( + convert_kps, + get_keypoint_idx, + get_keypoint_idxs_by_part, +) +from mmhuman3d.data.data_structures.human_data import HumanData +from mmhuman3d.models.body_models.builder import build_body_model +from mmhuman3d.models.body_models.utils import transform_to_camera_frame +# from mmhuman3d.utils.transforms import aa_to_rotmat, rotmat_to_aa +from .base_converter import BaseModeConverter +from .builder import DATA_CONVERTERS + +import pdb + + +@DATA_CONVERTERS.register_module() +class H36mNeuralConverter(BaseModeConverter): + """Human3.6M dataset + `Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human + Sensing in Natural Environments' TPAMI`2014 + More details can be found in the `paper + `__. + + Args: + modes (list): 'val' or 'train' for accepted modes + protocol (int): 1 or 2 for available protocols + extract_img (bool): Store True to extract images into a separate + folder. Default: False. + mosh_dir (str, optional): Path to directory containing mosh files. + """ + ACCEPTED_MODES = ['valid', 'train'] + + def __init__(self, modes: List = []) -> None: + + self.device = torch.device('cuda:0') + self.misc = dict( + bbox_source='by_dataset', + cam_param_type='prespective', + cam_param_source='original', + smplx_source='neural_annot', + ) + self.smplx_shape = { + 'betas': (-1, 10), + 'transl': (-1, 3), + 'global_orient': (-1, 3), + 'body_pose': (-1, 21, 3), + # 'left_hand_pose': (-1, 15, 3), + # 'right_hand_pose': (-1, 15, 3), + # 'leye_pose': (-1, 3), + # 'reye_pose': (-1, 3), + # 'jaw_pose': (-1, 3), + # 'expression': (-1, 10) + } + super(H36mNeuralConverter, self).__init__(modes) + + + def convert_by_mode(self, + dataset_path: str, + out_path: str, + mode: str, + enable_multi_human_data: bool = False) -> dict: + """ + Args: + dataset_path (str): Path to directory where raw images and + annotations are stored. + out_path (str): Path to directory to save preprocessed npz file + mode (str): Mode in accepted modes + enable_multi_human_data (bool): + Whether to generate a multi-human data. If set to True, + stored in MultiHumanData() format. + Default: False, stored in HumanData() format. + + Returns: + dict: + A dict containing keys image_path, bbox_xywh, keypoints2d, + keypoints2d_mask, keypoints3d, keypoints3d_mask, cam_param + stored in HumanData() format + """ + # get targeted seq list + targeted_seqs = sorted( + glob.glob(os.path.join(dataset_path, 'images', 's_*_act_*'))) + + # get all subject_ids and rearrange the seqs + subject_ids = [int(os.path.basename(seq)[2:4]) for seq in targeted_seqs] + subject_ids = list(set(subject_ids)) + + subject_seq_dict = {} + for subject_id in subject_ids: + subject_seq_dict[subject_id] = [] + for seq in targeted_seqs: + if int(os.path.basename(seq)[2:4]) == subject_id: + subject_seq_dict[subject_id].append(seq) + + # choose sebjetct ids for different mode + if mode == 'train': + user_list = [1, 5, 6, 7, 8] + elif mode == 'val': + user_list = [9, 11] + subject_ids = list(set(subject_ids) & set(user_list)) + + # calculate size + seqs_len = 0 + for key in subject_seq_dict: + seqs_len += len(subject_seq_dict[key]) + size_i = min(size_i, seqs_len) + + # parse seqs + for s, sid in enumerate(subject_ids): + + # use HumanData to store all data + human_data = HumanData() + + # init seed and size + seed, size = '230811', '999' + size_i = min(int(size), len(targeted_seqs)) + random.seed(int(seed)) + targeted_seqs = targeted_seqs[:size_i] + # random.shuffle(npzs) + + # initialize output for human_data + smplx_ = {} + for key in self.smplx_shape.keys(): + smplx_[key] = [] + keypoints2d_smplx_, keypoints3d_smplx_, = [], [] + keypoints2d_orig_, keypoints3d_orig_ = [], [] + bboxs_ = {} + for bbox_name in ['bbox_xywh']: + bboxs_[bbox_name] = [] + meta_ = {} + for key in ['focal_length', 'principal_point', 'height', 'width']: + meta_[key] = [] + image_path_ = [] + + # init smplx model + smplx_model = build_body_model( + dict( + type='SMPLX', + keypoint_src='smplx', + keypoint_dst='smplx', + model_path='data/body_models/smplx', + gender='neutral', + num_betas=10, + use_face_contour=True, + flat_hand_mean=False, + use_pca=False, + batch_size=1)).to(self.device) + + # load subject annotations + anno_base_path = os.path.join(dataset_path, 'annotations') + anno_base_name = f'Human36M_subject{int(sid)}' + + # load camera parameters + cam_param = f'{anno_base_name}_camera.json' + with open(os.path.join(anno_base_path, cam_param)) as f: + cam_params = json.load(f) + + # load data annotations + data_an = f'{anno_base_name}_data.json' + with open(os.path.join(anno_base_path, data_an)) as f: + data_annos = json.load(f) + + # load joints 3d annotations + joints3d_an = f'{anno_base_name}_joint_3d.json' + with open(os.path.join(anno_base_path, joints3d_an)) as f: + j3d_annos = json.load(f) + + # load smplx annotations (NeuralAnnot) + smplx_an = f'{anno_base_name}_SMPLX_NeuralAnnot.json' + with open(os.path.join(anno_base_path, smplx_an)) as f: + smplx_annos = json.load(f) + + for seq in tqdm(subject_seq_dict[sid], + desc=f'Processing subject {s + 1} / {len(subject_ids)}', + position=0, leave=False): + + # get ids + seqn = os.path.basename(seq) + action_id = str(int(seqn[9:11])) + subaction_id = str(int(seqn[19:21])) + camera_id = str(int(seqn[-2:])) + + # get annotation slice + smplx_anno_seq = smplx_annos[action_id][subaction_id] + + # get frames list + frames = smplx_anno_seq.keys() + + # get seq annotations + data_anno_seq = [data_annos['annotations'][idx] | data_annos['images'][idx] + for idx, finfo in enumerate(data_annos['images']) if + os.path.basename(finfo['file_name'])[:-11] == seqn] + assert len(data_anno_seq) == len(frames) + + # get joints 3d annotations + j3d_anno_seq = j3d_annos[action_id][subaction_id] + + # get camera parameters + cam_param = cam_params[camera_id] + R = np.array(cam_param['R']).reshape(3, 3) + T = np.array(cam_param['t']).reshape(1, 3) + focal_length = np.array(cam_param['f']).reshape(1, 2) + principal_point = np.array(cam_param['c']).reshape(1, 2) + + # create extrinsics + extrinsics = np.eye(4) + extrinsics[:3, :3] = R + extrinsics[:3, 3] = T / 1000 + + # create intrinsics camera, assume resolution is same in seq + width, height = data_anno_seq[0]['width'], data_anno_seq[0]['height'] + camera = build_cameras( + dict( + type='PerspectiveCameras', + convention='opencv', + in_ndc=False, + focal_length=focal_length, + image_size=(width, height), + principal_point=principal_point)).to(self.device) + + for fid, frame in enumerate(tqdm(frames, desc=f'Processing seq {seqn}', + position=1, leave=False)): + + smplx_anno = smplx_anno_seq[frame] + + # get image and bbox + info_anno = data_anno_seq[fid] + width, height = info_anno['width'], info_anno['height'] + bbox_xywh = info_anno['bbox'] + bbox_xywh.append(1) + imgp = os.path.join(dataset_path, 'images', info_anno['file_name']) + image_path = imgp.replace(f'{dataset_path}{os.path.sep}', '') + + # reformat smplx_anno + smplx_param = {} + smplx_param['global_orient'] = np.array(smplx_anno['root_pose']).reshape(-1, 3) + smplx_param['body_pose'] = np.array(smplx_anno['body_pose']).reshape(-1, 21, 3) + smplx_param['betas'] = np.array(smplx_anno['shape']).reshape(-1, 10) + smplx_param['transl'] = np.array(smplx_anno['trans']).reshape(-1, 3) + + # get pelvis world + intersect_keys = list( + set(smplx_param.keys()) & set(self.smplx_shape.keys())) + body_model_param_tensor = { + key: torch.tensor( + np.array(smplx_param[key]).reshape(self.smplx_shape[key]), + device=self.device, + dtype=torch.float32) + for key in intersect_keys + } + output = smplx_model(**body_model_param_tensor, return_joints=True) + + keypoints_3d = output['joints'] + pelvis_world = keypoints_3d.detach().cpu().numpy()[ + 0, get_keypoint_idx('pelvis', 'smplx')] + + # transform to camera space + global_orient, transl = transform_to_camera_frame( + global_orient=smplx_param['global_orient'], + transl=smplx_param['transl'], + pelvis=pelvis_world, + extrinsic=extrinsics) + + # update smplx param + smplx_param['global_orient'] = global_orient + smplx_param['transl'] = transl + + # update smplx + for update_key in ['global_orient', 'transl']: + body_model_param_tensor[update_key] = torch.tensor( + np.array(smplx_param[update_key]).reshape( + self.smplx_shape[update_key]), + device=self.device, + dtype=torch.float32) + output = smplx_model(**body_model_param_tensor, return_joints=True) + keypoints_3d = output['joints'] + + # get kps2d + keypoints_2d_xyd = camera.transform_points_screen(keypoints_3d) + keypoints_2d = keypoints_2d_xyd[..., :2].detach().cpu().numpy() + keypoints_3d = keypoints_3d.detach().cpu().numpy() + + # get j3d and project to 2d + j3d = np.array(j3d_anno_seq[frame]).reshape(-1, 3) + j3d_c = np.dot(R, j3d.transpose(1,0)).transpose(1,0) + T.reshape(1,3) + j2d = camera.transform_points_screen(torch.tensor(j3d_c, device=self.device, dtype=torch.float32)) + j2d = j2d.detach().cpu().numpy()[..., :2] + + # test projection + # pdb.set_trace() + # kps3d_c = np.dot(R, keypoints_3d[0].transpose(1,0)).transpose(1,0) + T.reshape(1,3) + # kps2d = camera.transform_points_screen(torch.tensor(kps3d_c, device=self.device, dtype=torch.float32)) + # kps2d = kps2d.detach().cpu().numpy()[..., :2] + + # # test overlay j2d + # img = cv2.imread(f'{dataset_path}/{image_path}') + # for kp in j2d: + # if 0 < kp[0] < 1920 and 0 < kp[1] < 1080: + # cv2.circle(img, (int(kp[0]), int(kp[1])), 1, (0,0,255), -1) + # pass + # # write image + # os.makedirs(f'{out_path}', exist_ok=True) + # cv2.imwrite(f'{out_path}/{os.path.basename(seq)}_{fid}.jpg', img) + + # append image path + image_path_.append(image_path) + + # append keypoints2d and 3d + keypoints2d_smplx_.append(keypoints_2d) + keypoints3d_smplx_.append(keypoints_3d) + keypoints2d_orig_.append(j2d) + keypoints3d_orig_.append(j3d_c) + + # append bbox + bboxs_['bbox_xywh'].append(bbox_xywh) + + # append smpl + for key in smplx_param.keys(): + smplx_[key].append(smplx_param[key]) + + # append meta + meta_['principal_point'].append(principal_point) + meta_['focal_length'].append(focal_length) + meta_['height'].append(height) + meta_['width'].append(width) + + # pdb.set_trace() + + # meta + human_data['meta'] = meta_ + + # image path + human_data['image_path'] = image_path_ + + # save bbox + for bbox_name in bboxs_.keys(): + bbox_ = np.array(bboxs_[bbox_name]).reshape(-1, 5) + human_data[bbox_name] = bbox_ + + # save smplx + # human_data.skip_keys_check = ['smplx'] + for key in smplx_.keys(): + smplx_[key] = np.concatenate( + smplx_[key], axis=0).reshape(self.smplx_shape[key]) + human_data['smplx'] = smplx_ + + # keypoints2d_smplx + keypoints2d_smplx = np.concatenate( + keypoints2d_smplx_, axis=0).reshape(-1, 144, 2) + keypoints2d_smplx_conf = np.ones([keypoints2d_smplx.shape[0], 144, 1]) + keypoints2d_smplx = np.concatenate( + [keypoints2d_smplx, keypoints2d_smplx_conf], axis=-1) + keypoints2d_smplx, keypoints2d_smplx_mask = \ + convert_kps(keypoints2d_smplx, src='smplx', dst='human_data') + human_data['keypoints2d_smplx'] = keypoints2d_smplx + human_data['keypoints2d_smplx_mask'] = keypoints2d_smplx_mask + + # keypoints3d_smplx + keypoints3d_smplx = np.concatenate( + keypoints3d_smplx_, axis=0).reshape(-1, 144, 3) + keypoints3d_smplx_conf = np.ones([keypoints3d_smplx.shape[0], 144, 1]) + keypoints3d_smplx = np.concatenate( + [keypoints3d_smplx, keypoints3d_smplx_conf], axis=-1) + keypoints3d_smplx, keypoints3d_smplx_mask = \ + convert_kps(keypoints3d_smplx, src='smplx', dst='human_data') + human_data['keypoints3d_smplx'] = keypoints3d_smplx + human_data['keypoints3d_smplx_mask'] = keypoints3d_smplx_mask + + # keypoints2d_orig + keypoints2d_orig = np.concatenate( + keypoints2d_orig_, axis=0).reshape(-1, 17, 2) + keypoints2d_orig_conf = np.ones([keypoints2d_orig.shape[0], 17, 1]) + keypoints2d_orig = np.concatenate( + [keypoints2d_orig, keypoints2d_orig_conf], axis=-1) + keypoints2d_orig, keypoints2d_orig_mask = \ + convert_kps(keypoints2d_orig, src='h36m', dst='human_data') + human_data['keypoints2d_original'] = keypoints2d_orig + human_data['keypoints2d_original_mask'] = keypoints2d_orig_mask + + # keypoints3d_orig + keypoints3d_orig = np.concatenate( + keypoints3d_orig_, axis=0).reshape(-1, 17, 3) + keypoints3d_orig_conf = np.ones([keypoints3d_orig.shape[0], 17, 1]) + keypoints3d_orig = np.concatenate( + [keypoints3d_orig, keypoints3d_orig_conf], axis=-1) + keypoints3d_orig, keypoints3d_orig_mask = \ + convert_kps(keypoints3d_orig, src='h36m', dst='human_data') + human_data['keypoints3d_original'] = keypoints3d_orig + human_data['keypoints3d_original_mask'] = keypoints3d_orig_mask + + # misc + human_data['misc'] = self.misc + human_data['config'] = f'h36m_neural_annot_{mode}' + + # save + human_data.compress_keypoints_by_mask() + os.makedirs(out_path, exist_ok=True) + out_file = os.path.join( + out_path, + f'h36m_neural_{mode}_{seed}_{"{:04d}".format(size_i)}_subject{sid}.npz') + human_data.dump(out_file) \ No newline at end of file diff --git a/mmhuman3d/data/data_converters/mpii.py b/mmhuman3d/data/data_converters/mpii.py deleted file mode 100644 index 59c6ccdd..00000000 --- a/mmhuman3d/data/data_converters/mpii.py +++ /dev/null @@ -1,91 +0,0 @@ -import os -from typing import List - -import h5py -import numpy as np -from tqdm import tqdm - -from mmhuman3d.core.conventions.keypoints_mapping import convert_kps -from mmhuman3d.data.data_structures.human_data import HumanData -from .base_converter import BaseConverter -from .builder import DATA_CONVERTERS - - -@DATA_CONVERTERS.register_module() -class MpiiConverter(BaseConverter): - """MPII Dataset `2D Human Pose Estimation: New Benchmark and State of the - Art Analysis' CVPR'2014. More details can be found in the `paper. - - `__ . - """ - - @staticmethod - def center_scale_to_bbox(center: float, scale: float) -> List[float]: - """Obtain bbox given center and scale.""" - w, h = scale * 200, scale * 200 - x, y = center[0] - w / 2, center[1] - h / 2 - return [x, y, w, h] - - def convert(self, dataset_path: str, out_path: str) -> dict: - """ - Args: - dataset_path (str): Path to directory where raw images and - annotations are stored. - out_path (str): Path to directory to save preprocessed npz file - - Returns: - dict: - A dict containing keys image_path, bbox_xywh, keypoints2d, - keypoints2d_mask stored in HumanData() format - """ - # use HumanData to store all data - human_data = HumanData() - - # structs we use - image_path_, bbox_xywh_, keypoints2d_ = [], [], [] - - # annotation files - annot_file = os.path.join(dataset_path, 'train.h5') - - # read annotations - f = h5py.File(annot_file, 'r') - centers, image_path, keypoints2d, scales = \ - f['center'], f['imgname'], f['part'], f['scale'] - - # go over all annotated examples - for center, imgname, keypoints2d16, scale in tqdm( - zip(centers, image_path, keypoints2d, scales)): - imgname = imgname.decode('utf-8') - # check if all major body joints are annotated - if (keypoints2d16 > 0).sum() < 2 * 16: - continue - - # keypoints - keypoints2d16 = np.hstack([keypoints2d16, np.ones([16, 1])]) - - # bbox - bbox_xywh = self.center_scale_to_bbox(center, scale) - - # store data - image_path_.append(os.path.join('images', imgname)) - bbox_xywh_.append(bbox_xywh) - keypoints2d_.append(keypoints2d16) - - bbox_xywh_ = np.array(bbox_xywh_).reshape((-1, 4)) - bbox_xywh_ = np.hstack([bbox_xywh_, np.ones([bbox_xywh_.shape[0], 1])]) - keypoints2d_ = np.array(keypoints2d_).reshape((-1, 16, 3)) - keypoints2d_, mask = convert_kps(keypoints2d_, 'mpii', 'human_data') - - human_data['image_path'] = image_path_ - human_data['bbox_xywh'] = bbox_xywh_ - human_data['keypoints2d_mask'] = mask - human_data['keypoints2d'] = keypoints2d_ - human_data['config'] = 'mpii' - human_data.compress_keypoints_by_mask() - - # store the data struct - if not os.path.isdir(out_path): - os.makedirs(out_path) - - out_file = os.path.join(out_path, 'mpii_train.npz') - human_data.dump(out_file) diff --git a/mmhuman3d/data/data_converters/mpii_neural_annot.py b/mmhuman3d/data/data_converters/mpii_neural_annot.py new file mode 100644 index 00000000..358c8e03 --- /dev/null +++ b/mmhuman3d/data/data_converters/mpii_neural_annot.py @@ -0,0 +1,285 @@ +import glob +import json +import os +import random +from typing import List + +import numpy as np +import torch +from tqdm import tqdm +import cv2 + +from mmhuman3d.core.cameras import build_cameras +# from mmhuman3d.core.conventions.keypoints_mapping import smplx +from mmhuman3d.core.conventions.keypoints_mapping import ( + convert_kps, + get_keypoint_idx, + get_keypoint_idxs_by_part, +) +from mmhuman3d.data.data_structures.human_data import HumanData +from mmhuman3d.models.body_models.builder import build_body_model +from mmhuman3d.models.body_models.utils import transform_to_camera_frame +# from mmhuman3d.utils.transforms import aa_to_rotmat, rotmat_to_aa +from .base_converter import BaseModeConverter +from .builder import DATA_CONVERTERS + +import pdb + + +@DATA_CONVERTERS.register_module() +class MpiiNeuralConverter(BaseModeConverter): + """MPII Dataset `2D Human Pose Estimation: New Benchmark and State of the + Art Analysis' CVPR'2014. More details can be found in the `paper. + + `__ . + """ + + ACCEPTED_MODES = ['test', 'train'] + + def __init__(self, modes: List = []) -> None: + + self.device = torch.device('cuda:0') + self.misc = dict( + bbox_source='by_dataset', + cam_param_type='prespective', + cam_param_source='original', + smplx_source='neural_annot', + ) + self.smplx_shape = { + 'betas': (-1, 10), + 'transl': (-1, 3), + 'global_orient': (-1, 3), + 'body_pose': (-1, 21, 3), + # 'left_hand_pose': (-1, 15, 3), + # 'right_hand_pose': (-1, 15, 3), + # 'leye_pose': (-1, 3), + # 'reye_pose': (-1, 3), + # 'jaw_pose': (-1, 3), + # 'expression': (-1, 10) + } + super(MpiiNeuralConverter, self).__init__(modes) + + + def convert_by_mode(self, + dataset_path: str, + out_path: str, + mode: str, + enable_multi_human_data: bool = False) -> dict: + """ + Args: + dataset_path (str): Path to directory where raw images and + annotations are stored. + out_path (str): Path to directory to save preprocessed npz file + + Returns: + dict: + A dict containing keys image_path, bbox_xywh, keypoints2d, + keypoints2d_mask stored in HumanData() format + """ + # use HumanData to store all data + human_data = HumanData() + + # initialize output for human_data + smplx_ = {} + for key in self.smplx_shape.keys(): + smplx_[key] = [] + keypoints2d_smplx_, keypoints3d_smplx_, = [], [] + keypoints2d_orig_ = [ ] + bboxs_ = {} + for bbox_name in ['bbox_xywh']: + bboxs_[bbox_name] = [] + meta_ = {} + for key in ['focal_length', 'principal_point', 'height', 'width']: + meta_[key] = [] + image_path_ = [] + + # load data seperate + split_path = os.path.join(dataset_path, 'annotations', f'{mode}_reformat.json') + with open(split_path, 'r') as f: + image_data = json.load(f) + + # load smplx annot + smplx_path = os.path.join(dataset_path, 'annotations', f'MPII_train_SMPLX_NeuralAnnot.json') + with open(smplx_path, 'r') as f: + smplx_data = json.load(f) + + # get targeted frame list + image_list = list(image_data.keys()) + + # init seed and size + seed, size = '230814', '90999' + size_i = min(int(size), len(image_list)) + random.seed(int(seed)) + image_list = image_list[:size_i] + # random.shuffle(npzs) + + # init smplx model + smplx_model = build_body_model( + dict( + type='SMPLX', + keypoint_src='smplx', + keypoint_dst='smplx', + model_path='data/body_models/smplx', + gender='neutral', + num_betas=10, + use_face_contour=True, + flat_hand_mean=False, + use_pca=False, + batch_size=1)).to(self.device) + + for fname in tqdm(image_list, desc=f'Converting MPII {mode} data'): + + # get info slice + image_info = image_data[fname] + + # prepare image path + image_path = os.path.join('images', f'{fname}') + imgp = os.path.join(dataset_path, image_path) + + # access image info + annot_id = image_info['id'] + width = image_info['width'] + height = image_info['height'] + + # read keypoints2d and bbox + j2d = np.array(image_info['keypoints']).reshape(-1, 3) + + bbox_xywh = image_info['bbox'] + bbox_xywh.append(1) + + # read smplx annot info + annot_info = smplx_data[str(annot_id)] + smplx_info = annot_info['smplx_param'] + cam_info = annot_info['cam_param'] + + # reformat smplx anno + smplx_param = {} + smplx_param['global_orient'] = np.array(smplx_info['root_pose']).reshape(-1, 3) + smplx_param['body_pose'] = np.array(smplx_info['body_pose']).reshape(-1, 21, 3) + smplx_param['betas'] = np.array(smplx_info['shape']).reshape(-1, 10) + smplx_param['transl'] = np.array(smplx_info['trans']).reshape(-1, 3) + + # get camera param and build camera + focal_length = cam_info['focal'] + principal_point = cam_info['princpt'] + + camera = build_cameras( + dict( + type='PerspectiveCameras', + convention='opencv', + in_ndc=False, + focal_length=focal_length, + image_size=(width, height), + principal_point=principal_point)).to(self.device) + + # get smplx output + intersect_keys = list( + set(smplx_param.keys()) & set(self.smplx_shape.keys())) + body_model_param_tensor = { + key: torch.tensor( + np.array(smplx_param[key]).reshape(self.smplx_shape[key]), + device=self.device, + dtype=torch.float32) + for key in intersect_keys + } + output = smplx_model(**body_model_param_tensor, return_joints=True) + + # get kps2d from projection + keypoints_3d = output['joints'] + keypoints_2d_xyd = camera.transform_points_screen(keypoints_3d) + keypoints_2d = keypoints_2d_xyd[..., :2].detach().cpu().numpy() + keypoints_3d = keypoints_3d.detach().cpu().numpy() + + + # # test overlay j2d + # img = cv2.imread(f'{dataset_path}/{image_path}') + # for kp in keypoints_2d[0]: + # if 0 < kp[0] < 1920 and 0 < kp[1] < 1080: + # cv2.circle(img, (int(kp[0]), int(kp[1])), 1, (0,0,255), -1) + # pass + # # write image + # os.makedirs(f'{out_path}', exist_ok=True) + # cv2.imwrite(f'{out_path}/{fname}', img) + + # append image path + image_path_.append(image_path) + + # append keypoints2d and 3d + keypoints2d_smplx_.append(keypoints_2d) + keypoints3d_smplx_.append(keypoints_3d) + keypoints2d_orig_.append(j2d) + + # append bbox + bboxs_['bbox_xywh'].append(bbox_xywh) + + # append smpl + for key in smplx_param.keys(): + smplx_[key].append(smplx_param[key]) + + # append meta + meta_['principal_point'].append(principal_point) + meta_['focal_length'].append(focal_length) + meta_['height'].append(height) + meta_['width'].append(width) + + # pdb.set_trace() + + # meta + human_data['meta'] = meta_ + + # image path + human_data['image_path'] = image_path_ + + # save bbox + for bbox_name in bboxs_.keys(): + bbox_ = np.array(bboxs_[bbox_name]).reshape(-1, 5) + human_data[bbox_name] = bbox_ + + # save smplx + # human_data.skip_keys_check = ['smplx'] + for key in smplx_.keys(): + smplx_[key] = np.concatenate( + smplx_[key], axis=0).reshape(self.smplx_shape[key]) + human_data['smplx'] = smplx_ + + # keypoints2d_orig + keypoints2d_orig = np.concatenate( + keypoints2d_orig_, axis=0).reshape(-1, 16, 3) + keypoints2d_orig, keypoints2d_orig_mask = \ + convert_kps(keypoints2d_orig, src='mpii', dst='human_data') + human_data['keypoints2d_original'] = keypoints2d_orig + human_data['keypoints2d_original_mask'] = keypoints2d_orig_mask + + # keypoints2d_smplx + keypoints2d_smplx = np.concatenate( + keypoints2d_smplx_, axis=0).reshape(-1, 144, 2) + keypoints2d_smplx_conf = np.ones([keypoints2d_smplx.shape[0], 144, 1]) + keypoints2d_smplx = np.concatenate( + [keypoints2d_smplx, keypoints2d_smplx_conf], axis=-1) + keypoints2d_smplx, keypoints2d_smplx_mask = \ + convert_kps(keypoints2d_smplx, src='smplx', dst='human_data') + human_data['keypoints2d_smplx'] = keypoints2d_smplx + human_data['keypoints2d_smplx_mask'] = keypoints2d_smplx_mask + + # keypoints3d_smplx + keypoints3d_smplx = np.concatenate( + keypoints3d_smplx_, axis=0).reshape(-1, 144, 3) + keypoints3d_smplx_conf = np.ones([keypoints3d_smplx.shape[0], 144, 1]) + keypoints3d_smplx = np.concatenate( + [keypoints3d_smplx, keypoints3d_smplx_conf], axis=-1) + keypoints3d_smplx, keypoints3d_smplx_mask = \ + convert_kps(keypoints3d_smplx, src='smplx', dst='human_data') + human_data['keypoints3d_smplx'] = keypoints3d_smplx + human_data['keypoints3d_smplx_mask'] = keypoints3d_smplx_mask + + # misc + human_data['misc'] = self.misc + human_data['config'] = f'mpii_neural_annot_{mode}' + + # save + human_data.compress_keypoints_by_mask() + os.makedirs(out_path, exist_ok=True) + out_file = os.path.join( + out_path, + f'mpii_neural_{mode}_{seed}_{"{:05d}".format(size_i)}.npz') + human_data.dump(out_file) \ No newline at end of file diff --git a/tools/convert_datasets.py b/tools/convert_datasets.py index 91bcd2de..1cb5ac3d 100644 --- a/tools/convert_datasets.py +++ b/tools/convert_datasets.py @@ -11,7 +11,7 @@ crowdpose=dict( type='CrowdposeConverter', modes=['train', 'val', 'test', 'trainval']), pw3d=dict(type='Pw3dConverter', modes=['train', 'test']), - mpii=dict(type='MpiiConverter'), + h36m_p1=dict( type='H36mConverter', modes=['train', 'valid'], @@ -110,6 +110,10 @@ type='Hsc4dConverter', # real, in progress prefix='hsc4d', modes=['train']), + h36m=dict( + type='H36mConverter', + modes=['train', 'val'], + prefix='h36m'), motionx=dict( type='MotionXConverter', # real, in progress prefix='motionx', @@ -118,6 +122,10 @@ type='MoyoConverter', # real prefix='moyo', modes=['train', 'val']), + mpii=dict( + type='MpiiConverter', # real multi-human? + prefix='mpii', + modes=['train', 'test'],), renbody=dict( type='RenbodyConverter', # real prefix='renbody', diff --git a/tools/convert_datasets_commands.txt b/tools/convert_datasets_commands.txt index c4c88f6a..a5169cd0 100644 --- a/tools/convert_datasets_commands.txt +++ b/tools/convert_datasets_commands.txt @@ -4,22 +4,22 @@ # behave-converter python tools/convert_datasets.py \ --datasets behave \ - --root_path /mnt/d \ - --output_path /mnt/d/behave/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/behave/output \ --modes train # blurhand-converter python tools/convert_datasets.py \ --datasets blurhand \ - --root_path /mnt/d \ - --output_path /mnt/d/blurhand/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/blurhand/output \ --modes train # cimi4d-converter python tools/convert_datasets.py \ --datasets cimi4d \ - --root_path /mnt/d \ - --output_path /mnt/d/cimi4d/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/cimi4d/output \ --modes train # dynacam-converter @@ -32,8 +32,8 @@ python tools/convert_datasets.py \ # egobody-converter python tools/convert_datasets.py \ --datasets egobody \ - --root_path /mnt/d \ - --output_path /mnt/d/egobody/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/egobody/output \ --modes egocentric_train python release_vis_kinect_scene.py @@ -50,8 +50,8 @@ python tools/convert_datasets.py \ # freihand-converter python tools/convert_datasets.py \ --datasets freihand \ - --root_path /mnt/d \ - --output_path /mnt/d/freihand/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/freihand/output \ --modes train # gta-converter @@ -64,8 +64,8 @@ python tools/convert_datasets.py \ # hanco-converter python tools/convert_datasets.py \ --datasets hanco \ - --root_path /mnt/d \ - --output_path /mnt/d/hanco/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/hanco/output \ --modes train # hsc4d-converter @@ -75,53 +75,67 @@ python tools/convert_datasets.py \ --output_path /mnt/e/hsc4d/output \ --modes train +# human36m-neural-converter +python tools/convert_datasets.py \ + --datasets h36m\ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/h36m/output \ + --modes train + # humanart-converter python tools/convert_datasets.py \ --datasets humanart \ - --root_path /mnt/d \ - --output_path /mnt/d/humanart/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/humanart/output \ --modes cosplay # interhand-converter python tools/convert_datasets.py \ --datasets interhand26m \ - --root_path /mnt/d \ - --output_path /mnt/d/interhand26m/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/interhand26m/output \ --modes train # motionx-converter python tools/convert_datasets.py \ --datasets motionx \ - --root_path /mnt/d \ - --output_path /mnt/d/motionx/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/motionx/output \ --modes train # moyo-converter python tools/convert_datasets.py \ --datasets moyo \ - --root_path /mnt/d \ - --output_path /mnt/d/moyo/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/moyo/output \ --modes train +# mpii-converter +python tools/convert_datasets.py \ + --datasets mpii \ + --root_path /mnt/e \ + --output_path /mnt/e/mpii/output \ + --modes test + # renbody-converter python tools/convert_datasets.py \ --datasets renbody \ - --root_path /mnt/d \ - --output_path /mnt/d/renbody/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/renbody/output \ --modes train # rmar-converter python tools/convert_datasets.py \ --datasets sminchisescu \ - --root_path /mnt/d \ - --output_path /mnt/d/sminchisescu-research-datasets/output + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/sminchisescu-research-datasets/output --modes HumanSC3D # sgnify-converter python tools/convert_datasets.py \ --datasets sgnify \ - --root_path /mnt/d \ - --output_path /mnt/d/sgnify/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/sgnify/output \ --modes train # shapy-converter @@ -134,8 +148,8 @@ python tools/convert_datasets.py \ # sloper4d-converter python tools/convert_datasets.py \ --datasets sloper4d \ - --root_path /mnt/d \ - --output_path /mnt/d/sloper4d/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/sloper4d/output \ --modes train # ssp3d-converter @@ -153,8 +167,8 @@ python tools/convert_datasets.py \ # ubody-converter python tools/convert_datasets.py \ --datasets ubody \ - --root_path /mnt/d \ - --output_path /mnt/d/ubody/output \ + --root_path /mnt/d/datasets \ + --output_path /mnt/d/datasets/ubody/output \ --modes inter intra @@ -181,7 +195,7 @@ python tools/convert_datasets.py \ srun -p Zoetrope python tools/convert_datasets.py \ --datasets synbody \ --root_path /mnt/lustre/share_data/weichen1/datasets \ - --output_path /mnt/lustre/share_data/weichen1/converted_humandata \ + --output_path /mnt/lustre/share_data/weichen1/converted_humandata_new \ --modes v1_train python tools/preprocess/synbody_preprocess.py \ diff --git a/tools/postprocess/humandata_sample.py b/tools/postprocess/humandata_sample.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/solvepnp/solvepnp.py b/tools/postprocess/solvepnp.py similarity index 100% rename from tools/solvepnp/solvepnp.py rename to tools/postprocess/solvepnp.py diff --git a/tools/postprocess/testpnp.py b/tools/postprocess/testpnp.py new file mode 100644 index 00000000..7f2b40c1 --- /dev/null +++ b/tools/postprocess/testpnp.py @@ -0,0 +1,60 @@ +import cv2 +import os +import numpy as np + +def main(): + + kps2d = np.array([[743.20697021, 580.30181885, 1. ], + [737.36425781, 594.02532959, 1. ], + [753.28991699, 593.3303833 , 1. ], + [744.3972168 , 565.92321777, 1. ], + [704.01763916, 619.74078369, 1. ], + [756.39501953, 635.73272705, 1. ], + [744.85351562, 548.64349365, 1. ], + [722.13586426, 668.10357666, 1. ], + [764.72460938, 688.8692627 , 1. ], + [744.24841309, 540.1550293 , 1. ]]) + + kps3d = np.array([[-1.70768964e+00, 3.17459553e-01, 9.12188816e+00, + 1.00000000e+00], + [-1.74542880e+00, 4.23549652e-01, 9.07879639e+00, + 1.00000000e+00], + [-1.63456345e+00, 4.21710700e-01, 9.15716743e+00, + 1.00000000e+00], + [-1.69063628e+00, 2.03275323e-01, 9.08065033e+00, + 1.00000000e+00], + [-2.04099011e+00, 6.35786593e-01, 9.23319435e+00, + 1.00000000e+00], + [-1.63573825e+00, 7.69105434e-01, 9.30350399e+00, + 1.00000000e+00], + [-1.68367076e+00, 6.76413178e-02, 9.06241703e+00, + 1.00000000e+00], + [-1.87235999e+00, 1.00837398e+00, 9.11552143e+00, + 1.00000000e+00], + [-1.56375468e+00, 1.19213676e+00, 9.27346706e+00, + 1.00000000e+00], + [-1.68993425e+00, 1.21444464e-03, 9.07061863e+00, + 1.00000000e+00]]) + + kps2d = kps2d[:, :2] + kps3d = kps3d[:, :3] + + focal_length = (1158.0337, 1158.0337) + camera_center = (960, 540) + + # create camera matrix + cameraMatrix = np.zeros((3,3)) + cameraMatrix[0,0] = focal_length[0] + cameraMatrix[1,1] = focal_length[1] + cameraMatrix[0,2] = camera_center[0] + cameraMatrix[1,2] = camera_center[1] + + objPoints = np.array(kps3d).reshape(-1, 3) + imgPoints = np.array(kps2d).reshape(-1, 2) + + retval, rvec, tvec = cv2.solvePnP(objPoints, imgPoints, cameraMatrix, distCoeffs=None) + + print(rvec, tvec) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tools/shape2smplx/fit_shape2smplx.py b/tools/preprocess/fit_shape2smplx.py similarity index 97% rename from tools/shape2smplx/fit_shape2smplx.py rename to tools/preprocess/fit_shape2smplx.py index e7bc7624..e94b37f5 100644 --- a/tools/shape2smplx/fit_shape2smplx.py +++ b/tools/preprocess/fit_shape2smplx.py @@ -188,7 +188,7 @@ def main(args): basedir = osp.basename(load_dir) # prepare mesh type - SUPPORTED_MESH_TYPES = ['obj', 'npy'] + SUPPORTED_MESH_TYPES = ['obj', 'npy', 'ply'] assert mesh_type in SUPPORTED_MESH_TYPES, \ f'mesh type {mesh_type} not supported' @@ -234,12 +234,13 @@ def main(args): target_verts = trimesh.load(fp).vertices.reshape(1, 10475, 3) if mesh_type == 'npy': target_verts = np.load(fp).reshape(1, 10475, 3) + if mesh_type == 'ply': + target_verts = trimesh.load(fp).vertices.reshape(1, 10475, 3) # fit parameters params = fit_params(target_verts, body_model=smplx_model, args=args) stem, _ = osp.splitext(osp.basename(fp)) - save_path = osp.join(fp).replace('.obj', '.npz').replace( - basedir, 'fitted_params') + save_path = osp.join(load_dir, 'fitted_params', osp.basename(fp)).replace(f'.{args.mesh_type}', '.npz') os.makedirs(osp.dirname(save_path), exist_ok=True) # pdb.set_trace() diff --git a/tools/preprocess/mpii.py b/tools/preprocess/mpii.py new file mode 100644 index 00000000..5db2fb85 --- /dev/null +++ b/tools/preprocess/mpii.py @@ -0,0 +1,81 @@ +import argparse +import glob +import json +import os + +from tqdm import tqdm + +import pdb +# from memory_profiler import profile + + +# @profile +def rewrite_anno_json(args): + + anno_bp = args.dataset_path + # for interhand2.6m + # anno_bp = os.path.join(anno_bp, 'anno*5*', '*', '*test*data.json') + # for blurhand + anno_bp = os.path.join(anno_bp, 'annotations', 'train.json') + + anno_ps = glob.glob(anno_bp) + anno_ps = [x for x in anno_ps if 'SMPLX' not in x] + print(anno_ps) + + for annop in anno_ps: + + with open(annop, 'r') as f: + anno_data = json.load(f) + + image_data = {} + + + image_info_dict = {} + anno_info_dict = {} + + for idx in tqdm(range(len(anno_data['images'])), + desc='extracting image info'): + + image_info_slice = anno_data['images'][idx] + image_info_dict[image_info_slice['id']] = image_info_slice + pdb.set_trace() + + for idx in tqdm(range(len(anno_data['annotations'])), + desc='extracting anno info'): + + anno_info_slice = anno_data['annotations'][idx] + anno_info_dict[anno_info_slice['image_id']] = anno_info_slice + + + for id in tqdm(image_info_dict.keys(), desc='merging info'): + + if id not in anno_info_dict.keys(): + continue + + info_slice = image_info_dict[id] + anno_slice = anno_info_dict[id] + + imgp = os.path.basename(info_slice['file_name']) + + del info_slice['file_name'] + del info_slice['id'] + + image_data[imgp] = info_slice | anno_slice + + pdb.set_trace() + # save json + annop_new = annop.replace('.json', '_reformat.json') + json.dump(image_data, open(annop_new, 'w')) + # break + + # pdb.set_trace() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Interhand26m preprocess - write dataset format') + parser.add_argument( + '--dataset_path', type=str, required=True, help='path to the dataset') + # python tools/preprocess/mpii.py --dataset_path /mnt/e/mpii + args = parser.parse_args() + rewrite_anno_json(args)