From ed928d2056f604828f007cabadb677366636a888 Mon Sep 17 00:00:00 2001 From: Onuralp SEZER Date: Fri, 1 Nov 2024 22:30:12 +0300 Subject: [PATCH] =?UTF-8?q?feat:=20=E2=9C=A8=20descriptors=20custom=20fiel?= =?UTF-8?q?d=20added=20for=20keypoint=20and=20as=20constant=20name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Onuralp SEZER --- supervision/config.py | 1 + supervision/keypoint/core.py | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/supervision/config.py b/supervision/config.py index b18d2e20b..18236c8a7 100644 --- a/supervision/config.py +++ b/supervision/config.py @@ -1,2 +1,3 @@ CLASS_NAME_DATA_FIELD = "class_name" ORIENTED_BOX_COORDINATES = "xyxyxyxy" +DESCRIPTORS_FIELD = "descriptors" diff --git a/supervision/keypoint/core.py b/supervision/keypoint/core.py index 803b818eb..a55cab725 100644 --- a/supervision/keypoint/core.py +++ b/supervision/keypoint/core.py @@ -7,7 +7,7 @@ import numpy as np import numpy.typing as npt -from supervision.config import CLASS_NAME_DATA_FIELD +from supervision.config import CLASS_NAME_DATA_FIELD, DESCRIPTORS_FIELD from supervision.detection.utils import get_data_item, is_data_equal from supervision.validators import validate_keypoints_fields @@ -542,8 +542,10 @@ def from_transformers(cls, transformers_results: List) -> KeyPoints: ``` """ # noqa: E501 // docs - keypoints_list = [] - scores_list = [] + keypoints_list: List[np.ndarray] = [] + scores_list: List[np.ndarray] = [] + descriptors_list: List[np.ndarray] = [] + data: Dict[str, Any] = {} for result in transformers_results: if "keypoints" in result: @@ -554,12 +556,22 @@ def from_transformers(cls, transformers_results: List) -> KeyPoints: keypoints_list.append(keypoints) scores_list.append(scores) + if "descriptors" in result: + descriptors = result["descriptors"].detach().numpy() + + if descriptors.size > 0: + descriptors_list.append(descriptors) + if not keypoints_list: return cls.empty() + if descriptors_list: + data[DESCRIPTORS_FIELD] = np.array(descriptors_list) + return cls( - xy=np.array(keypoints_list,dtype=np.float32), - confidence=np.array(scores_list,dtype=np.float32), + xy=np.array(keypoints_list, dtype=np.float32), + confidence=np.array(scores_list, dtype=np.float32), + data=data if data else {}, ) def __getitem__(