Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for resiliency feature integration #11406

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ jobs:
run: |
docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --env PYTHONUNBUFFERED=1 nemoci.azurecr.io/nemo_container:${{ github.run_id }} bash -c '\
# PyTorch Lightning version
python -c "import pytorch_lightning; print(pytorch_lightning.__version__)"
python -c "import lightning.pytorch; print(lightning.pytorch.__version__)"

# PyTorch Lightning DDP Checks
CUDA_VISIBLE_DEVICES="0,1" python "tests/core_ptl/check_for_ranks.py"
Expand Down Expand Up @@ -3921,6 +3921,46 @@ jobs:
--index-mapping-dir=/tmp/llm_tests/llama_index_mappings \
--cp 1 --tp 2 --sp 1

L2_NeMo_2_llama3_fault_tolerance_plugin:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_llama3_fault_tolerance_plugin') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |

mkdir -p /tmp/llm_tests/llama_pretrain_results \
export FAULT_TOL_CFG_PATH="/tmp/llm_tests/llama_pretrain_results/sample_job_ft_cfg.yml"; \
export FAULT_TOL_FINISHED_FLAG_FILE="/tmp/llm_tests/llama_pretrain_results/sample_job_finished_flag"; \
python tests/collections/llm/test_fault_nvrx.py \
--devices=2 \
--crash-step=4 \
--experiment-dir=/tmp/llm_tests/llama_pretrain_results \
--data-path=/home/TestData/nlp/megatron_llama/data/rp2_sample_sentencepiece_preproc_text_document \
--tokenizer-path=/home/TestData/nlp/megatron_llama/tokenizer.model \
--index-mapping-dir=/tmp/llm_tests/llama_index_mappings \
2>&1 | tee /tmp/llm_tests/llama_pretrain_results/run.log \

L2_NeMo_2_llama3_straggler_detection:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_llama3_straggler_detection') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |

mkdir -p /tmp/llm_tests/llama_pretrain_results \
export FAULT_TOL_CFG_PATH="/tmp/llm_tests/llama_pretrain_results/sample_job_ft_cfg.yml"; \
export FAULT_TOL_FINISHED_FLAG_FILE="/tmp/llm_tests/llama_pretrain_results/sample_job_finished_flag"; \
python tests/collections/llm/test_fault_nvrx.py \
--devices=2 \
--check-report=True \
--experiment-dir=/tmp/llm_tests/llama_pretrain_results \
--data-path=/home/TestData/nlp/megatron_llama/data/rp2_sample_sentencepiece_preproc_text_document \
--tokenizer-path=/home/TestData/nlp/megatron_llama/tokenizer.model \
--index-mapping-dir=/tmp/llm_tests/llama_index_mappings \
2>&1 | tee /tmp/llm_tests/llama_pretrain_results/run.log \

L2_NeMo_2_GPT_DDP_Param_Parity_check:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4587,6 +4627,8 @@ jobs:
- L2_NeMo_2_GPT_DDP_Param_Parity_check
- L2_NeMo_2_HF_MODEL_IMPORT
- L2_NeMo_2_llama3_pretraining_recipe
- L2_NeMo_2_llama3_fault_tolerance_plugin
- L2_NeMo_2_llama3_straggler_detection
- L2_HF_Transformer_SFT_TE_Acceleration
- L2_NeMo_2_SSM_Pretraining
- L2_NeMo_2_SSM_Finetuning
Expand Down Expand Up @@ -4760,4 +4802,4 @@ jobs:

- name: "Pipeline not successful, set exit code to 1"
if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'false' }}
run: exit 1
run: exit 1
3 changes: 3 additions & 0 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ git checkout ${MCORE_TAG} && \
popd
export PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM"

# Install nvidia-resiliency-ext
pip install --no-cache-dir "git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@97aad77609d2e25ed38ac5c99f0c13f93c48464e"

EOF

# Copy over NeMo code
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def init_ddp(self):
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
expert_data_parallel_group=parallel_state.get_expert_data_parallel_group(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think MCore in CI container needs this change

Copy link
Collaborator Author

@maanug-nv maanug-nv Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disable_bucketing=disable_bucketing,
)

Expand Down
6 changes: 3 additions & 3 deletions nemo/lightning/run/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ class FaultTolerancePlugin(run.Plugin):
This plugin enables workload hang detection, automatic calculation of timeouts used for hang detection, detection of rank(s) terminated due to an error and workload respawning in case of a failure.
Note: FaultTolerancePlugin does not work with the NsysPlugin.
Args:
num_in_process_restarts (int): Max number of restarts on failure, within the same job. Default is 3.
num_in_job_restarts (int): Max number of restarts on failure, within the same job. Default is 3.
num_job_retries_on_failure (int): Max number of new job restarts on failure. Default is 2.
initial_rank_heartbeat_timeout (int): Timeouts are time intervals used by a rank monitor to detect that a rank is not alive. This is the max timeout for the initial heartbeat. Default is 1800.
rank_heartbeat_timeout (int): This is the timeout for subsequent hearbeats after the initial heartbeat. Default is 300.
"""

num_in_process_restarts: int = 3
num_in_job_restarts: int = 3
num_job_retries_on_failure: int = 2
initial_rank_heartbeat_timeout: int = 1800
rank_heartbeat_timeout: int = 300
Expand All @@ -107,7 +107,7 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor):
assert HAVE_RES, "nvidia-resiliency-ext.ptl_resiliency is required to use the FaultTolerancePlugin."

executor.launcher = run.FaultTolerance(
max_restarts=self.num_in_process_restarts,
max_restarts=self.num_in_job_restarts,
initial_rank_heartbeat_timeout=self.initial_rank_heartbeat_timeout,
rank_heartbeat_timeout=self.rank_heartbeat_timeout,
)
Expand Down
9 changes: 6 additions & 3 deletions tests/collections/llm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os

import lightning.pytorch as pl
import nemo_run as run
import torch

from nemo import lightning as nl
Expand All @@ -27,8 +28,9 @@ def train_data(
data_path: str, tokenizer_path: str, index_mapping_dir: str, seq_length: int
) -> llm.PreTrainingDataModule:
"""Single shard dataset tokenized by SentencePiece"""
tokenizer = SentencePieceTokenizer(model_path=tokenizer_path)
return llm.PreTrainingDataModule(
tokenizer = run.Config(SentencePieceTokenizer, model_path=tokenizer_path)
return run.Config(
llm.PreTrainingDataModule,
paths=data_path,
tokenizer=tokenizer,
seq_length=seq_length,
Expand All @@ -41,7 +43,8 @@ def train_data(

def small_llama_cfg(seq_length: int) -> llm.GPTConfig:
"""Small 145m model"""
return llm.Llama3Config8B(
return run.Config(
llm.Llama3Config8B,
rotary_base=500_000,
seq_length=seq_length,
num_layers=12,
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/llm/llama3_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def main():
dir=args.experiment_dir, name=exp_name, num_gpus_per_node=args.devices
)

pretrain_recipe.model = llm.LlamaModel(small_llama_cfg(args.seq_length))
pretrain_recipe.model = run.Config(llm.LlamaModel, small_llama_cfg(args.seq_length))

if args.data_path and args.tokenizer_path:
pretrain_recipe.data = train_data(
Expand Down
134 changes: 134 additions & 0 deletions tests/collections/llm/test_fault_nvrx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test fault tolerance with LLaMA3 recipe and a smaller model.
"""

import argparse
import os

import nemo_run as run

from lightning.pytorch.callbacks import Callback

from nemo.collections import llm
from nemo.collections.llm.recipes.callbacks.common import straggler_det_callback
from nemo.lightning.run.plugins import FaultTolerancePlugin
from nemo.utils.exp_manager import TimingCallback
from tests.collections.llm.common import small_llama_cfg, train_data


class CrashCallback(Callback):
def __init__(self, crash_step=16):
self.crash_step = crash_step
self.current_step = 0
print(f"Setup to simulate a crash if step == {self.crash_step}")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.current_step = self.current_step + 1
if self.crash_step and self.current_step == self.crash_step:
raise Exception(f"Simulating a crash at step {self.crash_step}!")


def get_args():
parser = argparse.ArgumentParser(prog="", description="")
parser.add_argument('--devices', type=int, required=True, help="Number of devices to use for training")
parser.add_argument(
'--crash-step',
type=int,
help="Step when a crash should be simulated",
)
parser.add_argument(
'--check-report', type=bool, default=False, help="Check if StragglerDetection reports performance scores"
)
parser.add_argument(
'--experiment-dir', type=str, required=True, help="directory to write results and checkpoints to"
)
parser.add_argument(
'--data-path', type=str, default=None, help="Path to data file. If not specified, uses mock data."
)
parser.add_argument(
'--tokenizer-path',
type=str,
default=None,
help="Path to a sentencepiece tokenizer model file. If not specified, uses mock data.",
)
parser.add_argument('--index-mapping-dir', type=str, help="directory to write index mappings to")

return parser.parse_args()


def main():
args = get_args()

exp_name = "L2_llama3_small_pretrain_fault_tolerance_test"
pretrain_recipe = llm.llama3_8b.pretrain_recipe(
dir=args.experiment_dir, name=exp_name, num_gpus_per_node=args.devices
)

pretrain_recipe.model = run.Config(llm.LlamaModel, small_llama_cfg(1024))

if args.data_path and args.tokenizer_path:
pretrain_recipe.data = train_data(
data_path=args.data_path,
tokenizer_path=args.tokenizer_path,
index_mapping_dir=args.index_mapping_dir,
seq_length=1024,
)

# Recipe Overrides
pretrain_recipe.trainer.max_steps = 20
pretrain_recipe.trainer.log_every_n_steps = 1
# Enable ckpt save so that after the simulated crash, training can resume from ckpt
pretrain_recipe.log.ckpt.every_n_train_steps = 10
pretrain_recipe.log.ckpt.train_time_interval = None
# Disable async ckpt because the simulated crash happens during ckpt save
# So only an unfinished ckpt would be available for resume which can cause errors
pretrain_recipe.trainer.strategy.ckpt_async_save = False
pretrain_recipe.trainer.val_check_interval = 30
pretrain_recipe.trainer.limit_val_batches = 2

executor: run.SlurmExecutor = run.LocalExecutor(ntasks_per_node=args.devices, launcher="ft")
# Add the fault tolerance plugin which enables restart after a crash
run_plugins: list[run.Plugin] = [FaultTolerancePlugin(num_in_job_restarts=1, num_job_retries_on_failure=0)]
pretrain_recipe.trainer.callbacks = [
run.Config(TimingCallback),
straggler_det_callback(straggler_report_time_interval=0.5),
]

if args.crash_step:
pretrain_recipe.trainer.callbacks.append(run.Config(CrashCallback, crash_step=args.crash_step))

run.run(pretrain_recipe, plugins=run_plugins, executor=executor)

# Assumes that NeMo logs are written into "run.log"
# When a crash a simulated, error shows up on the terminal but it is not written to a file
# So the test appends run output to run.log in the experiment-dir
log_content = None
with open(os.path.join(args.experiment_dir, "run.log")) as f:
log_content = f.read()

if args.check_report:
assert "GPU relative performance" in log_content
assert "GPU individual performance" in log_content
assert "Straggler report processing time" in log_content
if args.crash_step:
assert f"Exception: Simulating a crash at step {args.crash_step}!" in log_content
assert "Restored all states from the checkpoint" in log_content
assert "`Trainer.fit` stopped: `max_steps=20` reached" in log_content


if __name__ == '__main__':
main()
Loading