diff --git a/megatron/training.py b/megatron/training.py index 77c44810f..f136860e9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -288,20 +288,29 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): # Unpack multimodal position ids. multimodal_position_ids = data_b["multimodal_position_ids"].long().contiguous() + # Padded text length max_text_length = text_input.shape[1] text_positions = multimodal_position_ids[:, MODALITY_DICT['text'], :max_text_length] + text_labels = labels[:, MODALITY_DICT['text'], :max_text_length] + assert torch.all(multimodal_position_ids[:, MODALITY_DICT['text'], max_text_length:] == -1) + assert torch.all(labels[:, MODALITY_DICT['text'], max_text_length:] == -1) text_input_info = { "input": text_input, + "labels": text_labels, "positions": text_positions, "seq_length": 1 } - # Unpack vision_input. + # Unpack vision_input and get padded vision length vision_input = data_b["vision_input"].half().contiguous() max_vision_length = vision_input.shape[1] vision_positions = multimodal_position_ids[:, MODALITY_DICT['vision'], :max_vision_length] + vision_labels = labels[:, MODALITY_DICT['vision'], :max_vision_length] + assert torch.all(multimodal_position_ids[:, MODALITY_DICT['vision'], max_vision_length:] == -1) + assert torch.all(labels[:, MODALITY_DICT['vision'], max_vision_length:] == -1) vision_input_info = { "input": vision_input, + "labels": vision_labels, "positions": vision_positions, "seq_length": neox_args.vision_seq_length } @@ -322,9 +331,8 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): } # Get the masks and position ids. - attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids = get_multimodal_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids, shifted_multimodal_position_ids, labels = get_multimodal_ltor_masks_and_position_ids( input_info=input_info, - labels=labels, input_seq_length=neox_args.seq_length, eod_token=neox_args.tokenizer.eod_id, bos_token=neox_args.tokenizer.bos_id if hasattr(neox_args.tokenizer, "bos_id") else None,