diff --git a/mytests/multimodal_utils_test.py b/mytests/multimodal_utils_test.py index 4cbd794c3..d7673a60e 100644 --- a/mytests/multimodal_utils_test.py +++ b/mytests/multimodal_utils_test.py @@ -187,32 +187,49 @@ def test_get_shifted_multimodal_position_ids(): def test_get_proxy_tokens(): eps = 1e-7 + # Case 0: 0 images # TODO + vision_positions = torch.tensor([[-1]]) + vision_seq_len = 1 + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) + assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[0]])) < eps) + # Case 1: 1 image vision_positions = torch.tensor([[1]]) vision_seq_len = 1 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100]])) < eps) # Case 2: multiple images vision_positions = torch.tensor([[1, 2]]) vision_seq_len = 1 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -101]])) < eps) # Case 3: multiple images, seq len > 1 vision_positions = torch.tensor([[3, 5]]) vision_seq_len = 3 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -100, -100, -101, -101, -101]])) < eps) # Case 4: multiple samples with padding vision_positions = torch.tensor([[1, 2], [3, -1]]) vision_seq_len = 2 - proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0) + proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0) assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -100, -101, -101], [-100, -100, 0, 0]])) < eps) def test_get_multimodal_mask(): + # Case 0: 0 images + tokens = torch.tensor([[2, 2]]) + multimodal_mask = utils.get_multimodal_mask(tokens, text_pad_id=0) + correct_mask = torch.tensor([ + [ + [False, False], + [False, False] + ] + ]) + assert torch.all(multimodal_mask == correct_mask) + # Case 1: 1 image tokens = torch.tensor([[-100]]) multimodal_mask = utils.get_multimodal_mask(tokens, text_pad_id=0)