Skip to content

Commit

Permalink
Fixed labels
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Oct 18, 2023
1 parent 68edade commit 3a46dde
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,29 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype):
# Unpack multimodal position ids.
multimodal_position_ids = data_b["multimodal_position_ids"].long().contiguous()

# Padded text length
max_text_length = text_input.shape[1]
text_positions = multimodal_position_ids[:, MODALITY_DICT['text'], :max_text_length]
text_labels = labels[:, MODALITY_DICT['text'], :max_text_length]
assert torch.all(multimodal_position_ids[:, MODALITY_DICT['text'], max_text_length:] == -1)
assert torch.all(labels[:, MODALITY_DICT['text'], max_text_length:] == -1)
text_input_info = {
"input": text_input,
"labels": text_labels,
"positions": text_positions,
"seq_length": 1
}

# Unpack vision_input.
# Unpack vision_input and get padded vision length
vision_input = data_b["vision_input"].half().contiguous()
max_vision_length = vision_input.shape[1]
vision_positions = multimodal_position_ids[:, MODALITY_DICT['vision'], :max_vision_length]
vision_labels = labels[:, MODALITY_DICT['vision'], :max_vision_length]
assert torch.all(multimodal_position_ids[:, MODALITY_DICT['vision'], max_vision_length:] == -1)
assert torch.all(labels[:, MODALITY_DICT['vision'], max_vision_length:] == -1)
vision_input_info = {
"input": vision_input,
"labels": vision_labels,
"positions": vision_positions,
"seq_length": neox_args.vision_seq_length
}
Expand All @@ -322,9 +331,8 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype):
}

# Get the masks and position ids.
attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids = get_multimodal_ltor_masks_and_position_ids(
attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids, labels = get_multimodal_ltor_masks_and_position_ids(
input_info=input_info,
labels=labels,
input_seq_length=neox_args.seq_length,
eod_token=neox_args.tokenizer.eod_id,
bos_token=neox_args.tokenizer.bos_id if hasattr(neox_args.tokenizer, "bos_id") else None,
Expand Down

0 comments on commit 3a46dde

Please sign in to comment.