Skip to content

Commit

Permalink
Added test for 0 images
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Oct 23, 2023
1 parent ea005d9 commit 8e52a4f
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions mytests/multimodal_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,32 +187,49 @@ def test_get_shifted_multimodal_position_ids():
def test_get_proxy_tokens():
eps = 1e-7

# Case 0: 0 images # TODO
vision_positions = torch.tensor([[-1]])
vision_seq_len = 1
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0)
assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[0]])) < eps)

# Case 1: 1 image
vision_positions = torch.tensor([[1]])
vision_seq_len = 1
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0)
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0)
assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100]])) < eps)

# Case 2: multiple images
vision_positions = torch.tensor([[1, 2]])
vision_seq_len = 1
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0)
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0)
assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -101]])) < eps)

# Case 3: multiple images, seq len > 1
vision_positions = torch.tensor([[3, 5]])
vision_seq_len = 3
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0)
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0)
assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -100, -100, -101, -101, -101]])) < eps)

# Case 4: multiple samples with padding
vision_positions = torch.tensor([[1, 2], [3, -1]])
vision_seq_len = 2
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, text_pad_id=0)
proxy_vision_tokens = utils.get_proxy_tokens(position_ids=vision_positions, seq_length=vision_seq_len, pad_id=0)
assert torch.all(torch.abs(proxy_vision_tokens - torch.tensor([[-100, -100, -101, -101], [-100, -100, 0, 0]])) < eps)

def test_get_multimodal_mask():

# Case 0: 0 images
tokens = torch.tensor([[2, 2]])
multimodal_mask = utils.get_multimodal_mask(tokens, text_pad_id=0)
correct_mask = torch.tensor([
[
[False, False],
[False, False]
]
])
assert torch.all(multimodal_mask == correct_mask)

# Case 1: 1 image
tokens = torch.tensor([[-100]])
multimodal_mask = utils.get_multimodal_mask(tokens, text_pad_id=0)
Expand Down

0 comments on commit 8e52a4f

Please sign in to comment.