Skip to content

Commit

Permalink
Add OpenVINO backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-T-G committed Sep 30, 2023
1 parent 26e8891 commit dcec5d5
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 21 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Blur Anything is an adaptation of the excellent [Track Anything](https://github.
</a>
</div>

https://github.com/Y-T-G/Blur-Anything/raw/main/assets/sample-1-blurred-stacked.mp4

## Get Started
```shell
# Clone the repository:
Expand All @@ -18,15 +20,22 @@ cd Blur-Anything
pip install -r requirements.txt

# Run the Blur-Anything gradio demo.
python app.py --device cuda:0
# python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage
python app.py --device [cpu|cuda:0|cuda:1|...] --sam_model_type [vit_t| vit_b|vit_h| vit_l] [--backend [onnx|openvino]]
```
## Features
- FastSAM with ONNX and OpenVINO support.
- Lower memory usage.
## To Do
- [x] Add a gradio demo
- [ ] Add support to use YouTube video URL
- [ ] Add option to completely black out the object
- [ ] Convert XMem to ONNX
## Acknowledgements
The project is an adaptation of [Track Anything](https://github.com/gaomingqi/Track-Anything) which is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem).
Thanks to [PIMS](https://github.com/soft-matter/pims) which is used to process video files while keeping memory usage low.
6 changes: 4 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
sys.path.append(sys.path[0] + "/tracker/model")

from track_anything import TrackingAnything
from track_anything import parse_augment
from track_anything import parse_argument

from utils.painter import mask_painter
from utils.blur import blur_frames_and_write
Expand Down Expand Up @@ -501,7 +501,7 @@ def convert_to_onnx(args, checkpoint, quantized=True):


# args, defined in track_anything.py
args = parse_augment()
args = parse_argument()

# check and download checkpoints if needed
SAM_checkpoint_dict = {
Expand Down Expand Up @@ -529,6 +529,8 @@ def convert_to_onnx(args, checkpoint, quantized=True):
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)

if args.sam_model_type == "vit_t":
if args.backend not in ("", "onnx", "openvino"):
print("vit_t only supports `onnx` and `openvino` backends. Falling back to `onnx`")
sam_onnx_checkpoint = convert_to_onnx(args, sam_pt_checkpoint, quantized=True)
else:
sam_onnx_checkpoint = ""
Expand Down
3 changes: 3 additions & 0 deletions assets/sample-1-blurred-stacked.mp4
Git LFS file not shown
15 changes: 13 additions & 2 deletions track_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, xmem_checkpoint, args
self.sam_onnx_checkpoint = sam_onnx_checkpoint
self.xmem_checkpoint = xmem_checkpoint
self.samcontroler = SamControler(
self.sam_pt_checkpoint, self.sam_onnx_checkpoint, args.sam_model_type, args.device
self.sam_pt_checkpoint, self.sam_onnx_checkpoint,
args.sam_model_type, args.backend, args.device
)
self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)

Expand Down Expand Up @@ -69,9 +70,15 @@ def generator(
return masks, logits, painted_images


def parse_augment():
def parse_argument():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument(
"--backend",
type=str,
default="",
choices=["onnx", "openvino"],
help="Specify either `onnx` or `openvino` backend for vit_t model. Not applicable for other models.")
parser.add_argument("--sam_model_type", type=str, default="vit_t")
parser.add_argument(
"--port",
Expand All @@ -83,6 +90,10 @@ def parse_augment():
parser.add_argument("--mask_save", default=False)
args = parser.parse_args()

if args.backend in ("onnx", "openvino") and args.sam_model_type != "vit_t":
print(f" {args.sam_model_type} does not support `onnx` or `openvino` \
backend. Using PyTorch backend...")

if args.debug:
print(args)
return args
54 changes: 41 additions & 13 deletions utils/base_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@


class BaseSegmenter:
def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device="cuda:0"):
def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type,
backend, device="cuda:0"):
"""
device: model device
SAM_checkpoint: path of SAM checkpoint
model_type: vit_b, vit_l, vit_h, vit_t
"""
print(f"Initializing BaseSegmenter to {device}")
print(f"Initializing BaseSegmenter to {device} with {backend} backend")
assert model_type in [
"vit_b",
"vit_l",
Expand All @@ -20,11 +21,25 @@ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device="c
self.device = device
self.torch_dtype = torch.float16 if "cuda" in device else torch.float32

self.backend = backend

if (model_type == "vit_t"):
from mobile_sam import sam_model_registry, SamPredictor
from onnxruntime import InferenceSession
self.ort_session = InferenceSession(sam_onnx_checkpoint)
self.predict = self.predict_onnx
if self.backend == "onnx":
from onnxruntime import InferenceSession
self.ort_session = InferenceSession(sam_onnx_checkpoint)
self.predict = self.predict_onnx_ov
elif self.backend == "openvino":
from openvino import Core
ov_core = Core()
ov_model = ov_core.read_model(sam_onnx_checkpoint)
ov_device = "CPU" if device == "cpu" else "AUTO"
self.ir_model = ov_core.compile_model(model=ov_model,
device_name=ov_device)
self.ov_ir = self.ir_model.create_infer_request()
self.predict = self.predict_onnx_ov
else:
raise ("Unsupported Backend")
else:
from segment_anything import sam_model_registry, SamPredictor
self.predict = self.predict_pt
Expand Down Expand Up @@ -55,6 +70,8 @@ def reset_image(self):

def predict_pt(self, prompts, mode, multimask=True):
"""
Prediction using PyTorch backend.
image: numpy array, h, w, 3
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
prompts['point_coords']: numpy array [N,2]
Expand Down Expand Up @@ -91,8 +108,10 @@ def predict_pt(self, prompts, mode, multimask=True):
# masks (n, h, w), scores (n,), logits (n, 256, 256)
return masks, scores, logits

def predict_onnx(self, prompts, mode, multimask=True):
def predict_onnx_ov(self, prompts, mode, multimask=True):
"""
Prediction using ONNX or OpenVINO backend.
image: numpy array, h, w, 3
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
prompts['point_coords']: numpy array [N,2]
Expand All @@ -108,42 +127,51 @@ def predict_onnx(self, prompts, mode, multimask=True):
assert mode in ["point", "mask", "both"], "mode must be point, mask, or both"

if mode == "point":
ort_inputs = {
inputs = {
"image_embeddings": self.image_embedding,
"point_coords": prompts["point_coords"],
"point_labels": prompts["point_labels"],
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.zeros(1, dtype=np.float32),
"orig_im_size": prompts["orig_im_size"],
}
masks, scores, logits = self.ort_session.run(None, ort_inputs)
if self.backend == "onnx":
masks, scores, logits = self.ort_session.run(None, inputs)
elif self.backend == "openvino":
masks, scores, logits = self.ov_ir.infer(inputs).to_tuple()
masks = masks > self.predictor.model.mask_threshold

elif mode == "mask":
ort_inputs = {
inputs = {
"image_embeddings": self.image_embedding,
"point_coords": np.zeros((len(prompts["point_labels"]), 2), dtype=np.float32),
"point_labels": prompts["point_labels"],
"mask_input": prompts["mask_input"],
"has_mask_input": np.ones(1, dtype=np.float32),
"orig_im_size": prompts["orig_im_size"],
}
masks, scores, logits = self.ort_session.run(None, ort_inputs)
if self.backend == "onnx":
masks, scores, logits = self.ort_session.run(None, inputs)
elif self.backend == "openvino":
masks, scores, logits = self.ov_ir.infer(inputs).to_tuple()
masks = masks > self.predictor.model.mask_threshold

elif mode == "both": # both
ort_inputs = {
inputs = {
"image_embeddings": self.image_embedding,
"point_coords": prompts["point_coords"],
"point_labels": prompts["point_labels"],
"mask_input": prompts["mask_input"],
"has_mask_input": np.ones(1, dtype=np.float32),
"orig_im_size": prompts["orig_im_size"],
}
masks, scores, logits = self.ort_session.run(None, ort_inputs)
if self.backend == "onnx":
masks, scores, logits = self.ort_session.run(None, inputs)
elif self.backend == "openvino":
masks, scores, logits = self.ov_ir.infer(inputs).to_tuple()
masks = masks > self.predictor.model.mask_threshold

else:
raise ("Not implement now!")
# masks (n, h, w), scores (n,), logits (n, 256, 256)
return masks[0], scores[0], logits[0]
return masks[0], scores[0], logits[0]
4 changes: 2 additions & 2 deletions utils/interact_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@


class SamControler:
def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device):
def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, backend, device):
"""
initialize sam controler
"""

self.sam_controler = BaseSegmenter(sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device)
self.sam_controler = BaseSegmenter(sam_pt_checkpoint, sam_onnx_checkpoint, model_type, backend, device)
self.onnx = model_type == "vit_t"

def first_frame_click(
Expand Down

0 comments on commit dcec5d5

Please sign in to comment.