Skip to content

Commit

Permalink
fixed labels utils
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Oct 18, 2023
1 parent 3a46dde commit fbfb277
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def get_multimodal_attn_mask(

def get_multimodal_ltor_masks_and_position_ids(
input_info,
labels,
input_seq_length,
eod_token,
bos_token,
Expand All @@ -276,8 +275,7 @@ def get_multimodal_ltor_masks_and_position_ids(
"""Build masks and position id for left to right model."""

# Extract batch size and label length
batch_size = labels.shape[0]
label_length = labels.shape[1]
batch_size = input_info["text"]["input"].shape[0]

shifted_text_positions, shifted_vision_positions, shited_audio_positions = get_shifted_multimodal_position_ids(input_info, position_pad_id=-1)

Expand All @@ -300,18 +298,31 @@ def get_multimodal_ltor_masks_and_position_ids(
text_pad_token_id=pad_token,
concat_data=concat_data,
attn_uses_sequence_id=attn_uses_sequence_id,
device=labels.device,
device=input_info["text"]["input"].device,
)

# Prepare labels
vision_labels = torch.repeat_interleave(input_info["vision"]["labels"], input_info["vision"]["seq_length"], dim=1)

# Concatenate vision proxy tokens with text tokens
concat_labels = torch.cat((input_info["text"]["labels"], vision_labels), dim=1)

# Rearrrange tokens in interleaved format using shifted multimodal position ids
interleaved_labels = torch.zeros_like(concat_labels, dtype=concat_labels.dtype, device=concat_labels.device)
labels = interleaved_labels.scatter_(1, shifted_multimodal_position_ids, concat_labels)[:,:input_seq_length]

# Loss mask.
label_length = labels.shape[1]
loss_mask = torch.ones((batch_size, label_length), dtype=torch.float, device=labels.device)

if pad_token is not None:
loss_mask[labels == pad_token] = 0.0
loss_mask[labels == bos_token] = 0.0
loss_mask[labels == eod_token] = 0.0

position_ids = torch.arange(input_seq_length, dtype=torch.long, device=labels.device)
position_ids = torch.arange(input_seq_length, dtype=torch.long, device=labels.device) # FIX THIS #TODO
position_ids = position_ids.unsqueeze(0).expand(batch_size, input_seq_length)
return attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids
return attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids, labels

def local_rank():
"""Local rank of process"""
Expand Down

0 comments on commit fbfb277

Please sign in to comment.