From fbfb27772cc88e6819e73afaf92560dec760e859 Mon Sep 17 00:00:00 2001 From: kshitij Date: Wed, 18 Oct 2023 22:35:20 +0200 Subject: [PATCH] fixed labels utils --- megatron/utils.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/megatron/utils.py b/megatron/utils.py index 0860516a3..6271428ba 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -265,7 +265,6 @@ def get_multimodal_attn_mask( def get_multimodal_ltor_masks_and_position_ids( input_info, - labels, input_seq_length, eod_token, bos_token, @@ -276,8 +275,7 @@ def get_multimodal_ltor_masks_and_position_ids( """Build masks and position id for left to right model.""" # Extract batch size and label length - batch_size = labels.shape[0] - label_length = labels.shape[1] + batch_size = input_info["text"]["input"].shape[0] shifted_text_positions, shifted_vision_positions, shited_audio_positions = get_shifted_multimodal_position_ids(input_info, position_pad_id=-1) @@ -300,18 +298,31 @@ def get_multimodal_ltor_masks_and_position_ids( text_pad_token_id=pad_token, concat_data=concat_data, attn_uses_sequence_id=attn_uses_sequence_id, - device=labels.device, + device=input_info["text"]["input"].device, ) + # Prepare labels + vision_labels = torch.repeat_interleave(input_info["vision"]["labels"], input_info["vision"]["seq_length"], dim=1) + + # Concatenate vision proxy tokens with text tokens + concat_labels = torch.cat((input_info["text"]["labels"], vision_labels), dim=1) + + # Rearrrange tokens in interleaved format using shifted multimodal position ids + interleaved_labels = torch.zeros_like(concat_labels, dtype=concat_labels.dtype, device=concat_labels.device) + labels = interleaved_labels.scatter_(1, shifted_multimodal_position_ids, concat_labels)[:,:input_seq_length] + # Loss mask. + label_length = labels.shape[1] loss_mask = torch.ones((batch_size, label_length), dtype=torch.float, device=labels.device) if pad_token is not None: loss_mask[labels == pad_token] = 0.0 + loss_mask[labels == bos_token] = 0.0 + loss_mask[labels == eod_token] = 0.0 - position_ids = torch.arange(input_seq_length, dtype=torch.long, device=labels.device) + position_ids = torch.arange(input_seq_length, dtype=torch.long, device=labels.device) # FIX THIS #TODO position_ids = position_ids.unsqueeze(0).expand(batch_size, input_seq_length) - return attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids + return attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids, labels def local_rank(): """Local rank of process"""