From 591bc90d87eafb825a116f8189253cc5cb67bfc3 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Mon, 10 Jul 2023 14:29:38 +0200 Subject: [PATCH 1/4] Add MobileSAM support, allow to select device --- .gitignore | 1 + app.py | 15 ++++++++++----- tools/base_segmenter.py | 4 ++-- track_anything.py | 3 ++- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 029399d..ca65d89 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ test_sample/ result/ vots/ vots.py +iris_conf.yaml diff --git a/app.py b/app.py index 870fae0..32318c2 100644 --- a/app.py +++ b/app.py @@ -28,6 +28,9 @@ def download_checkpoint(url, folder, filename): if not os.path.exists(filepath): print("download checkpoints ......") + + if url is None: + raise FileNotFoundError(f"Model checkpoint {folder}/{filename} does not exist and it cannot be downloaded automatically") response = requests.get(url, stream=True) with open(filepath, "wb") as f: for chunk in response.iter_content(chunk_size=8192): @@ -358,12 +361,14 @@ def generate_video_from_frames(frames, output_path, fps=30): SAM_checkpoint_dict = { 'vit_h': "sam_vit_h_4b8939.pth", 'vit_l': "sam_vit_l_0b3195.pth", - "vit_b": "sam_vit_b_01ec64.pth" + "vit_b": "sam_vit_b_01ec64.pth", + "vit_t": "mobile_sam.pt" } SAM_checkpoint_url_dict = { 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", - 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" + 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + 'vit_t': None } sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type] sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type] @@ -378,7 +383,7 @@ def generate_video_from_frames(frames, output_path, fps=30): xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint) e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint) args.port = 12212 -args.device = "cuda:3" +# args.device = "cuda:3" # args.mask_save = True # initialize sam, xmem, e2fgvi models @@ -598,5 +603,5 @@ def generate_video_from_frames(frames, output_path, fps=30): # cache_examples=True, ) iface.queue(concurrency_count=1) -iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0") -# iface.launch(debug=True, enable_queue=True) \ No newline at end of file +iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name=args.server_name) +# iface.launch(debug=True, enable_queue=True) diff --git a/tools/base_segmenter.py b/tools/base_segmenter.py index 2b975bb..d02e0fd 100644 --- a/tools/base_segmenter.py +++ b/tools/base_segmenter.py @@ -4,7 +4,7 @@ from PIL import Image, ImageDraw, ImageOps import numpy as np from typing import Union -from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +from mobile_sam import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator import matplotlib.pyplot as plt import PIL from .mask_painter import mask_painter @@ -18,7 +18,7 @@ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): model_type: vit_b, vit_l, vit_h """ print(f"Initializing BaseSegmenter to {device}") - assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h' + assert model_type in ['vit_b', 'vit_l', 'vit_h', 'vit_t'], 'model_type must be vit_b, vit_l, vit_h or vit_t (for MobileSAM)' self.device = device self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 diff --git a/track_anything.py b/track_anything.py index 5275252..b56c7b4 100644 --- a/track_anything.py +++ b/track_anything.py @@ -63,6 +63,7 @@ def parse_augment(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default="cuda:0") parser.add_argument('--sam_model_type', type=str, default="vit_h") + parser.add_argument('--server-name', type=str, help="only useful when running gradio applications", default="0.0.0.0") parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications") parser.add_argument('--debug', action="store_true") parser.add_argument('--mask_save', default=False) @@ -93,4 +94,4 @@ def parse_augment(): - \ No newline at end of file + From 44b5d0a8458583ca997d162bb5c77cef89177721 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Mon, 10 Jul 2023 14:33:26 +0200 Subject: [PATCH 2/4] Update README.md --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e364ce4..9dd1dbd 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ ## :rocket: Updates +- 2023/07/10: [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) support is available, resulting in a faster experience on CPU and slow devices. To use it add the `mobile_sam.pt` to the checkpoints folder and run `app.py` with `--sam_model_type vit_t`. + - 2023/05/02: We uploaded tutorials in steps :world_map:. Check [HERE](./doc/tutorials.md) for more details. - 2023/04/29: We improved inpainting by decoupling GPU memory usage and video length. Now Track-Anything can inpaint videos with any length! :smiley_cat: Check [HERE](https://github.com/gaomingqi/Track-Anything/issues/4#issuecomment-1528198165) for our GPU memory requirements. @@ -77,7 +79,8 @@ pip install -r requirements.txt # Run the Track-Anything gradio demo. python app.py --device cuda:0 -# python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage +# python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage +# python app.py --device cuda:0 --sam_model_type vit_t # to use MobileSAM ``` From 36599f5740c9dd59ed7db19acbf971346e6f0c13 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Mon, 10 Jul 2023 14:34:50 +0200 Subject: [PATCH 3/4] Update requirements.txt --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dc6195a..007113a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ hickle tensorboard numpy git+https://github.com/facebookresearch/segment-anything.git +pip install git+https://github.com/ChaoningZhang/MobileSAM.git gradio opencv-python matplotlib @@ -13,4 +14,4 @@ pyyaml av openmim tqdm -psutil \ No newline at end of file +psutil From 83c3a67c7b85e530975e4c42b0c513553b51ca65 Mon Sep 17 00:00:00 2001 From: JunkyByte Date: Fri, 8 Sep 2023 17:47:59 +0200 Subject: [PATCH 4/4] fix --- app.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index 32318c2..97c67f9 100644 --- a/app.py +++ b/app.py @@ -368,7 +368,7 @@ def generate_video_from_frames(frames, output_path, fps=30): 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - 'vit_t': None + 'vit_t': "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt" } sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type] sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type] diff --git a/requirements.txt b/requirements.txt index 007113a..a1d3a54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ hickle tensorboard numpy git+https://github.com/facebookresearch/segment-anything.git -pip install git+https://github.com/ChaoningZhang/MobileSAM.git +git+https://github.com/ChaoningZhang/MobileSAM.git gradio opencv-python matplotlib