-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
132 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,78 @@ | ||
name: SDXL Models Nightly | ||
name: SDXL E2E Pipeline CI | ||
|
||
on: | ||
workflow_dispatch: | ||
pull_request: | ||
schedule: | ||
- cron: '30 6 * * *' | ||
- cron: "*/50 * * * *" | ||
|
||
concurrency: | ||
# A PR number if a pull request and otherwise the commit hash. This cancels | ||
# queued and in-progress runs for the same PR (presubmit) or commit | ||
# (postsubmit). The workflow name is prepended to avoid conflicts between | ||
# different workflows. | ||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
test-sdxl-models: | ||
strategy: | ||
matrix: | ||
version: [3.11] | ||
os: [nodai-amdgpu-w7900-x86-64] | ||
os: [nodai-amdgpu-mi300-x86-64] | ||
|
||
runs-on: ${{matrix.os}} | ||
env: | ||
IREE_TOKEN: ${{ secrets.IREE_TOKEN }} | ||
steps: | ||
- name: "Setting up Python" | ||
uses: actions/setup-python@75f3110429a8c05be0e1bf360334e4cced2b63fa # v2.3.3 | ||
with: | ||
python-version: ${{matrix.version}} | ||
|
||
- name: "Checkout Code" | ||
- name: "Checkout SHARK-ModelDev" | ||
uses: actions/checkout@v4 | ||
with: | ||
ref: ean-sd-fp16 | ||
ref: bump-punet-tom | ||
|
||
- name: "Checkout iree-turbine" | ||
uses: actions/checkout@v4 | ||
with: | ||
repository: iree-org/iree-turbine | ||
# TODO: Let the ref be passed as a parameter to run integration tests. | ||
path: iree-turbine | ||
|
||
- name: "Checkout iree" | ||
uses: actions/checkout@v4 | ||
with: | ||
repository: iree-org/iree | ||
path: iree | ||
|
||
- name: Sync source deps | ||
# build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile | ||
- name: Python deps | ||
run: | | ||
python3.11 -m venv sdxl_venv | ||
source sdxl_venv/bin/activate | ||
python -m pip install --upgrade pip | ||
# Note: We install in three steps in order to satisfy requirements | ||
# from non default locations first. Installing the PyTorch CPU | ||
# wheels saves multiple minutes and a lot of bandwidth on runner setup. | ||
pip install --no-compile --index-url https://download.pytorch.org/whl/cpu \ | ||
-r ${{ github.workspace }}/iree-turbine//pytorch-cpu-requirements.txt | ||
pip install --no-compile --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt | ||
pip install --no-compile -e ${{ github.workspace }}/iree-turbine/[testing,torch-cpu-nightly] | ||
pip install --no-compile --upgrade -r models/requirements.txt | ||
pip install --no-compile -e models | ||
pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt | ||
pip install --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt | ||
pip install --no-compile --pre --upgrade -e models -r models/requirements.txt | ||
pip uninstall torch torchvision torchaudio -y | ||
pip install https://download.pytorch.org/whl/nightly/pytorch_triton_rocm-3.0.0%2B21eae954ef-cp311-cp311-linux_x86_64.whl | ||
pip install https://download.pytorch.org/whl/nightly/rocm6.1/torch-2.5.0.dev20240710%2Brocm6.1-cp311-cp311-linux_x86_64.whl | ||
pip install https://download.pytorch.org/whl/nightly/rocm6.1/torchvision-0.20.0.dev20240711%2Brocm6.1-cp311-cp311-linux_x86_64.whl | ||
pip install https://download.pytorch.org/whl/nightly/rocm6.1/torchaudio-2.4.0.dev20240711%2Brocm6.1-cp311-cp311-linux_x86_64.whl | ||
pip uninstall iree-compiler iree-runtime iree-base-compiler iree-base-runtime -y | ||
python ci-tools/latest-pkgci.py | ||
cd wheels | ||
unzip *.zip | ||
pip install *.whl | ||
cd .. | ||
rm -rf wheels | ||
- name: Show current free memory | ||
run: | | ||
free -mh | ||
- name: Run sdxl tests | ||
run: | | ||
pip install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu | ||
pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu | ||
pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux | ||
pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device rocm --iree_target_triple gfx90a | ||
source sdxl_venv/bin/activate | ||
python3 models/turbine_models/custom_models/sd_inference/sd_pipeline.py --device=hip --precision=fp16 --iree_target_triple=gfx942 --external_weights=safetensors --hf_model_name=stabilityai/stable-diffusion-xl-base-1.0 --width=1024 --height=1024 --batch_size=1 --use_i8_punet --attn_spec=punet --vae_decomp_attn --external_weights=safetensors --num_inference_steps=20 --benchmark=all --verbose |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import requests | ||
import json | ||
import os | ||
|
||
GITHUB_TOKEN = os.getenv("IREE_TOKEN") | ||
|
||
OWNER = "iree-org" | ||
REPO = "iree" | ||
|
||
API_URL = ( | ||
f"https://api.github.com/repos/{OWNER}/{REPO}/actions/workflows/pkgci.yml/runs" | ||
) | ||
|
||
|
||
# Get the latest workflow run ID for pkgci.yml | ||
def get_latest_pkgci_workflow_run(): | ||
headers = { | ||
"Accept": "application/vnd.github+json", | ||
"Authorization": f"Bearer {GITHUB_TOKEN}", | ||
"X-GitHub-Api-Version": "2022-11-28", | ||
} | ||
params = { | ||
"per_page": 1, | ||
"event": "push", | ||
"branch": "main", | ||
} | ||
response = requests.get(API_URL, headers=headers, params=params) | ||
|
||
if response.status_code == 200: | ||
data = response.json() | ||
if data["total_count"] > 0: | ||
latest_run = data["workflow_runs"][0] | ||
return latest_run["id"], latest_run["artifacts_url"] | ||
else: | ||
print("No workflow runs found for pkgci.yml.") | ||
return None | ||
else: | ||
print(f"Error fetching workflow runs: {response.status_code}") | ||
return None | ||
|
||
|
||
# Get the artifacts of a specific workflow run | ||
def get_artifacts(workflow_run_id, artifacts_url): | ||
headers = { | ||
"Accept": "application/vnd.github+json", | ||
"Authorization": f"Bearer {GITHUB_TOKEN}", | ||
"X-GitHub-Api-Version": "2022-11-28", | ||
} | ||
response = requests.get(artifacts_url, headers=headers) | ||
|
||
if response.status_code == 200: | ||
artifacts = response.json()["artifacts"] | ||
if artifacts: | ||
print(f"Artifacts for pkgci.yml workflow run {workflow_run_id}:") | ||
for artifact in artifacts: | ||
print(f"- {artifact['name']} (Size: {artifact['size_in_bytes']} bytes)") | ||
download_artifact(artifact["archive_download_url"], artifact["name"]) | ||
else: | ||
print("No artifacts found for the pkgci.yml workflow run.") | ||
else: | ||
print(f"Error fetching artifacts: {response.status_code}") | ||
|
||
|
||
# Download an artifact | ||
def download_artifact(download_url, artifact_name): | ||
headers = { | ||
"Accept": "application/vnd.github+json", | ||
"Authorization": f"Bearer {GITHUB_TOKEN}", | ||
"X-GitHub-Api-Version": "2022-11-28", | ||
} | ||
response = requests.get(download_url, headers=headers, stream=True) | ||
|
||
if response.status_code == 200: | ||
file_name = f"wheels/{artifact_name}.zip" | ||
os.mkdir("wheels") | ||
with open(file_name, "wb") as f: | ||
for chunk in response.iter_content(chunk_size=8192): | ||
if chunk: | ||
f.write(chunk) | ||
print(f"Artifact '{artifact_name}' downloaded successfully as '{file_name}'.") | ||
else: | ||
print(f"Error downloading artifact '{artifact_name}': {response.status_code}") | ||
|
||
|
||
if __name__ == "__main__": | ||
workflow_run_id, artifact_url = get_latest_pkgci_workflow_run() | ||
if workflow_run_id: | ||
get_artifacts(workflow_run_id, artifact_url) |