Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] I tried to make skeleton_demo.py using mmpose Webcam API #2879

Open
kimdahyeon977 opened this issue Oct 10, 2024 · 0 comments
Open
Assignees

Comments

@kimdahyeon977
Copy link

kimdahyeon977 commented Oct 10, 2024

What is the problem this feature will solve?

I'm developing a skeleton_demo.py using the MMPose webcam API based on this tutorial, and while it works, there are several challenges I need help with:

  1. PoseC3D currently doesn't support multi-class classification.
  2. PoseC3D's input requires multiple frames, not bounding boxes. This creates a challenge when multiple bounding boxes appear in a single frame—how do I handle labeling in such cases?
  3. PoseC3D's inference time is quite slow, which introduces a delay when compared to the faster pose coordinate inference.
    (Skeleton-based model (PoseC3D) for Real-Time Webcam Inference #2155)

I’m struggling to resolve these issues. any helps?

What is the feature?

webcam based skeleton_demo.py based on mmpose webcam API

What alternatives have you considered?

here's my code.
1. mmpose/mmpose/apis/webcam/nodes/model_nodes/pose_tracker_node.py

def _merge_bbox(bboxes: List[Dict], ratio=0.5):
    """Merge bboxes in a video to create a new bbox that covers the region
    where hand moves in the video."""
    if len(bboxes) <= 1:
        return bboxes

    bboxes.sort(key=lambda b: _compute_area(b), reverse=True)
    merged = False
    for i in range(1, len(bboxes)):
        small_area = _compute_area(bboxes[i])
        x1 = max(bboxes[0]['bbox'][0], bboxes[i]['bbox'][0])
        y1 = max(bboxes[0]['bbox'][1], bboxes[i]['bbox'][1])
        x2 = min(bboxes[0]['bbox'][2], bboxes[i]['bbox'][2])
        y2 = min(bboxes[0]['bbox'][3], bboxes[i]['bbox'][3])
        area_ratio = (abs(x2 - x1) * abs(y2 - y1)) / small_area
        if area_ratio > ratio:
            bboxes[0]['bbox'][0] = min(bboxes[0]['bbox'][0],
                                       bboxes[i]['bbox'][0])
            bboxes[0]['bbox'][1] = min(bboxes[0]['bbox'][1],
                                       bboxes[i]['bbox'][1])
            bboxes[0]['bbox'][2] = max(bboxes[0]['bbox'][2],
                                       bboxes[i]['bbox'][2])
            bboxes[0]['bbox'][3] = max(bboxes[0]['bbox'][3],
                                       bboxes[i]['bbox'][3])
            merged = True
            break

    if merged:
        bboxes.pop(i)
        return _merge_bbox(bboxes, ratio)
    else:
        # return the largest bounding box
        return [bboxes[0]]


@dataclass
class TrackInfo:
    next_id: int = 0
    last_objects: List = None


@NODES.register_module()
class PoseTrackerNode(Node):
    """Perform object detection and top-down pose estimation. Only detect
    objects every few frames, and use the pose estimation results to track the
    object at interval.

    Note that MMDetection is required for this node. Please refer to
    `MMDetection documentation <https://mmdetection.readthedocs.io/en
    /latest/get_started.html>`_ for the installation guide.

    Parameters:
        name (str): The node name (also thread name)
        det_model_cfg (str): The config file of the detection model
        det_model_checkpoint (str): The checkpoint file of the detection model
        pose_model_cfg (str): The config file of the pose estimation model
        pose_model_checkpoint (str): The checkpoint file of the pose
            estimation model
        input_buffer (str): The name of the input buffer
        output_buffer (str|list): The name(s) of the output buffer(s)
        enable_key (str|int, optional): Set a hot-key to toggle enable/disable
            of the node. If an int value is given, it will be treated as an
            ascii code of a key. Please note: (1) If ``enable_key`` is set,
            the ``bypass()`` method need to be overridden to define the node
            behavior when disabled; (2) Some hot-keys are reserved for
            particular use. For example: 'q', 'Q' and 27 are used for exiting.
            Default: ``None``
        enable (bool): Default enable/disable status. Default: ``True``
        device (str): Specify the device to hold model weights and inference
            the model. Default: ``'cuda:0'``
        det_interval (int): Set the detection interval in frames. For example,
            ``det_interval==10`` means inference the detection model every
            10 frames. Default: 1
        class_ids (list[int], optional): Specify the object category indices
            to apply pose estimation. If both ``class_ids`` and ``labels``
            are given, ``labels`` will be ignored. If neither is given, pose
            estimation will be applied for all objects. Default: ``None``
        labels (list[str], optional): Specify the object category names to
            apply pose estimation. See also ``class_ids``. Default: ``None``
        bbox_thr (float): Set a threshold to filter out objects with low bbox
            scores. Default: 0.5
        kpt2bbox_cfg (dict, optional): Configure the process to get object
            bbox from its keypoints during tracking. Specifically, the bbox
            is obtained from the minimal outer rectangle of the keyponits with
            following configurable arguments: ``'scale'``, the coefficient to
            expand the keypoint outer rectangle, defaults to 1.5;
            ``'kpt_thr'``: a threshold to filter out low-scored keypoint,
            defaults to 0.3. See ``self.default_kpt2bbox_cfg`` for details
        smooth (bool): If set to ``True``, a :class:`Smoother` will be used to
            refine the pose estimation result. Default: ``True``
        smooth_filter_cfg (str): The filter config path to build the smoother.
            Only valid when ``smooth==True``. Default to use an OneEuro filter

    Example::
        >>> cfg = dict(
        ...    type='PoseTrackerNode',
        ...    name='pose tracker',
        ...    det_model_config='demo/mmdetection_cfg/'
        ...    'ssdlite_mobilenetv2_scratch_600e_coco.py',
        ...    det_model_checkpoint='https://download.openmmlab.com'
        ...    '/mmdetection/v2.0/ssd/'
        ...    'ssdlite_mobilenetv2_scratch_600e_coco/ssdlite_mobilenetv2_'
        ...    'scratch_600e_coco_20210629_110627-974d9307.pth',
        ...    pose_model_config='configs/wholebody/2d_kpt_sview_rgb_img/'
        ...    'topdown_heatmap/coco-wholebody/'
        ...    'vipnas_mbv3_coco_wholebody_256x192_dark.py',
        ...    pose_model_checkpoint='https://download.openmmlab.com/mmpose/'
        ...    'top_down/vipnas/vipnas_mbv3_coco_wholebody_256x192_dark'
        ...    '-e2158108_20211205.pth',
        ...    det_interval=10,
        ...    labels=['person'],
        ...    smooth=True,
        ...    device='cuda:0',
        ...    # `_input_` is an executor-reserved buffer
        ...    input_buffer='_input_',
        ...    output_buffer='human_pose')

        >>> from mmpose.apis.webcam.nodes import NODES
        >>> node = NODES.build(cfg)
    """

    default_kpt2bbox_cfg: Dict = dict(scale=1.5, kpt_thr=0.3)

    def __init__(
            self,
            name: str,
            model_config: str,
            model_checkpoint: str,
            det_model_config: str,
            det_model_checkpoint: str,
            pose_model_config: str,
            pose_model_checkpoint: str,
            input_buffer: str,
            output_buffer: Union[str, List[str]],
            enable_key: Optional[Union[str, int]] = None,
            enable: bool = True,
            device: str = 'cuda:0',
            det_interval: int = 1,
            class_ids: Optional[List] = None,
            labels: Optional[List] = None,
            bbox_thr: float = 0.5,
            kpt2bbox_cfg: Optional[dict] = None,
            smooth: bool = False,
            smooth_filter_cfg: str = 'configs/_base_/filters/one_euro.py',
            min_frame: int = 16,
            fps: int = 30,
            score_thr: float = 0.7):

        assert has_mmdet, \
            f'MMDetection is required for {self.__class__.__name__}.'

        super().__init__(name=name, enable_key=enable_key, enable=enable)
        self.model_config = mmcv.Config.fromfile(model_config)
        self.model_checkpoint = model_checkpoint
        self.det_model_config = get_config_path(det_model_config, 'mmdet')
        self.det_model_checkpoint = det_model_checkpoint
        self.pose_model_config = get_config_path(pose_model_config, 'mmpose')
        self.pose_model_checkpoint = pose_model_checkpoint
        self.device = device.lower()
        self.class_ids = class_ids
        self.labels = labels
        self.bbox_thr = 0.9
        self.det_interval = det_interval

        if not kpt2bbox_cfg:
            kpt2bbox_cfg = self.default_kpt2bbox_cfg
        self.kpt2bbox_cfg = copy.deepcopy(kpt2bbox_cfg)

        self.det_countdown = 0
        self.track_info = TrackInfo()

        if smooth:
            smooth_filter_cfg = get_config_path(smooth_filter_cfg, 'mmpose')
            self.smoother = Smoother(smooth_filter_cfg, keypoint_dim=2)
        else:
            self.smoother = None

        self._clip_buffer = []  # items: (clip message, num of frames)
        self.score_thr = score_thr
        self.min_frame = min_frame
        self.fps = fps

        self.det_model = init_detector(
            self.det_model_config,
            self.det_model_checkpoint,
            device=self.device)

        self.pose_model = init_pose_model(
            self.pose_model_config,
            self.pose_model_checkpoint,
            device=self.device)

        # register buffers
        self.register_input_buffer(input_buffer, 'input', trigger=True)
        self.register_output_buffer(output_buffer)

    def bypass(self, input_msgs):
        return input_msgs['input']

    @property
    def totol_clip_length(self):
        return sum([clip[1] for clip in self._clip_buffer])

    def _extend_clips(self, clips: List[Message]):
        """Push the newly loaded clips from buffer, and discard old clips."""
        for clip in clips:
            clip_length = clip.get_image().shape[0]
            self._clip_buffer.append((clip, clip_length))

        total_length = 0
        for i in range(-2, -len(self._clip_buffer) - 1, -1):
            total_length += self._clip_buffer[i][1]
            if total_length >= self.min_frame:
                self._clip_buffer = self._clip_buffer[i:]
                break

    def _merge_clips(self):
        """Concat the clips into a longer video, and gather bboxes."""
        videos = [clip[0].get_image() for clip in self._clip_buffer]
        video = np.stack(videos, axis=0)

        bboxes = []
        for clip in self._clip_buffer:
            objects = clip[0].get_objects(lambda x: x.get('label') == 'person')
            bboxes.append(_merge_bbox(objects))
        bboxes = list(filter(len, bboxes))
        return video, bboxes, self._clip_buffer[0][0].get_image().shape

    def process(self, input_msgs):
        input_msg = input_msgs['input']
        img = input_msg.get_image()

        if self.det_countdown == 0:
            # get objects by detection model
            self.det_countdown = self.det_interval
            preds = inference_detector(self.det_model, img)
            single_objects_det = self._post_process_det(preds)
        else:
            # get object by pose tracking
            single_objects_det = self._get_objects_by_tracking(img.shape)

        self.det_countdown -= 1

        single_objects_pose, _ = inference_top_down_pose_model(
            self.pose_model,
            img,
            single_objects_det,
            bbox_thr=self.bbox_thr,
            format='xyxy')

        single_objects, next_id = get_track_id(
            single_objects_pose,
            self.track_info.last_objects,
            self.track_info.next_id,
            use_oks=False,
            tracking_thr=0.3)

        self.track_info.next_id = next_id
        self.track_info.last_objects = single_objects.copy()

        # Pose smoothing
        if self.smoother:
            single_objects = self.smoother.smooth(single_objects)

        for obj in single_objects:
            obj['det_model_cfg'] = self.det_model.cfg
            obj['pose_model_cfg'] = self.pose_model.cfg

        input_msg.update_objects(single_objects)

        self._extend_clips([input_msg])
        video, shape = self._merge_clips()

        if self.totol_clip_length >= self.min_frame and len(
                single_objects) > 0 and max(map(len, single_objects)) > 0:
            # Init posec3d model
            h, w = shape[0], shape[1]
            for component in self.model_config.data.test.pipeline:
                if component['type'] == 'PoseNormalize':
                    component['mean'] = (w // 2, h // 2, .5)
                    component['max_value'] = (w, h, 1.)

            self.model = init_recognizer(self.model_config, self.model_checkpoint, self.device)
            # Inference pose
            print('Start Inferencing....')
            pred_label, pred_score, bboxes = recognize_pose_model_batch(
                self.model,
                self.det_model,
                self.pose_model,
                self.score_thr,
                self.bbox_thr,
                video,
                shape)

            result = bboxes[-1][0]  # Sort by bbox area

            if pred_score > self.score_thr:
                result['label'] = pred_label

            input_msg.update_objects([result])
        return input_msg

2. mmpose/apis/inference.py

def recognize_pose_model(
        model,
        object_pose,
        shape,
        labels=['Arrest', 'Arson', 'Fighting', 'Normal_Videos_event',
                'RoadAccidents', 'Robbery', 'Shooting']):

    num_frame = len(object_pose)

    # data preprocessing
    data = dict(frame_dir='',
                label=-1,
                img_shape=(shape[0], shape[1]),
                original_shape=(shape[0], shape[1]),
                start_index=0,
                modality='Pose',
                total_frames=num_frame)

    num_person = max([len(x) for x in object_pose])
    num_keypoint = 17

    keypoint = np.zeros((num_person, num_frame, num_keypoint, 2),
                        dtype=np.float16)
    keypoint_score = np.zeros((num_person, num_frame, num_keypoint),
                              dtype=np.float16)

    for j, pose in enumerate(object_pose):
        pose = pose['keypoints']
        try:
            keypoint[j, 0] = pose[:, :2]
            keypoint_score[j, 0] = pose[:, 2]
        except IndexError:
            pass
    data['keypoint'] = keypoint
    data['keypoint_score'] = keypoint_score

    results = inference_recognizer(model, data)
    pred_idx, pred_score = results[0]
    pred_label = labels[pred_idx]
    return pred_label, pred_score


def recognize_pose_model_batch(
    model,
    det_model,
    pose_model,
    score_thr,
    bbox_thr,
    video,
    shape,
    labels=['Arrest', 'Arson', 'Fighting', 'Normal_Videos_event', 'RoadAccidents', 'Robbery', 'Shooting']
):

    objects_det = []
    for v in video:
        result = inference_detector(det_model, v)
        result = result[0][result[0][:, 4] >= score_thr]
        objects_det.append(result)

    objects_pose = []
    bboxes = []
    for v, o in zip(video, objects_det):
        d = [dict(bbox=x) for x in list(o)]
        op, _ = inference_top_down_pose_model(
            pose_model,
            v,
            d,
            bbox_thr=bbox_thr,
            format='xyxy')
        bboxes.append(d)
        objects_pose.append(op)

    num_frame = len(objects_pose)

    # data preprocessing
    data = dict(frame_dir='',
                label=-1,
                img_shape=(shape[0], shape[1]),
                original_shape=(shape[0], shape[1]),
                start_index=0,
                modality='Pose',
                total_frames=num_frame)

    num_person = max([len(x) for x in objects_pose])
    num_keypoint = 17

    keypoint = np.zeros((num_person, num_frame, num_keypoint, 2),
                        dtype=np.float16)
    keypoint_score = np.zeros((num_person, num_frame, num_keypoint),
                              dtype=np.float16)

    for i, poses in enumerate(objects_pose):
        for j, pose in enumerate(poses):
            pose = pose['keypoints']
            try:
                keypoint[j, i] = pose[:, :2]
                keypoint_score[j, i] = pose[:, 2]
            except IndexError:
                pass
    data['keypoint'] = keypoint
    data['keypoint_score'] = keypoint_score

    results = inference_recognizer(model, data)
    pred_idx, pred_score = results[0]
    pred_label = labels[pred_idx]

    return pred_label, pred_score, bboxes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants