From d6859580bc3b047a2adc53d1e526be368c097d78 Mon Sep 17 00:00:00 2001 From: Y-T-G Date: Fri, 11 Aug 2023 16:35:32 +0800 Subject: [PATCH] enable onnx inference --- app.py | 2 +- utils/base_segmenter.py | 13 ++++++++++--- utils/interact_tools.py | 33 ++++++++++++++++++++------------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/app.py b/app.py index c7be510..99f8b55 100644 --- a/app.py +++ b/app.py @@ -56,7 +56,7 @@ def get_prompt(click_state, click_input): "prompt_type": ["click"], "input_point": click_state[0], "input_label": click_state[1], - "multimask_output": "True", + "multimask_output": "False", } return prompt diff --git a/utils/base_segmenter.py b/utils/base_segmenter.py index 7f54507..288a52a 100644 --- a/utils/base_segmenter.py +++ b/utils/base_segmenter.py @@ -24,8 +24,10 @@ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device="c from mobile_sam import sam_model_registry, SamPredictor from onnxruntime import InferenceSession self.ort_session = InferenceSession(sam_onnx_checkpoint) + self.predict = self.predict_onnx else: from segment_anything import sam_model_registry, SamPredictor + self.predict = self.predict_pt self.model = sam_model_registry[model_type](checkpoint=sam_pt_checkpoint) self.model.to(device=self.device) @@ -51,7 +53,7 @@ def reset_image(self): self.predictor.reset_image() self.embedded = False - def predict(self, prompts, mode, multimask=True): + def predict_pt(self, prompts, mode, multimask=True): """ image: numpy array, h, w, 3 prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input' @@ -115,17 +117,20 @@ def predict_onnx(self, prompts, mode, multimask=True): "orig_im_size": prompts["orig_im_size"], } masks, scores, logits = self.ort_session.run(None, ort_inputs) + masks = masks > self.predictor.model.mask_threshold elif mode == "mask": ort_inputs = { "image_embeddings": self.image_embedding, - "point_coords": prompts["point_coords"], + "point_coords": np.zeros((len(prompts["point_labels"]), 2), dtype=np.float32), "point_labels": prompts["point_labels"], "mask_input": prompts["mask_input"], "has_mask_input": np.ones(1, dtype=np.float32), "orig_im_size": prompts["orig_im_size"], } masks, scores, logits = self.ort_session.run(None, ort_inputs) + masks = masks > self.predictor.model.mask_threshold + elif mode == "both": # both ort_inputs = { "image_embeddings": self.image_embedding, @@ -136,7 +141,9 @@ def predict_onnx(self, prompts, mode, multimask=True): "orig_im_size": prompts["orig_im_size"], } masks, scores, logits = self.ort_session.run(None, ort_inputs) + masks = masks > self.predictor.model.mask_threshold + else: raise ("Not implement now!") # masks (n, h, w), scores (n,), logits (n, 256, 256) - return masks, scores, logits + return masks[0], scores[0], logits[0] diff --git a/utils/interact_tools.py b/utils/interact_tools.py index 4f2fc05..02a3a33 100644 --- a/utils/interact_tools.py +++ b/utils/interact_tools.py @@ -23,6 +23,7 @@ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device): """ self.sam_controler = BaseSegmenter(sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device) + self.onnx = model_type == "vit_t" def first_frame_click( self, @@ -38,32 +39,38 @@ def first_frame_click( """ # self.sam_controler.set_image(image) neg_flag = labels[-1] - if neg_flag == 1: - # find neg + + if self.onnx: + onnx_coord = np.concatenate([points, np.array([[0.0, 0.0]])], axis=0)[None, :, :] + onnx_label = np.concatenate([labels, np.array([-1])], axis=0)[None, :].astype(np.float32) + onnx_coord = self.sam_controler.predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) + prompts = { + "point_coords": onnx_coord, + "point_labels": onnx_label, + "orig_im_size": np.array(image.shape[:2], dtype=np.float32), + } + + else: prompts = { "point_coords": points, "point_labels": labels, - "orig_im_size": image.shape[:2], } + + if neg_flag == 1: + # find positive masks, scores, logits = self.sam_controler.predict( prompts, "point", multimask ) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] - prompts = { - "point_coords": points, - "point_labels": labels, - "mask_input": logit[None, :, :], - } + + prompts["mask_input"] = np.expand_dims(logit[None, :, :], 0) masks, scores, logits = self.sam_controler.predict( prompts, "both", multimask ) mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] + else: - # find positive - prompts = { - "point_coords": points, - "point_labels": labels, - } + # find neg masks, scores, logits = self.sam_controler.predict( prompts, "point", multimask )