Skip to content

Commit

Permalink
SDXL E2E Pipeline CI (#885)
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri authored Nov 21, 2024
1 parent 541572a commit d551ab1
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 22 deletions.
66 changes: 44 additions & 22 deletions .github/workflows/test_sdxl.yml
Original file line number Diff line number Diff line change
@@ -1,56 +1,78 @@
name: SDXL Models Nightly
name: SDXL E2E Pipeline CI

on:
workflow_dispatch:
pull_request:
schedule:
- cron: '30 6 * * *'
- cron: "*/50 * * * *"

concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

jobs:
test-sdxl-models:
strategy:
matrix:
version: [3.11]
os: [nodai-amdgpu-w7900-x86-64]
os: [nodai-amdgpu-mi300-x86-64]

runs-on: ${{matrix.os}}
env:
IREE_TOKEN: ${{ secrets.IREE_TOKEN }}
steps:
- name: "Setting up Python"
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
- name: "Checkout SHARK-ModelDev"
uses: actions/checkout@v4
with:
ref: ean-sd-fp16
ref: bump-punet-tom

- name: "Checkout iree-turbine"
uses: actions/checkout@v4
with:
repository: iree-org/iree-turbine
# TODO: Let the ref be passed as a parameter to run integration tests.
path: iree-turbine

- name: "Checkout iree"
uses: actions/checkout@v4
with:
repository: iree-org/iree
path: iree

- name: Sync source deps
# build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile
- name: Python deps
run: |
python3.11 -m venv sdxl_venv
source sdxl_venv/bin/activate
python -m pip install --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile --index-url https://download.pytorch.org/whl/cpu \
-r ${{ github.workspace }}/iree-turbine//pytorch-cpu-requirements.txt
pip install --no-compile --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
pip install --no-compile -e ${{ github.workspace }}/iree-turbine/[testing,torch-cpu-nightly]
pip install --no-compile --upgrade -r models/requirements.txt
pip install --no-compile -e models
pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt
pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
pip install --no-compile --pre --upgrade -e models -r models/requirements.txt
pip uninstall torch torchvision torchaudio -y
pip install https://download.pytorch.org/whl/nightly/pytorch_triton_rocm-3.0.0%2B21eae954ef-cp311-cp311-linux_x86_64.whl
pip install https://download.pytorch.org/whl/nightly/rocm6.1/torch-2.5.0.dev20240710%2Brocm6.1-cp311-cp311-linux_x86_64.whl
pip install https://download.pytorch.org/whl/nightly/rocm6.1/torchvision-0.20.0.dev20240711%2Brocm6.1-cp311-cp311-linux_x86_64.whl
pip install https://download.pytorch.org/whl/nightly/rocm6.1/torchaudio-2.4.0.dev20240711%2Brocm6.1-cp311-cp311-linux_x86_64.whl
pip uninstall iree-compiler iree-runtime iree-base-compiler iree-base-runtime -y
python ci-tools/latest-pkgci.py
cd wheels
unzip *.zip
pip install *.whl
cd ..
rm -rf wheels
- name: Show current free memory
run: |
free -mh
- name: Run sdxl tests
run: |
pip install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu
pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device rocm --iree_target_triple gfx90a
source sdxl_venv/bin/activate
python3 models/turbine_models/custom_models/sd_inference/sd_pipeline.py --device=hip --precision=fp16 --iree_target_triple=gfx942 --external_weights=safetensors --hf_model_name=stabilityai/stable-diffusion-xl-base-1.0 --width=1024 --height=1024 --batch_size=1 --use_i8_punet --attn_spec=punet --vae_decomp_attn --external_weights=safetensors --num_inference_steps=20 --benchmark=all --verbose
88 changes: 88 additions & 0 deletions ci-tools/latest-pkgci.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import requests
import json
import os

GITHUB_TOKEN = os.getenv("IREE_TOKEN")

OWNER = "iree-org"
REPO = "iree"

API_URL = (
f"https://api.github.com/repos/{OWNER}/{REPO}/actions/workflows/pkgci.yml/runs"
)


# Get the latest workflow run ID for pkgci.yml
def get_latest_pkgci_workflow_run():
headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {GITHUB_TOKEN}",
"X-GitHub-Api-Version": "2022-11-28",
}
params = {
"per_page": 1,
"event": "push",
"branch": "main",
}
response = requests.get(API_URL, headers=headers, params=params)

if response.status_code == 200:
data = response.json()
if data["total_count"] > 0:
latest_run = data["workflow_runs"][0]
return latest_run["id"], latest_run["artifacts_url"]
else:
print("No workflow runs found for pkgci.yml.")
return None
else:
print(f"Error fetching workflow runs: {response.status_code}")
return None


# Get the artifacts of a specific workflow run
def get_artifacts(workflow_run_id, artifacts_url):
headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {GITHUB_TOKEN}",
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get(artifacts_url, headers=headers)

if response.status_code == 200:
artifacts = response.json()["artifacts"]
if artifacts:
print(f"Artifacts for pkgci.yml workflow run {workflow_run_id}:")
for artifact in artifacts:
print(f"- {artifact['name']} (Size: {artifact['size_in_bytes']} bytes)")
download_artifact(artifact["archive_download_url"], artifact["name"])
else:
print("No artifacts found for the pkgci.yml workflow run.")
else:
print(f"Error fetching artifacts: {response.status_code}")


# Download an artifact
def download_artifact(download_url, artifact_name):
headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {GITHUB_TOKEN}",
"X-GitHub-Api-Version": "2022-11-28",
}
response = requests.get(download_url, headers=headers, stream=True)

if response.status_code == 200:
file_name = f"wheels/{artifact_name}.zip"
os.mkdir("wheels")
with open(file_name, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print(f"Artifact '{artifact_name}' downloaded successfully as '{file_name}'.")
else:
print(f"Error downloading artifact '{artifact_name}': {response.status_code}")


if __name__ == "__main__":
workflow_run_id, artifact_url = get_latest_pkgci_workflow_run()
if workflow_run_id:
get_artifacts(workflow_run_id, artifact_url)

0 comments on commit d551ab1

Please sign in to comment.