Skip to content

Commit

Permalink
Merge pull request #3 from mjkanji/shell_invoke
Browse files Browse the repository at this point in the history
Explore invoking SkyPilot using dagster_shell vs. the Python API
  • Loading branch information
mjkanji authored Mar 6, 2024
2 parents b4877cb + 1a7a8ea commit 66be9a9
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 129 deletions.
2 changes: 0 additions & 2 deletions dagster_cloud_post_install.sh

This file was deleted.

105 changes: 36 additions & 69 deletions dagster_skypilot/assets.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,28 @@
import json
import os
from pathlib import Path

import sky
import yaml
from dagster import AssetExecutionContext, asset
from dagster_shell import execute_shell_command
from upath import UPath

from dagster_skypilot.consts import DEPLOYMENT_TYPE
from dagster_skypilot.utils import populate_keyfiles


def populate_keyfiles():
"""
SkyPilot only supports reading credentials from key files and not environment
variables.
def get_metrics(context: AssetExecutionContext, bucket):
with (UPath(bucket) / context.run_id / "train_results.json").open("r") as f:
return json.load(f)

This reads the credentials for AWS and Lambda Labs from env vars (set in the
Dagster Cloud UI) and then populates the expected key files accordingly.
"""
lambda_key_file = Path.home() / ".lambda_cloud" / "lambda_keys"
aws_key_file = Path.home() / ".aws" / "credentials"

# Don't overwrite local keys, but always populate them dynamically in
# Dagster Cloud
if not DEPLOYMENT_TYPE == "local":
lambda_key_file.parent.mkdir(parents=True, exist_ok=True)
aws_key_file.parent.mkdir(parents=True, exist_ok=True)
def teardown_all_clusters(logger):
clusters = sky.status(refresh=True)

with lambda_key_file.open("w") as f:
f.write("api_key = {}".format(os.getenv("LAMBDA_LABS_API_KEY")))
for c in clusters:
logger.info(f"Shutting down cluster: {c['name']}.")
sky.down(c["name"])

with aws_key_file.open("w") as f:
f.write(
"[default]\n"
f"aws_access_key_id = {os.getenv('AWS_ACCESS_KEY_ID')}\n"
f"aws_secret_access_key = {os.getenv('AWS_SECRET_ACCESS_KEY')}\n"
)
logger.info("All clusters shut down.")


@asset(group_name="ai")
Expand All @@ -41,51 +31,28 @@ def skypilot_model(context: AssetExecutionContext) -> None:
# So, we need to populate the required keyfiles.
populate_keyfiles()

# The setup command.
setup = r"""
set -e # Exit if any command failed.
git clone https://github.com/huggingface/transformers/ || true
cd transformers
pip install .
cd examples/pytorch/text-classification
pip install -r requirements.txt
"""

# The command to run. Will be run under the working directory.
run = r"""
set -e # Exit if any command failed.
cd transformers/examples/pytorch/text-classification
python run_glue.py \
--model_name_or_path bert-base-cased \
--dataset_name imdb \
--do_train \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--max_steps 50 \
--output_dir /tmp/imdb/ --overwrite_output_dir \
--fp16
"""

# Mount an external bucket
storage_mounts = {
"/dagster-skypilot-bucket": sky.Storage(
source="s3://dagster-skypilot-bucket", mode=sky.StorageMode.MOUNT
)
}

task = sky.Task(
"huggingface",
workdir=".",
setup=setup,
run=run,
skypilot_bucket = os.getenv("SKYPILOT_BUCKET")

# The parent of the current script
parent_dir = UPath(__file__).parent
yaml_file = parent_dir / "finetune.yaml"
with yaml_file.open("r", encoding="utf-8") as f:
task_config = yaml.safe_load(f)

task = sky.Task().from_yaml_config(
config=task_config,
env_overrides={ # type: ignore
"HF_TOKEN": os.getenv("HF_TOKEN", ""),
"DAGSTER_RUN_ID": context.run_id,
"BUCKET_NAME": skypilot_bucket,
},
)
task.workdir = str(parent_dir.absolute() / "scripts")

task.set_resources(
sky.Resources(sky.Lambda(), accelerators={"A10": 1})
).set_storage_mounts(storage_mounts)

# sky.launch(task, dryrun=True)
sky.launch(task, cluster_name="dnn", idle_minutes_to_autostop=5, down=True) # type: ignore
try:
sky.launch(task, cluster_name="gemma", idle_minutes_to_autostop=5) # type: ignore
context.add_output_metadata(get_metrics(context, skypilot_bucket))

return None
finally:
teardown_all_clusters(context.log)
...
7 changes: 0 additions & 7 deletions dagster_skypilot/consts.py

This file was deleted.

57 changes: 57 additions & 0 deletions dagster_skypilot/finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
resources:
cloud: aws
accelerators: {L4, A10g, A10, L40, A40, A100, A100-80GB}
disk_tier: best

envs:
# The ID of the Dagster run that trggered the job.
# Overwritten by the Dagster process.
DAGSTER_RUN_ID: "no-run"
HF_TOKEN: "" # We'll pass this via the Dagster Cloud UI or a .env file instead
SKYPILOT_BUCKET: s3://dagster-skypilot-bucket # Change to your own bucket name
TERM: "dumb"
NO_COLOR: 1

workdir: dagster_skypilot/scripts

file_mounts:
/artifacts:
source: ${SKYPILOT_BUCKET}
mode: MOUNT

# The '|' separator indicates a multiline string.
setup: |
conda activate gemma
if [ $? -ne 0 ]; then
conda create -q -y -n gemma python=3.10
conda activate gemma
fi
echo "Installing Python dependencies."
pip install -q -U bitsandbytes==0.42.0
pip install -q -U peft==0.8.2
pip install -q -U trl==0.7.10
pip install -q -U accelerate==0.27.1
pip install -q -U datasets==2.17.0
pip install -q -U transformers==4.38.1
pip install -q "torch<2.2" torchvision --index-url https://download.pytorch.org/whl/cu121
run: |
conda activate gemma
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
# Turn off wandb
WANDB_MODE="offline"
TERM=dumb NO_COLOR=1 torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=${SKYPILOT_NODE_RANK} \
lora.py \
--model_name_or_path google/gemma-7b \
--save_steps 4 \
--output_dir /artifacts/${DAGSTER_RUN_ID}
167 changes: 167 additions & 0 deletions dagster_skypilot/scripts/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
import pathlib
import shutil
import subprocess
from dataclasses import dataclass, field
from typing import Optional

import torch
import transformers
from datasets import disable_progress_bars, load_dataset
from peft import LoraConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
GemmaTokenizer,
)
from trl import SFTTrainer


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="google/gemma-7b")


@dataclass
class TrainingArguments(transformers.TrainingArguments):
disable_tqdm: bool = field(default=True)
per_device_train_batch_size: int = field(default=1)
gradient_accumulation_steps: int = field(default=4)
warmup_steps: int = field(default=2)
max_steps: int = field(default=10)
learning_rate: float = field(default=2e-4)
fp16: bool = field(default=True)
logging_steps: int = field(default=1)

output_dir: str = field(default="outputs")
optim: str = field(default="paged_adamw_8bit")
save_steps: int = field(default=1)


class CheckpointCallback(transformers.TrainerCallback):
def on_save(self, args, state, control, **kwargs):
"""Add complete indicator to avoid incomplete checkpoints."""
if state.is_world_process_zero:
ckpt_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
with open(os.path.join(ckpt_path, "complete"), "w") as f:
f.write("")
print(f"Checkpoint {state.global_step} saved.")
torch.distributed.barrier()


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa


def cleanup_incomplete_checkpoints(output_dir):
"""Remove incomplete checkpoints."""
checkpoints = list(pathlib.Path(output_dir).glob("checkpoint-*"))
checkpoints = [c for c in checkpoints if c.name.split("-")[-1].isdigit()]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.name.split("-")[-1]), reverse=True
)
for checkpoint in checkpoints:
if not (checkpoint / "complete").exists():
print(f"Removing incomplete checkpoint {checkpoint}")
shutil.rmtree(checkpoint)
else:
print(
f"Using checkpoint {checkpoint}, copying to ~/tmp/ for "
"optimization of loading."
)
tmp_dir = os.path.expanduser("~/tmp")
os.makedirs(tmp_dir, exist_ok=True)
try:
# Optimization for checkpoint loading. This is to force the
# mounting tool to download the checkpoints in parallel first.
# It will improve the loading speed of the checkpoints
# significantly.
subprocess.run(
["gsutil", "-m", "rsync", "-r", checkpoint, tmp_dir], check=True
)
except:
print("Failed to optimize checkpoint loading. Skip.")
break


def train():
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
if local_rank == 0:
cleanup_incomplete_checkpoints(training_args.output_dir)
torch.distributed.barrier()

# Check the existence of checkpoints in all processes
# All ranks must simultaneously resume from a checkpoint if it exists.
# Otherwise, upon recovery the model weights may not reload correctly,
# causing loss spikes.
resume_from_checkpoint = False
checkpoints = list(pathlib.Path(training_args.output_dir).glob("checkpoint-*"))
checkpoints = [c for c in checkpoints if c.name.split("-")[-1].isdigit()]
if checkpoints:
resume_from_checkpoint = True

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, token=os.environ["HF_TOKEN"]
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
quantization_config=bnb_config,
device_map="auto",
token=os.environ["HF_TOKEN"],
)

lora_config = LoraConfig(
r=8,
target_modules=[
"q_proj",
"o_proj",
"k_proj",
"v_proj",
"gate_proj",
"up_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)

data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

def formatting_func(example):
text = f'Quote: {example["quote"][0]}\nAuthor: {example["author"][0]}'
return [text]

trainer = SFTTrainer(
model=model,
train_dataset=data["train"],
args=training_args,
peft_config=lora_config,
formatting_func=formatting_func,
)
trainer.add_callback(CheckpointCallback)
train_results = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)

metrics = train_results.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)


if __name__ == "__main__":
# Disable progress bars for datasets operations
disable_progress_bars()
train()
Loading

0 comments on commit 66be9a9

Please sign in to comment.