Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for MobileSAM for snappier Gradio experience #100

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ test_sample/
result/
vots/
vots.py
iris_conf.yaml
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

## :rocket: Updates

- 2023/07/10: [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) support is available, resulting in a faster experience on CPU and slow devices. To use it add the `mobile_sam.pt` to the checkpoints folder and run `app.py` with `--sam_model_type vit_t`.

- 2023/05/02: We uploaded tutorials in steps :world_map:. Check [HERE](./doc/tutorials.md) for more details.

- 2023/04/29: We improved inpainting by decoupling GPU memory usage and video length. Now Track-Anything can inpaint videos with any length! :smiley_cat: Check [HERE](https://github.com/gaomingqi/Track-Anything/issues/4#issuecomment-1528198165) for our GPU memory requirements.
Expand Down Expand Up @@ -77,7 +79,8 @@ pip install -r requirements.txt

# Run the Track-Anything gradio demo.
python app.py --device cuda:0
# python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage
# python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage
# python app.py --device cuda:0 --sam_model_type vit_t # to use MobileSAM
```


Expand Down
15 changes: 10 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def download_checkpoint(url, folder, filename):

if not os.path.exists(filepath):
print("download checkpoints ......")

if url is None:
raise FileNotFoundError(f"Model checkpoint {folder}/{filename} does not exist and it cannot be downloaded automatically")
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
Expand Down Expand Up @@ -358,12 +361,14 @@ def generate_video_from_frames(frames, output_path, fps=30):
SAM_checkpoint_dict = {
'vit_h': "sam_vit_h_4b8939.pth",
'vit_l': "sam_vit_l_0b3195.pth",
"vit_b": "sam_vit_b_01ec64.pth"
"vit_b": "sam_vit_b_01ec64.pth",
"vit_t": "mobile_sam.pt"
}
SAM_checkpoint_url_dict = {
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
'vit_t': "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
}
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
Expand All @@ -378,7 +383,7 @@ def generate_video_from_frames(frames, output_path, fps=30):
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
args.port = 12212
args.device = "cuda:3"
# args.device = "cuda:3"
# args.mask_save = True

# initialize sam, xmem, e2fgvi models
Expand Down Expand Up @@ -598,5 +603,5 @@ def generate_video_from_frames(frames, output_path, fps=30):
# cache_examples=True,
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
# iface.launch(debug=True, enable_queue=True)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name=args.server_name)
# iface.launch(debug=True, enable_queue=True)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ hickle
tensorboard
numpy
git+https://github.com/facebookresearch/segment-anything.git
git+https://github.com/ChaoningZhang/MobileSAM.git
gradio
opencv-python
matplotlib
pyyaml
av
openmim
tqdm
psutil
psutil
4 changes: 2 additions & 2 deletions tools/base_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from PIL import Image, ImageDraw, ImageOps
import numpy as np
from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from mobile_sam import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL
from .mask_painter import mask_painter
Expand All @@ -18,7 +18,7 @@ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'):
model_type: vit_b, vit_l, vit_h
"""
print(f"Initializing BaseSegmenter to {device}")
assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h'
assert model_type in ['vit_b', 'vit_l', 'vit_h', 'vit_t'], 'model_type must be vit_b, vit_l, vit_h or vit_t (for MobileSAM)'

self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
Expand Down
3 changes: 2 additions & 1 deletion track_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def parse_augment():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default="cuda:0")
parser.add_argument('--sam_model_type', type=str, default="vit_h")
parser.add_argument('--server-name', type=str, help="only useful when running gradio applications", default="0.0.0.0")
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
parser.add_argument('--debug', action="store_true")
parser.add_argument('--mask_save', default=False)
Expand Down Expand Up @@ -93,4 +94,4 @@ def parse_augment():