From 4131b07349c62c7279193573b6bd22ffdea33188 Mon Sep 17 00:00:00 2001 From: Tuomas Rintamaki Date: Fri, 15 Nov 2024 23:48:26 -0800 Subject: [PATCH] ADLR/megatron-lm!2306 - NVLM example scripts --- examples/multimodal/README.md | 2 +- .../combine_lm_vision_checkpoints.sh | 57 ++++++ examples/multimodal/combine_mistral_clip.sh | 23 --- examples/multimodal/config.py | 19 +- .../model_converter/internvit_converter.py | 0 .../model_converter/siglip_converter.py | 6 +- examples/multimodal/nvlm/README.md | 5 + examples/multimodal/nvlm/nvlm_prompts.json | 165 ++++++++++++++++ .../nvlm/pp_checkpoint_converter.py | 180 ++++++++++++++++++ examples/multimodal/nvlm/pretrain_blend.yaml | 28 +++ .../nvlm/pretrain_qwen20_72b_internvit_6b.sh | 158 +++++++++++++++ .../nvlm/pretrain_yi_34b_internvit_6b.sh | 154 +++++++++++++++ ...text_generation_qwen20_72b_internvit_6b.sh | 139 ++++++++++++++ ...run_text_generation_yi_34b_internvit_6b.sh | 138 ++++++++++++++ examples/multimodal/nvlm/sft_34b_internvit.sh | 160 ++++++++++++++++ examples/multimodal/nvlm/sft_blend.yaml | 23 +++ .../nvlm/sft_qwen20_72b_internvit_6b.sh | 166 ++++++++++++++++ 17 files changed, 1395 insertions(+), 28 deletions(-) create mode 100755 examples/multimodal/combine_lm_vision_checkpoints.sh delete mode 100755 examples/multimodal/combine_mistral_clip.sh mode change 100644 => 100755 examples/multimodal/model_converter/internvit_converter.py create mode 100644 examples/multimodal/nvlm/README.md create mode 100644 examples/multimodal/nvlm/nvlm_prompts.json create mode 100644 examples/multimodal/nvlm/pp_checkpoint_converter.py create mode 100644 examples/multimodal/nvlm/pretrain_blend.yaml create mode 100644 examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh create mode 100644 examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh create mode 100644 examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh create mode 100644 examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh create mode 100644 examples/multimodal/nvlm/sft_34b_internvit.sh create mode 100644 examples/multimodal/nvlm/sft_blend.yaml create mode 100644 examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index 5ab0c7bf0b..afd0ad2e25 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -31,7 +31,7 @@ python examples/multimodal/model_converter/clip_converter.py --download-root /so Update the paths to point to the mcore converted CLIP and Mistral models and run the following script to combine the Mistral and CLIP models into a single multimodal checkpoint folder: ``` -examples/multimodal/combine_mistral_clip.sh /path/to/mistral/model /path/to/clip/model /output/dir +examples/multimodal/combine_lm_vision_checkpoints.sh /path/to/mistral/model /path/to/clip/model /output/dir ``` ## Training diff --git a/examples/multimodal/combine_lm_vision_checkpoints.sh b/examples/multimodal/combine_lm_vision_checkpoints.sh new file mode 100755 index 0000000000..52de16ecd2 --- /dev/null +++ b/examples/multimodal/combine_lm_vision_checkpoints.sh @@ -0,0 +1,57 @@ +#/bin/bash +MCORE_LM=$1 # +MCORE_VISION=$2 # +OUTPUT_DIR=$3 # +MODEL_TYPE=$4 # Model type. Default: Mistral CLIP example. + +if [[ $MODEL_TYPE == "nvlm" ]]; then + # NVLM TP=8 + python examples/multimodal/combine_state_dicts.py \ + --input \ + ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_07/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_07/model_optim_rng.pt \ + --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ + --output \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_04/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_05/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_06/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_07/model_optim_rng.pt +else + # Mistral CLIP example TP=4. + python examples/multimodal/combine_state_dicts.py \ + --input \ + ${MCORE_LM}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${MCORE_LM}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + ${MCORE_VISION}/iter_0000001/mp_rank_03/model_optim_rng.pt \ + --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ + --output \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_00/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_01/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_02/model_optim_rng.pt \ + ${OUTPUT_DIR}/iter_0000001/mp_rank_03/model_optim_rng.pt +fi + +echo 1 > ${OUTPUT_DIR}/latest_checkpointed_iteration.txt diff --git a/examples/multimodal/combine_mistral_clip.sh b/examples/multimodal/combine_mistral_clip.sh deleted file mode 100755 index ff866c7f72..0000000000 --- a/examples/multimodal/combine_mistral_clip.sh +++ /dev/null @@ -1,23 +0,0 @@ -#/bin/bash -MCORE_MISTRAL=$1 # -MCORE_CLIP=$2 # -OUTPUT_DIR=$3 # - -python examples/multimodal/combine_state_dicts.py \ - --input \ - ${MCORE_MISTRAL}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${MCORE_CLIP}/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${MCORE_MISTRAL}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${MCORE_CLIP}/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${MCORE_MISTRAL}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${MCORE_CLIP}/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${MCORE_MISTRAL}/iter_0000001/mp_rank_03/model_optim_rng.pt \ - ${MCORE_CLIP}/iter_0000001/mp_rank_03/model_optim_rng.pt \ - --prefixes language_model vision_model language_model vision_model language_model vision_model language_model vision_model \ - --output \ - ${OUTPUT_DIR}/mistral_instruct_clip336_tp4_combined_mcore/iter_0000001/mp_rank_00/model_optim_rng.pt \ - ${OUTPUT_DIR}/mistral_instruct_clip336_tp4_combined_mcore/iter_0000001/mp_rank_01/model_optim_rng.pt \ - ${OUTPUT_DIR}/mistral_instruct_clip336_tp4_combined_mcore/iter_0000001/mp_rank_02/model_optim_rng.pt \ - ${OUTPUT_DIR}/mistral_instruct_clip336_tp4_combined_mcore/iter_0000001/mp_rank_03/model_optim_rng.pt - -echo 1 > ${OUTPUT_DIR}/mistral_instruct_clip336_tp4_combined_mcore/latest_checkpointed_iteration.txt diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py index 4524df4480..4d7b915c19 100644 --- a/examples/multimodal/config.py +++ b/examples/multimodal/config.py @@ -73,6 +73,20 @@ def get_language_model_config(config): config.apply_rope_fusion = False config.attention_softmax_in_fp32 = True config.ffn_hidden_size = 20480 + elif config.language_model_type == "qwen2.0_72B": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.add_qkv_bias = True + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 29568 else: raise ValueError(f"unknown language model type {config.language_model_type}") @@ -146,7 +160,6 @@ def get_vision_model_config(config, apply_query_key_layer_scaling): else: raise ValueError(f"unknown vision model type {config.vision_model_type}") - return config @@ -171,6 +184,10 @@ def get_vision_projection_config(config, hidden_size): config.ffn_hidden_size = 20480 config.normalization = 'LayerNorm' config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "qwen2.0_72B": + config.ffn_hidden_size = 29568 + config.normalization = 'LayerNorm' + config.activation_func = torch.nn.functional.gelu else: raise ValueError(f"unknown language model type {config.language_model_type}") diff --git a/examples/multimodal/model_converter/internvit_converter.py b/examples/multimodal/model_converter/internvit_converter.py old mode 100644 new mode 100755 diff --git a/examples/multimodal/model_converter/siglip_converter.py b/examples/multimodal/model_converter/siglip_converter.py index 117f8b8924..666cda15eb 100644 --- a/examples/multimodal/model_converter/siglip_converter.py +++ b/examples/multimodal/model_converter/siglip_converter.py @@ -61,9 +61,9 @@ def add_chunck_tensor(new_tensor, new_name, chunk_dim=None): head_dim = 72 num_head = 16 for layer_idx in range(27): - origin_base = f"vision_tower.vision_model.encoder.layers.{layer_idx}" + origin_base = f"vision_tower.vision_model.encoder.layers.{layer_idx}" target_base = f"decoder.layers.{layer_idx}" - + for param_type in ["weight", "bias"]: # QKV q_proj_params = state_dict[f"{origin_base}.self_attn.q_proj.{param_type}"] @@ -135,7 +135,7 @@ def add_chunck_tensor(new_tensor, new_name, chunk_dim=None): Example usage: python siglip_converter.py --tensor-parallel-size 4 --output google_paligemma_3b_pt_44_mcore_tp_4 --use-te -examples/multimodal/combine_mistral_clip.sh /lustre/fsw/portfolios/llmservice/users/jbarker/workspace/checkpoints/Mistral-7B-Instruct-v0.3-mcore-tp4 google_paligemma_3b_pt_44_mcore_tp_4 mistral_7b_instruct_v0p3_google_paligemma_3b_pt_44_mcore_tp_4 +examples/multimodal/combine_mistral_clip.sh Mistral-7B-Instruct-v0.3-mcore-tp4 google_paligemma_3b_pt_44_mcore_tp_4 mistral_7b_instruct_v0p3_google_paligemma_3b_pt_44_mcore_tp_4 """, formatter_class=argparse.RawDescriptionHelpFormatter, ) diff --git a/examples/multimodal/nvlm/README.md b/examples/multimodal/nvlm/README.md new file mode 100644 index 0000000000..9bcca10dc8 --- /dev/null +++ b/examples/multimodal/nvlm/README.md @@ -0,0 +1,5 @@ +NVLM +==== + +Work in progress. +Please refer to the [NVLM paper](https://arxiv.org/pdf/2409.11402) for details. diff --git a/examples/multimodal/nvlm/nvlm_prompts.json b/examples/multimodal/nvlm/nvlm_prompts.json new file mode 100644 index 0000000000..ab36adc765 --- /dev/null +++ b/examples/multimodal/nvlm/nvlm_prompts.json @@ -0,0 +1,165 @@ +{ + "COMMENT": "Mixture of our own custom prompts and some prompts from https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/viewer and https://huggingface.co/datasets/HuggingFaceM4/M3IT", + "Captioning": { + "raw": [ + "Can you briefly explain what you see in the image?", + "Describe what's happening in this image in one short sentence.", + "Write a short caption that accurately represents the content of this image.", + "Please generate a descriptive caption for the image provided.", + "How would you summarize the scene depicted in the picture in short?", + "Describe the image briefly.", + "Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details.", + "Create a concise caption that accurately describes the main elements in the image provided.", + "Write a brief, yet comprehensive, description of the image.", + "Describe the image in a clear and concise manner.", + "For the given image, provide a one-sentence summary that captures the most important details.", + "Generate a short caption for the picture.", + "Write a short and informative description that highlights the primary subjects and actions occurring in the given image.", + "Provide a concise and informative caption for the image, focusing on the primary subjects.", + "Write a clear description of the image, make sure the key features are well covered.", + "Offer a succinct explanation of the picture presented." + ] + }, + "CaptioningPretraining": { + "raw": [ + "Give a brief description of image.", + "Give a brief description of the image.", + "Provide a brief description of the given image.", + "Provide a one-sentence caption for the provided image.", + "Write a terse but informative summary of the picture.", + "Describe the image concisely.", + "Generate a clear and concise summary of the photo." + ] + }, + "CaptioningSFT": { + "raw": [ + "Give a brief description of the image.", + "Give a short and clear explanation of the subsequent image.", + "Present a compact description of the photo's key features.", + "Provide a brief description of the given image.", + "Provide a one-sentence caption for the provided image.", + "Render a clear and concise summary of the photo.", + "Share a concise interpretation of the image provided.", + "Summarize the visual content of the image.", + "Write a terse but informative summary of the picture.", + "Describe the image concisely." + ] + }, + "VQAPretraining": { + "raw": [ + "Question: {} Short answer:", + "Question: {} Answer:" + ] + }, + "VQASFT": { + "raw": [ + "{}", + "{}\nAnswer the question using a single word or phrase." + ], + "docvqa": [ + "{}", + "{}\nAnswer this question using the text in the image directly." + ] + }, + "DocPretraining": { + "raw": [ + "Retrieve the text from the given pdf image.", + "Extract the text from the provided document.", + "Transcribe the text displayed in the image." + ], + "ocr_multi": [ + "Apply grounded Optical Character Recognition (OCR) to the provided image.", + "Extract all texts and their bounding boxes from the given image using grounded OCR.", + "Extract and transcribe all visible text from the provided image, ensuring accurate spatial recognition.", + "Conduct a detailed optical character recognition analysis on this image, maintaining the text's original layout and positioning.", + "Execute a thorough text recognition procedure on this visual input, ensuring that the spatial arrangement of the text is accurately represented.", + "Perform an in-depth OCR scan of the image, capturing both the content and contextual positioning of all textual information.", + "OCR with grounding:" + ], + "md": [ + "Extract the text from the given image and format it in Markdown.", + "Convert the text from the provided image into Markdown format.", + "Transform the text from the given image into Markdown syntax.", + "Extract and convert the text from the image to Markdown.", + "Retrieve the text from the image and present it in Markdown format." + ], + "grounded_ocr": [ + "{}. Text:", + "Recognize the text in this region: {}.", + "Identify the text in this area: {}.", + "Detect the text within this section: {}." + ], + "referring_grounding": [ + "Region of \"{}\" is:", + "Locate the text \"{}\" in the image.", + "Identify the text \"{}\" in the image and provide the coordinates." + ] + }, + "CaptioningDetailed": { + "raw": [ + "Create a comprehensive paragraph that captures the essence of the image while weaving a cohesive narrative around its elements.", + "Compose a paragraph that thoroughly describes the image's content, providing context and connections between different aspects of the scene.", + "Provide a detailed, paragraph-length description of the image that paints a vivid picture and tells a coherent story.", + "Write a rich and engaging paragraph that delves into the image's components, describing not only what is seen but also how the elements relate to one another.", + "Give a well-rounded, paragraph-length explanation of the image, describing the scene and its components while forming a complete and engaging narrative.", + "Produce a paragraph that not only describes the individual elements in the image but also weaves them together to form a cohesive, connected account.", + "Construct a paragraph that captures the image's details and context, offering a more in-depth and engaging story than a simple caption.", + "Compose a descriptive paragraph that brings the image to life through detailed storytelling, connecting the various visual elements into a unified narrative.", + "Create a paragraph that provides an extensive and interconnected description of the image, ensuring that the narrative is both detailed and cohesive.", + "Write a compelling and detailed paragraph that delves into the image's components, linking them together to create a unified and engaging story." + ] + }, + "OCR": { + "raw": [ + "Can you read the text from image and output here?", + "Extract and document the text from the provided image.", + "Converting the text embedded in this image into a readable document.", + "Transcribe all the text you find.", + "Can you extract all visible text from the image here?" + ], + "markdown": [ + "Can you extract all visible text from the provided image?", + "Converting the text embedded in this image into a readable markdown document.", + "Can you read the text in the document as markdown?", + "Transcribe the document as markdown.", + "Extract and document the text from the provided image." + ], + "table_markdown": [ + "Can you extract all visible text from the provided table?", + "Can you read the text in the provided table as markdown?", + "Transcribe the table as markdown.", + "Extract and document the text from the provided table image." + ], + "plain": [ + "Transcribe the document as plain text.", + "Extract and document the text from the provided image.", + "Converting the text embedded in this image into a readable document.", + "Transcribe all the text you find.", + "Can you extract all visible text from the image here?" + ], + "bbox_plain": [ + "Transcribe the document as plain text along with bounding boxes.", + "Extract and document the text from the provided image along with bounding boxes.", + "Converting the text embedded in this image into a readable documen along with bounding boxes.", + "Can you extract all visible text with bounding boxes from the image here?" + ] + }, + "VQA": { + "raw": [ + "Given the image, answer the following question with few words.", + "Answer the following question: ", + "What is the answer to this question?", + "Write the answer: ", + "Please answer this question: " + ] + }, + "Embedded": { + "raw": [ + "Given the image, answer the following question with few words.", + "Answer the following question: ", + "What is the answer to this question?", + "Write the answer: ", + "Please answer this question: " + ] + } +} diff --git a/examples/multimodal/nvlm/pp_checkpoint_converter.py b/examples/multimodal/nvlm/pp_checkpoint_converter.py new file mode 100644 index 0000000000..cde63e5ad2 --- /dev/null +++ b/examples/multimodal/nvlm/pp_checkpoint_converter.py @@ -0,0 +1,180 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os +import sys + +import torch + +# Add megatron to the path. +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir, os.path.pardir)) +) + + +def split(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_per_pp_rank): + """Split pipeline parallel size = 1 checkpoint to pipeline parallel size N.""" + for tp in range(num_tp): + path = os.path.join(input_dir, f"mp_rank_0{tp}", "model_optim_rng.pt") + sd = torch.load(path) + + if num_layers_per_pp_rank is None: + num_layers = sd["args"].num_layers + assert num_layers % output_pp == 0, "specify --num-layers-per-pp-rank for an uneven split" + num_layers_per_pp_rank = [num_layers // output_pp] * output_pp + + layer_lb = 0 + for pp in range(output_pp): + assert num_layers_per_pp_rank[pp] > 0, "each pp rank must have at least 1 layer" + layer_ub = layer_lb + num_layers_per_pp_rank[pp] + + new_sd = sd.copy() + new_sd["model"] = dict() + for k, v in sd["model"].items(): + # First pp rank has vision model. + if pp == 0 and ("vision_model" in k or "vision_projection" in k): + new_sd["model"][k] = v + continue + + # Only the first pp rank has the word embeddings. + if "language_model.embedding.word_embeddings" in k and pp == 0: + new_sd["model"][k] = v + + # Only the last pp rank has the output layer. + if "language_model.output_layer" in k and pp == input_pp - 1: + new_sd["model"][k] = v + + # Only the last pp rank has final layer norm. + if "language_model.decoder.final_layernorm" in k and pp == input_pp - 1: + new_sd["model"][k] = v + + if "language_model.decoder.layers" in k: + layer_num = int(k.split(".")[3]) + + if layer_lb <= layer_num and layer_num < layer_ub: + # On all pp ranks, megatron starts layer nums from 0! + new_layer_num = int(layer_num - layer_lb) + + k_splitted = k.split(".") + k_splitted[3] = str(new_layer_num) + new_k = ".".join(k_splitted) + + new_sd["model"][new_k] = v + + output_dir = os.path.join(base_output_dir, f"iter_0000001/mp_rank_0{tp}_00{pp}") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "model_optim_rng.pt") + torch.save(new_sd, output_path) + + print(f"processed tp rank: {tp}/{num_tp - 1} and pp rank: {pp}/{output_pp - 1}") + + layer_lb = layer_ub + + # This is needed for megatron checkpoint loading. + with open(os.path.join(base_output_dir, "iter_0000001/latest_checkpointed_iteration.txt"), "w") as f: + f.write("1") + + +def combine(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_per_pp_rank): + """Combine pipeline parallel size = N checkpoint to pipeline parallel size 1.""" + for tp in range(num_tp): + new_sd = None + + layer_num_offset = 0 + max_layer_num = 0 + + for pp in range(input_pp): + path = os.path.join(input_dir, f"mp_rank_0{tp}_00{pp}", "model_optim_rng.pt") + sd = torch.load(path) + + if pp == 0: + new_sd = sd.copy() + new_sd["model"] = dict() + new_sd["args"].pipeline_model_parallel_size = 1 + + assert new_sd is not None + + for k, v in sd["model"].items(): + # First pp rank has vision model. + if pp == 0 and ("vision_model" in k or "vision_projection" in k): + new_sd["model"][k] = v + continue + + # Only the first pp rank has the word embeddings. + if "language_model.embedding.word_embeddings" in k and pp == 0: + new_sd["model"][k] = v + + # Only the last pp rank has the output layer. + if "language_model.output_layer" in k and pp == input_pp - 1: + new_sd["model"][k] = v + + # Only the last pp rank has final layer norm. + if "language_model.decoder.final_layernorm" in k and pp == input_pp - 1: + new_sd["model"][k] = v + + if "language_model.decoder.layers" in k: + layer_num = int(k.split(".")[3]) + + # On all pp ranks, megatron starts layer nums from 0! + new_layer_num = layer_num_offset + layer_num + + if new_layer_num > max_layer_num: + max_layer_num = new_layer_num + + k_splitted = k.split(".") + k_splitted[3] = str(new_layer_num) + new_k = ".".join(k_splitted) + + new_sd["model"][new_k] = v + + print(f"processed tp rank: {tp}/{num_tp - 1} and pp rank: {pp}/{input_pp - 1}") + + layer_num_offset = max_layer_num + 1 + + output_dir = os.path.join(base_output_dir, f"iter_0000001/mp_rank_0{tp}") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "model_optim_rng.pt") + torch.save(new_sd, output_path) + + # This is needed for megatron checkpoint loading. + with open(os.path.join(base_output_dir, "iter_0000001/latest_checkpointed_iteration.txt"), "w") as f: + f.write("1") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Change pipeline parallelism for a model", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--input", type=str, required=True, help="Input model directory" + ) + parser.add_argument( + "--input-pipeline-parallel", type=int, required=True, help="Input model pipeline parallelism" + ) + parser.add_argument( + "--output", type=str, required=True, help="Output model directory" + ) + parser.add_argument( + "--output-pipeline-parallel", type=int, required=True, help="Output model pipeline parallelism" + ) + parser.add_argument( + "--tensor-parallel", type=int, required=True, help="Model tensor parallel size", + ) + parser.add_argument( + "--num-layers-per-pp-rank", type=int, default=None, nargs="*", help="Specify this for uneven pipeline parallel split", + ) + + args = parser.parse_args() + + f = None + if args.input_pipeline_parallel == 1 and args.output_pipeline_parallel > 1: + f = split + elif args.input_pipeline_parallel > 1 and args.output_pipeline_parallel == 1: + f = combine + else: + raise NotImplementedError("Only pipeline parallel 1 to N and N to 1 are supported") + + f(args.input, args.output, args.input_pipeline_parallel, args.output_pipeline_parallel, args.tensor_parallel, args.num_layers_per_pp_rank) + + print("done.") diff --git a/examples/multimodal/nvlm/pretrain_blend.yaml b/examples/multimodal/nvlm/pretrain_blend.yaml new file mode 100644 index 0000000000..fbbcc54388 --- /dev/null +++ b/examples/multimodal/nvlm/pretrain_blend.yaml @@ -0,0 +1,28 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 0.579 # Datasets are weighted according to their size. Weights sum up to 1. + path: + subflavors: + augmentation: False + + - weight: 0.02 + path: + subflavors: + augmentation: False + + - weight: 0.01 + path: + subflavors: + augmentation: False + + # Please refer to Table 4 in https://arxiv.org/pdf/2409.11402 for full list of pretrain datasets. + # Please refer to https://nvidia.github.io/Megatron-Energon/data_prep.html on preparing datasets in the Megatron Energon format. + val: + datasets: + - weight: 1. + path: + subflavors: + augmentation: False diff --git a/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh new file mode 100644 index 0000000000..922ca6bc7b --- /dev/null +++ b/examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh @@ -0,0 +1,158 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM="false" + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-qwen20-72b-internvit-${DATETIME}" +else + MODEL_NAME="mcore-qwen20-72b-internvit" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +CHECKPOINT_DIR="${WORKSPACE}/combined-qwen2.0-72b-instruct-internvit-6b-448px-1.5-tp8-te" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/pretrain_blend.yaml" + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + AD=0.0 + HD=0.0 + LI=1 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=2048 + NW=8 + AD=0.1 + HD=0.1 + LI=5 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +SEQ_LEN=256 # Image embeddings sequence length. +DECODER_SEQ_LEN=512 # Language model sequence length. +MAX_POS_EMBED=512 + + +OPTIONS=" \ + --use-checkpoint-args \ + --exit-duration-in-mins 230 \ + --disable-bias-linear \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model ${WORKSPACE}/ \ + --tokenizer-prompt-format qwen2p0 \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --num-layers 80 \ + --hidden-size 8192 \ + --ffn-hidden-size 29568 \ + --add-qkv-bias \ + --num-attention-heads 64 \ + --use-distributed-optimizer \ + --use-te \ + --num-workers ${NW} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings 32768 \ + --train-samples 122880000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 1e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 500 \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --save-interval 5000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --split 100,0,0 \ + --clip-grad 10.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --bf16 \ + --eod-mask-loss \ + --freeze-ViT \ + --freeze-LM \ + --patch-dim 14 \ + --img-h 448 \ + --img-w 448 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type qwen2.0_72B \ + ${EXTRA_ARGS} \ + --allow-missing-vision-projection-checkpoint \ + --vision-model-type internvit \ + --disable-vision-class-token \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --ckpt-format torch \ + --pixel-shuffle \ + --use-image-tag +" + + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh b/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh new file mode 100644 index 0000000000..da1c4e0ac2 --- /dev/null +++ b/examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh @@ -0,0 +1,154 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export TOKENIZERS_PARALLELISM="false" + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-${DATETIME}" +else + MODEL_NAME="mcore-nous-yi34b-internvit-mlp" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +LOAD_NAME="combined-yi-34b-internvit-tp8-mcore" +CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/pretrain_blend.yaml" + + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + LI=1 + AD=0.0 + HD=0.0 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=2048 + NW=8 + LI=5 + AD=0.1 + HD=0.1 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +SEQ_LEN=256 # Image embeddings sequence length. +DECODER_SEQ_LEN=512 # Language model sequence length. +MAX_POS_EMBED=512 + + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-duration-in-mins 230 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model ${WORKSPACE}/ \ + --tokenizer-prompt-format chatml \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --train-samples 122880000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --lr 1e-4 \ + --min-lr 2.5e-5 \ + --lr-decay-style cosine \ + --clip-grad 10.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir=${TENSORBOARD_DIR} \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --split 100,0,0 \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --save-interval 2000 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --allow-missing-vision-projection-checkpoint \ + --disable-vision-class-token \ + --use-te \ + --use-checkpoint-args \ + --ckpt-format torch \ + --pixel-shuffle \ + --use-image-tag + " + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh new file mode 100644 index 0000000000..ffb5c30d1c --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh @@ -0,0 +1,139 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +USE_TILING=0 +USE_PIXEL_SHUFFLE_ONLY=0 + +while [[ $# -gt 0 ]]; do + case $1 in + --input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + --task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + --use-tiling) + USE_TILING=1 + shift + shift + ;; + --use-pixel-shuffle-only) + USE_PIXEL_SHUFFLE_ONLY=1 + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +SEQ_LEN=1024 # Image embeddings sequence length. +DECODER_SEQ_LEN=8192 # Language model sequence length. +MAX_POS_EMBED=8192 + +# Additional arguments. +EXTRA_ARGS="" + +if [[ $USE_TILING -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 6 --use-thumbnail --use-tile-tags --use-image-tag" + SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +fi + +if [[ $USE_PIXEL_SHUFFLE_ONLY -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-image-tag" + SEQ_LEN=256 +fi + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --no-masked-softmax-fusion \ + --swiglu \ + --num-layers 80 \ + --hidden-size 8192 \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --num-attention-heads 64 \ + --exit-on-missing-checkpoint \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 29568 \ + --load ${MODEL_PATH} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model \ + --tokenizer-prompt-format qwen2p0 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --disable-bias-linear \ + --add-qkv-bias \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --language-model-type qwen2.0_72B \ + --vision-model-type internvit \ + --micro-batch-size 1 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --bf16 \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --use-te \ + --transformer-impl transformer_engine \ + --use-checkpoint-args \ + --out-seq-length 16 \ + --temperature 1.0 \ + --patch-dim 14 \ + --seed 1234 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --disable-vision-class-token \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + ${EXTRA_ARGS} \ + --task ${TASK} +done diff --git a/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh b/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh new file mode 100644 index 0000000000..8ad070d94e --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +USE_TILING=0 +USE_PIXEL_SHUFFLE_ONLY=0 + +while [[ $# -gt 0 ]]; do + case $1 in + --input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + --task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + --use-tiling) + USE_TILING=1 + shift + shift + ;; + --use-pixel-shuffle-only) + USE_PIXEL_SHUFFLE_ONLY=1 + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + +SEQ_LEN=1024 # Image embeddings sequence length. +DECODER_SEQ_LEN=8192 # Language model sequence length. +MAX_POS_EMBED=8192 + +# Additional arguments. +EXTRA_ARGS="" + +if [[ $USE_TILING -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-tiling --max-num-tiles 6 --use-thumbnail --use-tile-tags --use-image-tag" + SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +fi + +if [[ $USE_PIXEL_SHUFFLE_ONLY -eq 1 ]]; then + EXTRA_ARGS+=" --pixel-shuffle --use-image-tag" + SEQ_LEN=256 +fi + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --no-masked-softmax-fusion \ + --swiglu \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-on-missing-checkpoint \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --load ${MODEL_PATH} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model \ + --tokenizer-prompt-format chatml \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 1 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size 1 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --bf16 \ + --freeze-LM \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --use-te \ + --transformer-impl transformer_engine \ + --use-checkpoint-args \ + --out-seq-length 16 \ + --temperature 1.0 \ + --patch-dim 14 \ + --seed 1234 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --disable-vision-class-token \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + ${EXTRA_ARGS} \ + --task ${TASK} +done diff --git a/examples/multimodal/nvlm/sft_34b_internvit.sh b/examples/multimodal/nvlm/sft_34b_internvit.sh new file mode 100644 index 0000000000..5201b2d95a --- /dev/null +++ b/examples/multimodal/nvlm/sft_34b_internvit.sh @@ -0,0 +1,160 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_ALGO=^NVLS +export TOKENIZERS_PARALLELISM="false" + + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft-${DATETIME}" +else + MODEL_NAME="mcore-nous-yi34b-internvit-mlp-sft" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR=${OUTPUT}/checkpoints +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +LOAD_NAME="mcore-nous-yi34b-internvit-mlp" # From pretraining +CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" + + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + LI=1 + AD=0.0 + HD=0.0 + ALLOW_NONDETERMINISTIC=1 + + # Can run out of GPU memory in interactive memory without this. + # This is just for interactive testing purposes. Do not use for proper training. + EXTRA_ARGS=" --freeze-LM" +else + MBZ=1 + BZ=128 + NW=2 + LI=5 + AD=0.0 + HD=0.0 + ALLOW_NONDETERMINISTIC=1 + + EXTRA_ARGS="" +fi + +SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +DECODER_SEQ_LEN=3200 # Language model sequence length. +MAX_POS_EMBED=3200 + +OPTIONS=" \ + --swiglu \ + --use-distributed-optimizer \ + --num-workers ${NW} \ + --num-layers 60 \ + --hidden-size 7168 \ + --normalization RMSNorm \ + --num-attention-heads 56 \ + --exit-duration-in-mins 230 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 20480 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings ${MAX_POS_EMBED} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model ${WORKSPACE}/ \ + --tokenizer-prompt-format chatml \ + --vocab-size 64000 \ + --make-vocab-size-divisible-by 1 \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 5000000 \ + --disable-bias-linear \ + --tensor-model-parallel-size 8 \ + --language-model-type yi-34b \ + --vision-model-type internvit \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --train-samples 30000000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --lr 2e-6 \ + --min-lr 2.5e-7 \ + --lr-decay-style cosine \ + --split 100,0,0 \ + --clip-grad 10 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --eod-mask-loss \ + --bf16 \ + --tensorboard-dir=${TENSORBOARD_DIR} \ + --freeze-ViT \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --data-path ${DATA_TRAIN} \ + --dataloader-type external \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --log-interval ${LI} \ + --load ${FINETUNE_DIR} \ + --save ${FINETUNE_DIR} \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --save-interval 5000 \ + --eval-interval 500 \ + --eval-iters 10 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + ${EXTRA_ARGS} \ + --disable-vision-class-token \ + --use-te \ + --ckpt-format torch \ + --pixel-shuffle \ + --use-tiling \ + --max-num-tiles 6 \ + --use-thumbnail \ + --use-tile-tags \ + --use-image-tag + " + +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} +export NVTE_APPLY_QK_LAYER_SCALING=0 + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi diff --git a/examples/multimodal/nvlm/sft_blend.yaml b/examples/multimodal/nvlm/sft_blend.yaml new file mode 100644 index 0000000000..56c8230a2a --- /dev/null +++ b/examples/multimodal/nvlm/sft_blend.yaml @@ -0,0 +1,23 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 0.01 # # Datasets are weighted according to their size. Weights sum up to 1. + path: + subflavors: + augmentation: False + + - weight: 0.02 + path: + subflavors: + augmentation: False + + # Please refer to Table 6 in https://arxiv.org/pdf/2409.11402 for full list of SFT datasets. + # Please refer to https://nvidia.github.io/Megatron-Energon/data_prep.html on preparing datasets in the Megatron Energon format. + val: + datasets: + - weight: 1. + path: + subflavors: + augmentation: False diff --git a/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh b/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh new file mode 100644 index 0000000000..ed207ae0f9 --- /dev/null +++ b/examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh @@ -0,0 +1,166 @@ +#!/bin/bash + +# Your SBATCH commands here if using SLURM. + +# Please launch this script from megatron-lm root. + +# Train a multimodal model. + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_ALGO=^NVLS +export TOKENIZERS_PARALLELISM="false" + +DEBUG=0 + +if [[ $BATCH -eq 0 ]]; then + DATETIME=`date +'%y-%m-%d-%H-%M-%S'` + MODEL_NAME="mcore-qwen20-72b-internvit-sft-${DATETIME}" +else + MODEL_NAME="mcore-qwen20-72b-internvit-sft" +fi + +WORKSPACE="" +SOURCE=`pwd` +OUTPUT_BASE="${WORKSPACE}/output" +OUTPUT="${OUTPUT_BASE}/${MODEL_NAME}" + +FINETUNE_DIR="${OUTPUT}/checkpoints" +LOGS_DIR="${OUTPUT}/logs" +TENSORBOARD_DIR="${OUTPUT}/tensorboard" + +# From pretraining. The pretraining checkpoint must be manually split to 4 pipeline parallel stages. +# Please refer to README.md and run examples/multimodal/nvlm/pp_checkpoint_converter.py. +LOAD_NAME="mcore-qwen20-72b-internvit-pp4" + +CHECKPOINT_DIR="${WORKSPACE}/output/${LOAD_NAME}/checkpoints" + +DATA_TRAIN="${SOURCE}/examples/multimodal/nvlm/sft_blend.yaml" + +if [[ $DEBUG -eq 1 ]]; then + MBZ=1 + BZ=1 + NW=0 + AD=0.0 + HD=0.0 + LI=1 + # This is just for interactive testing purposes. Do not use for proper training. + EXTRA_ARGS="--freeze-LM" + ALLOW_NONDETERMINISTIC=1 +else + MBZ=1 + BZ=256 + NW=8 + AD=0.0 + HD=0.0 + LI=5 + EXTRA_ARGS="" + ALLOW_NONDETERMINISTIC=1 +fi + +SEQ_LEN=261 # Image embeddings sequence length (256 image embeddings + 5 tile tag embeddings). +DECODER_SEQ_LEN=3200 # Language model sequence length. +MAX_POS_EMBED=8192 + +OPTIONS=" \ + --use-checkpoint-args \ + --exit-duration-in-mins 230 \ + --disable-bias-linear \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model ${WORKSPACE}/ \ + --tokenizer-prompt-format qwen2p0 \ + --transformer-impl transformer_engine \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --attention-dropout ${AD} \ + --hidden-dropout ${HD} \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --tensor-model-parallel-size 8 \ + --pipeline-model-parallel-size 4 \ + --num-layers 80 \ + --hidden-size 8192 \ + --ffn-hidden-size 29568 \ + --add-qkv-bias \ + --num-attention-heads 64 \ + --use-distributed-optimizer \ + --use-te \ + --num-workers ${NW} \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --max-position-embeddings 32768 \ + --train-samples 122880000 \ + --lr-decay-samples 25600000 \ + --lr-warmup-samples 83200 \ + --micro-batch-size ${MBZ} \ + --global-batch-size ${BZ} \ + --lr 2e-6 \ + --min-lr 2.5e-7 \ + --lr-decay-style cosine \ + --log-interval ${LI} \ + --eval-iters 10 \ + --eval-interval 500 \ + --data-path ${DATA_TRAIN} \ + --prompt-path ${SOURCE}/examples/multimodal/nvlm/nvlm_prompts.json \ + --save-interval 10000 \ + --save ${FINETUNE_DIR} \ + --load ${FINETUNE_DIR} \ + --dataloader-save ${FINETUNE_DIR}/dataloader \ + --pretrained-checkpoint ${CHECKPOINT_DIR} \ + --split 100,0,0 \ + --clip-grad 10.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --bf16 \ + --eod-mask-loss \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 448 \ + --img-w 448 \ + --dataloader-type external \ + --tensorboard-dir ${TENSORBOARD_DIR} \ + --language-model-type qwen2.0_72B \ + ${EXTRA_ARGS} \ + --allow-missing-vision-projection-checkpoint \ + --vision-model-type internvit \ + --disable-vision-class-token \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --ckpt-format torch \ + --pixel-shuffle \ + --use-tiling \ + --max-num-tiles 6 \ + --use-thumbnail \ + --use-tile-tags \ + --use-image-tag +" + + +export NVTE_APPLY_QK_LAYER_SCALING=0 +export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${ALLOW_NONDETERMINISTIC} + +# Interactive or batch mode +if [[ $BATCH -eq 0 ]]; then + torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +else + run_cmd="python -u ${SOURCE}/examples/multimodal/train.py ${OPTIONS}" + + DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` + + srun -l --verbose \ + --container-image \ + --container-mounts "" \ + --output=${LOGS_DIR}/%x_%j_$DATETIME.log \ + sh -c "${run_cmd}" + + set +x +fi