diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 740bb4b0719c61..ed28aea3ff4252 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -832,6 +832,8 @@ title: Perceiver - local: model_doc/pix2struct title: Pix2Struct + - local: model_doc/sam2 + title: SAM2 - local: model_doc/sam title: Segment Anything - local: model_doc/siglip diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index beeea517fa3028..1d3a19a9ff7198 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -67,6 +67,7 @@ "ToolCollection", "launch_gradio_demo", "load_tool", + "stream_to_gradio", ], "audio_utils": [], "benchmark": [], @@ -681,6 +682,13 @@ "SamPromptEncoderConfig", "SamVisionConfig", ], + "models.sam2": [ + "Sam2Config", + "Sam2MaskDecoderConfig", + "Sam2Processor", + "Sam2PromptEncoderConfig", + "Sam2VisionConfig", + ], "models.seamless_m4t": [ "SeamlessM4TConfig", "SeamlessM4TFeatureExtractor", @@ -1179,6 +1187,7 @@ _import_structure["models.pvt"].extend(["PvtImageProcessor"]) _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) _import_structure["models.sam"].extend(["SamImageProcessor"]) + _import_structure["models.sam2"].extend(["Sam2ImageProcessor"]) _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) _import_structure["models.siglip"].append("SiglipImageProcessor") @@ -3096,6 +3105,12 @@ "SamPreTrainedModel", ] ) + _import_structure["models.sam2"].extend( + [ + "Sam2Model", + "Sam2PreTrainedModel", + ] + ) _import_structure["models.seamless_m4t"].extend( [ "SeamlessM4TCodeHifiGan", @@ -4733,6 +4748,7 @@ ToolCollection, launch_gradio_demo, load_tool, + stream_to_gradio, ) from .configuration_utils import PretrainedConfig @@ -5913,6 +5929,7 @@ from .models.pvt import PvtImageProcessor from .models.rt_detr import RTDetrImageProcessor from .models.sam import SamImageProcessor + from .models.sam2 import Sam2ImageProcessor from .models.segformer import SegformerFeatureExtractor, SegformerImageProcessor from .models.seggpt import SegGptImageProcessor from .models.siglip import SiglipImageProcessor @@ -7465,6 +7482,10 @@ SamModel, SamPreTrainedModel, ) + from .models.sam2 import ( + Sam2Model, + Sam2PreTrainedModel, + ) from .models.seamless_m4t import ( SeamlessM4TCodeHifiGan, SeamlessM4TForSpeechToSpeech, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index cc1e41b3fc4076..4ba4933eba9a5b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -201,6 +201,7 @@ rt_detr, rwkv, sam, + sam2, seamless_m4t, seamless_m4t_v2, segformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 512c1eaaf5e01a..1fadc7a92c9122 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -223,6 +223,7 @@ ("rt_detr_resnet", "RTDetrResNetConfig"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), + ("sam2", "Sam2Config"), ("seamless_m4t", "SeamlessM4TConfig"), ("seamless_m4t_v2", "SeamlessM4Tv2Config"), ("segformer", "SegformerConfig"), @@ -517,6 +518,7 @@ ("rt_detr_resnet", "RT-DETR-ResNet"), ("rwkv", "RWKV"), ("sam", "SAM"), + ("sam2", "SAM2"), ("seamless_m4t", "SeamlessM4T"), ("seamless_m4t_v2", "SeamlessM4Tv2"), ("segformer", "SegFormer"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 8bfc61b9bea349..6916dbc8f3df7b 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -120,6 +120,7 @@ ("resnet", ("ConvNextImageProcessor",)), ("rt_detr", "RTDetrImageProcessor"), ("sam", ("SamImageProcessor",)), + ("sam2", ("Sam2ImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("seggpt", ("SegGptImageProcessor",)), ("siglip", ("SiglipImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d096abf4342614..8a1308a26c9e8b 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -208,6 +208,7 @@ ("rt_detr", "RTDetrModel"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), + ("sam2", "Sam2Model"), ("seamless_m4t", "SeamlessM4TModel"), ("seamless_m4t_v2", "SeamlessM4Tv2Model"), ("segformer", "SegformerModel"), @@ -1280,6 +1281,7 @@ MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict( [ ("sam", "SamModel"), + ("sam2", "Sam2Model"), ] ) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 1ab136a1e74ca7..0f7663d93d07bb 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -83,6 +83,7 @@ ("pix2struct", "Pix2StructProcessor"), ("pop2piano", "Pop2PianoProcessor"), ("sam", "SamProcessor"), + ("sam2", "Sam2Processor"), ("seamless_m4t", "SeamlessM4TProcessor"), ("sew", "Wav2Vec2Processor"), ("sew-d", "Wav2Vec2Processor"), diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index de739c6e70044a..b61a89c0639a0f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -7719,6 +7719,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Sam2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Sam2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SeamlessM4TCodeHifiGan(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 19f8dc1b1d9c9e..178c10295c8bb4 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -527,6 +527,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class Sam2ImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class SegformerFeatureExtractor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/sam2/__init__.py b/tests/models/sam2/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/sam2/test_modeling_sam2.py b/tests/models/sam2/test_modeling_sam2.py new file mode 100644 index 00000000000000..9f12337771c7dd --- /dev/null +++ b/tests/models/sam2/test_modeling_sam2.py @@ -0,0 +1,733 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch SAM2 model.""" + +import gc +import unittest + +import requests + +from transformers import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig, Sam2VisionConfig, pipeline +from transformers.testing_utils import backend_empty_cache, require_torch, slow, torch_device +from transformers.utils import is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + from torch import nn + + from transformers import Sam2Model, SamProcessor + + +if is_vision_available(): + from PIL import Image + + +class Sam2PromptEncoderTester: + def __init__( + self, + hidden_size=32, + input_image_size=24, + patch_size=2, + mask_input_channels=4, + num_point_embeddings=4, + hidden_act="gelu", + ): + self.hidden_size = hidden_size + self.input_image_size = input_image_size + self.patch_size = patch_size + self.mask_input_channels = mask_input_channels + self.num_point_embeddings = num_point_embeddings + self.hidden_act = hidden_act + + def get_config(self): + return Sam2PromptEncoderConfig( + image_size=self.input_image_size, + patch_size=self.patch_size, + mask_input_channels=self.mask_input_channels, + hidden_size=self.hidden_size, + num_point_embeddings=self.num_point_embeddings, + hidden_act=self.hidden_act, + ) + + def prepare_config_and_inputs(self): + dummy_points = floats_tensor([self.batch_size, 3, 2]) + config = self.get_config() + + return config, dummy_points + + +class Sam2MaskDecoderTester: + def __init__( + self, + hidden_size=32, + hidden_act="relu", + mlp_dim=64, + num_hidden_layers=2, + num_attention_heads=4, + attention_downsam2ple_rate=2, + num_multimask_outputs=3, + iou_head_depth=3, + iou_head_hidden_dim=32, + layer_norm_eps=1e-6, + ): + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_dim = mlp_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_downsam2ple_rate = attention_downsam2ple_rate + self.num_multimask_outputs = num_multimask_outputs + self.iou_head_depth = iou_head_depth + self.iou_head_hidden_dim = iou_head_hidden_dim + self.layer_norm_eps = layer_norm_eps + + def get_config(self): + return Sam2MaskDecoderConfig( + hidden_size=self.hidden_size, + hidden_act=self.hidden_act, + mlp_dim=self.mlp_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + attention_downsam2ple_rate=self.attention_downsam2ple_rate, + num_multimask_outputs=self.num_multimask_outputs, + iou_head_depth=self.iou_head_depth, + iou_head_hidden_dim=self.iou_head_hidden_dim, + layer_norm_eps=self.layer_norm_eps, + ) + + def prepare_config_and_inputs(self): + config = self.get_config() + + dummy_inputs = { + "image_embedding": floats_tensor([self.batch_size, self.hidden_size]), + } + + return config, dummy_inputs + + +class Sam2ModelTester: + def __init__( + self, + parent, + hidden_size=36, + intermediate_size=72, + projection_dim=62, + output_channels=32, + num_hidden_layers=2, + num_attention_heads=4, + num_channels=3, + image_size=24, + patch_size=2, + hidden_act="gelu", + layer_norm_eps=1e-06, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + qkv_bias=True, + mlp_ratio=4.0, + use_abs_pos=True, + use_rel_pos=True, + rel_pos_zero_init=False, + window_size=14, + global_attn_indexes=[2, 5, 8, 11], + num_pos_feats=16, + mlp_dim=None, + batch_size=2, + ): + self.parent = parent + self.image_size = image_size + self.patch_size = patch_size + self.output_channels = output_channels + self.num_channels = num_channels + self.hidden_size = hidden_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.dropout = dropout + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.mlp_ratio = mlp_ratio + self.use_abs_pos = use_abs_pos + self.use_rel_pos = use_rel_pos + self.rel_pos_zero_init = rel_pos_zero_init + self.window_size = window_size + self.global_attn_indexes = global_attn_indexes + self.num_pos_feats = num_pos_feats + self.mlp_dim = mlp_dim + self.batch_size = batch_size + + # in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + self.prompt_encoder_tester = Sam2PromptEncoderTester() + self.mask_decoder_tester = Sam2MaskDecoderTester() + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + vision_config = Sam2VisionConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + projection_dim=self.projection_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + dropout=self.dropout, + attention_dropout=self.attention_dropout, + initializer_range=self.initializer_range, + initializer_factor=self.initializer_factor, + output_channels=self.output_channels, + qkv_bias=self.qkv_bias, + mlp_ratio=self.mlp_ratio, + use_abs_pos=self.use_abs_pos, + use_rel_pos=self.use_rel_pos, + rel_pos_zero_init=self.rel_pos_zero_init, + window_size=self.window_size, + global_attn_indexes=self.global_attn_indexes, + num_pos_feats=self.num_pos_feats, + mlp_dim=self.mlp_dim, + ) + + prompt_encoder_config = self.prompt_encoder_tester.get_config() + + mask_decoder_config = self.mask_decoder_tester.get_config() + + return Sam2Config( + vision_config=vision_config, + prompt_encoder_config=prompt_encoder_config, + mask_decoder_config=mask_decoder_config, + ) + + def create_and_check_model(self, config, pixel_values): + model = Sam2Model(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + self.parent.assertEqual(result.iou_scores.shape, (self.batch_size, 1, 3)) + self.parent.assertEqual(result.pred_masks.shape[:3], (self.batch_size, 1, 3)) + + def create_and_check_get_image_features(self, config, pixel_values): + model = Sam2Model(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.get_image_embeddings(pixel_values) + self.parent.assertEqual(result[0].shape, (self.output_channels, 12, 12)) + + def create_and_check_get_image_hidden_states(self, config, pixel_values): + model = Sam2Model(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=True, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + with torch.no_grad(): + result = model.vision_encoder( + pixel_values, + output_hidden_states=True, + return_dict=False, + ) + + # after computing the convolutional features + expected_hidden_states_shape = (self.batch_size, 12, 12, 36) + self.parent.assertEqual(len(result[1]), self.num_hidden_layers + 1) + self.parent.assertEqual(result[1][0].shape, expected_hidden_states_shape) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class Sam2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as SAM2's vision encoder does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (Sam2Model,) if is_torch_available() else () + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_torchscript = False + + @unittest.skip(reason="SAM2's vision encoder does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + def test_model_get_set_embeddings(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, nn.Linear)) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_get_image_features(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_features(*config_and_inputs) + + def test_image_hidden_states(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_get_image_hidden_states(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + expected_vision_attention_shape = ( + self.model_tester.batch_size * self.model_tester.num_attention_heads, + 196, + 196, + ) + expected_mask_decoder_attention_shape = (self.model_tester.batch_size, 1, 144, 32) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + vision_attentions = outputs.vision_attentions + self.assertEqual(len(vision_attentions), self.model_tester.num_hidden_layers) + + mask_decoder_attentions = outputs.mask_decoder_attentions + self.assertEqual(len(mask_decoder_attentions), self.model_tester.mask_decoder_tester.num_hidden_layers) + + self.assertListEqual( + list(vision_attentions[0].shape[-4:]), + list(expected_vision_attention_shape), + ) + + self.assertListEqual( + list(mask_decoder_attentions[0].shape[-4:]), + list(expected_mask_decoder_attention_shape), + ) + + @unittest.skip(reason="Sam2Model does not support training") + def test_training(self): + pass + + @unittest.skip(reason="Sam2Model does not support training") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Sam2Model has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_from_base(self): + pass + + @unittest.skip(reason="Sam2Model has no base class and is not available in MODEL_MAPPING") + def test_save_load_fast_init_to_base(self): + pass + + @unittest.skip(reason="Sam2Model does not support training") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="Hidden_states is tested in create_and_check_model tests") + def test_hidden_states_output(self): + pass + + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + # Use a slightly higher default tol to make the tests non-flaky + super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes) + + @slow + def test_model_from_pretrained(self): + model_name = "facebook/sam2-hiera-large" + model = Sam2Model.from_pretrained(model_name) + self.assertIsNotNone(model) + + +def prepare_image(): + img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +def prepare_dog_img(): + img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam2.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") + return raw_image + + +@slow +class Sam2ModelIntegrationTest(unittest.TestCase): + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied by PyTorch + gc.collect() + backend_empty_cache(torch_device) + + def test_inference_mask_generation_no_point(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + inputs = processor(images=raw_image, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.4515), atol=2e-4)) + self.assertTrue(torch.allclose(masks, torch.tensor([-4.1800, -3.4948, -3.4481]).to(torch_device), atol=2e-4)) + + def test_inference_mask_generation_one_point_one_bb(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[[650, 900, 1000, 1250]]] + input_points = [[[820, 1080]]] + + inputs = processor( + images=raw_image, input_boxes=input_boxes, input_points=input_points, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + masks = outputs.pred_masks[0, 0, 0, 0, :3] + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9566), atol=2e-4)) + self.assertTrue( + torch.allclose(masks, torch.tensor([-12.7729, -12.3665, -12.6061]).to(torch_device), atol=2e-4) + ) + + def test_inference_mask_generation_batched_points_batched_images(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_points = [ + [[[820, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + [[[510, 1080]], [[820, 1080]], [[820, 1080]], [[820, 1080]]], + ] + + inputs = processor(images=[raw_image, raw_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze().cpu() + masks = outputs.pred_masks[0, 0, 0, 0, :3].cpu() + + EXPECTED_SCORES = torch.tensor( + [ + [ + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + ], + [ + [0.3317, 0.7264, 0.7646], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + [0.6765, 0.9379, 0.8803], + ], + ] + ) + EXPECTED_MASKS = torch.tensor([-2.8550, -2.7988, -2.9625]) + self.assertTrue(torch.allclose(scores, EXPECTED_SCORES, atol=1e-3)) + self.assertTrue(torch.allclose(masks, EXPECTED_MASKS, atol=1e-3)) + + def test_inference_mask_generation_one_point_one_bb_zero(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + input_boxes = [[[620, 900, 1000, 1255]]] + input_points = [[[820, 1080]]] + labels = [[0]] + + inputs = processor( + images=raw_image, + input_boxes=input_boxes, + input_points=input_points, + input_labels=labels, + return_tensors="pt", + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7894), atol=1e-4)) + + def test_inference_mask_generation_one_point(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650]]] + input_labels = [[1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4)) + + # With no label + input_points = [[[400, 650]]] + + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9675), atol=1e-4)) + + def test_inference_mask_generation_two_points(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]]] + input_labels = [[1, 1]] + + inputs = processor( + images=raw_image, input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4)) + + # no labels + inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.9762), atol=1e-4)) + + def test_inference_mask_generation_two_points_batched(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = [[[400, 650], [800, 650]], [[400, 650]]] + input_labels = [[1, 1], [1]] + + inputs = processor( + images=[raw_image, raw_image], input_points=input_points, input_labels=input_labels, return_tensors="pt" + ).to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[0][-1], torch.tensor(0.9762), atol=1e-4)) + self.assertTrue(torch.allclose(scores[1][-1], torch.tensor(0.9637), atol=1e-4)) + + def test_inference_mask_generation_one_box(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_boxes = [[[75, 275, 1725, 850]]] + + inputs = processor(images=raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores[-1], torch.tensor(0.7937), atol=1e-4)) + + def test_inference_mask_generation_batched_image_one_point(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + raw_dog_image = prepare_dog_img() + + input_points = [[[820, 1080]], [[220, 470]]] + + inputs = processor(images=[raw_image, raw_dog_image], input_points=input_points, return_tensors="pt").to( + torch_device + ) + + with torch.no_grad(): + outputs = model(**inputs) + scores_batched = outputs.iou_scores.squeeze() + + input_points = [[[220, 470]]] + + inputs = processor(images=raw_dog_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + scores_single = outputs.iou_scores.squeeze() + self.assertTrue(torch.allclose(scores_batched[1, :], scores_single, atol=1e-4)) + + def test_inference_mask_generation_two_points_point_batch(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + input_points = torch.Tensor([[[400, 650]], [[220, 470]]]).cpu() # fmt: skip + + input_points = input_points.unsqueeze(0) + + inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 2, 3)) + torch.testing.assert_close( + iou_scores, torch.tensor([[[0.9105, 0.9825, 0.9675], [0.7646, 0.7943, 0.7774]]]), atol=1e-4, rtol=1e-4 + ) + + def test_inference_mask_generation_three_boxes_point_batch(self): + model = Sam2Model.from_pretrained("facebook/sam2-vit-base") + processor = SamProcessor.from_pretrained("facebook/sam2-vit-base") + + model.to(torch_device) + model.eval() + + raw_image = prepare_image() + + # fmt: off + input_boxes = torch.Tensor([[[620, 900, 1000, 1255]], [[75, 275, 1725, 850]], [[75, 275, 1725, 850]]]).cpu() + EXPECTED_IOU = torch.tensor([[[0.9773, 0.9881, 0.9522], + [0.5996, 0.7661, 0.7937], + [0.5996, 0.7661, 0.7937]]]) + # fmt: on + input_boxes = input_boxes.unsqueeze(0) + + inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + outputs = model(**inputs) + + iou_scores = outputs.iou_scores.cpu() + self.assertTrue(iou_scores.shape == (1, 3, 3)) + torch.testing.assert_close(iou_scores, EXPECTED_IOU, atol=1e-4, rtol=1e-4) + + def test_dummy_pipeline_generation(self): + generator = pipeline("mask-generation", model="facebook/sam2-vit-base", device=torch_device) + raw_image = prepare_image() + + _ = generator(raw_image, points_per_batch=64) diff --git a/tests/models/sam2/test_processor_sam2.py b/tests/models/sam2/test_processor_sam2.py new file mode 100644 index 00000000000000..0146476f098782 --- /dev/null +++ b/tests/models/sam2/test_processor_sam2.py @@ -0,0 +1,151 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers.testing_utils import ( + is_pt_tf_cross_test, + require_tf, + require_torch, + require_torchvision, + require_vision, +) +from transformers.utils import is_tf_available, is_torch_available, is_vision_available + + +if is_vision_available(): + from PIL import Image + + from transformers import AutoProcessor, Sam2ImageProcessor, Sam2Processor + +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + + +@require_vision +@require_torchvision +class Sam2ProcessorTest(unittest.TestCase): + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = Sam2ImageProcessor() + processor = Sam2Processor(image_processor) + processor.save_pretrained(self.tmpdirname) + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + return image_inputs + + def prepare_mask_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + mask_inputs = [np.random.randint(255, size=(30, 400), dtype=np.uint8)] + mask_inputs = [Image.fromarray(x) for x in mask_inputs] + return mask_inputs + + def test_save_load_pretrained_additional_features(self): + processor = Sam2Processor(image_processor=self.get_image_processor()) + processor.save_pretrained(self.tmpdirname) + + image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0) + + processor = Sam2Processor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, Sam2ImageProcessor) + + def test_image_processor_no_masks(self): + image_processor = self.get_image_processor() + + processor = Sam2Processor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + + input_feat_extract = image_processor(image_input, return_tensors="np") + input_processor = processor(images=image_input, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + for image in input_feat_extract.pixel_values: + self.assertEqual(image.shape, (3, 1024, 1024)) + + for original_size in input_feat_extract.original_sizes: + np.testing.assert_array_equal(original_size, np.array([30, 400])) + + for reshaped_input_size in input_feat_extract.reshaped_input_sizes: + np.testing.assert_array_equal( + reshaped_input_size, np.array([77, 1024]) + ) # reshaped_input_size value is before padding + + def test_image_processor_with_masks(self): + image_processor = self.get_image_processor() + + processor = Sam2Processor(image_processor=image_processor) + + image_input = self.prepare_image_inputs() + mask_input = self.prepare_mask_inputs() + + input_feat_extract = image_processor(images=image_input, segmentation_maps=mask_input, return_tensors="np") + input_processor = processor(images=image_input, segmentation_maps=mask_input, return_tensors="np") + + for key in input_feat_extract.keys(): + self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2) + + for label in input_feat_extract.labels: + self.assertEqual(label.shape, (256, 256)) + + @require_torch + def test_post_process_masks(self): + image_processor = self.get_image_processor() + + processor = Sam2Processor(image_processor=image_processor) + dummy_masks = [torch.ones((1, 3, 5, 5))] + + original_sizes = [[1764, 2646]] + + reshaped_input_size = [[683, 1024]] + masks = processor.post_process_masks(dummy_masks, original_sizes, reshaped_input_size) + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + masks = processor.post_process_masks( + dummy_masks, torch.tensor(original_sizes), torch.tensor(reshaped_input_size) + ) + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + # should also work with np + dummy_masks = [np.ones((1, 3, 5, 5))] + masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size)) + + self.assertEqual(masks[0].shape, (1, 3, 1764, 2646)) + + dummy_masks = [[1, 0], [0, 1]] + with self.assertRaises(ValueError): + masks = processor.post_process_masks(dummy_masks, np.array(original_sizes), np.array(reshaped_input_size))