From bfacd08a7f3f83e12eb0daa8debf364e2298fdf2 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 15 Apr 2024 15:00:25 +0200 Subject: [PATCH] perf: pre-warm ImageToVideo SFAST pipeline (#57) This commit ensures that the model is pre-traced when SFAST is enabled for the ImageToVideo pipeline. Without this pre-tracing the first request will be slower than a non SFAST call. --- runner/app/pipelines/image_to_image.py | 14 +++++++-- runner/app/pipelines/image_to_video.py | 42 ++++++++++++++++++++++++-- runner/app/pipelines/text_to_image.py | 14 +++++++-- 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index efa4a958..68d96892 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -92,15 +92,25 @@ def __init__(self, model_id: str): model_id, **kwargs ).to(torch_device) - if os.environ.get("SFAST"): + if os.getenv("SFAST", "").strip().lower() == "true": logger.info( - "ImageToImagePipeline will be dynamicallly compiled with stable-fast for %s", + "ImageToImagePipeline will be dynamically compiled with stable-fast " + "for %s", model_id, ) from app.pipelines.sfast import compile_model self.ldm = compile_model(self.ldm) + # Warm-up the pipeline. + # TODO: Not yet supported for ImageToImagePipeline. + if os.getenv("SFAST_WARMUP", "true").lower() == "true": + logger.warning( + "The 'SFAST_WARMUP' flag is not yet supported for the " + "ImageToImagePipeline and will be ignored. As a result the first " + "call may be slow if 'SFAST' is enabled." + ) + def __call__(self, prompt: str, image: PIL.Image, **kwargs) -> List[PIL.Image]: seed = kwargs.pop("seed", None) if seed is not None: diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index 5967a672..d2e8a907 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -8,6 +8,7 @@ from typing import List import logging import os +import time from PIL import ImageFile @@ -15,6 +16,8 @@ logger = logging.getLogger(__name__) +SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled. + class ImageToVideoPipeline(Pipeline): def __init__(self, model_id: str): @@ -40,17 +43,52 @@ def __init__(self, model_id: str): self.ldm = StableVideoDiffusionPipeline.from_pretrained(model_id, **kwargs) self.ldm.to(get_torch_device()) - if os.environ.get("SFAST"): + if os.getenv("SFAST", "").strip().lower() == "true": logger.info( - "ImageToVideoPipeline will be dynamicallly compiled with stable-fast for %s", + "ImageToVideoPipeline will be dynamically compiled with stable-fast " + "for %s", model_id, ) from app.pipelines.sfast import compile_model self.ldm = compile_model(self.ldm) + # Warm-up the pipeline. + # NOTE: Initial calls may be slow due to compilation. Subsequent calls will + # be faster. + if os.getenv("SFAST_WARMUP", "true").lower() == "true": + # Retrieve default model params. + # TODO: Retrieve defaults from Pydantic class in route. + warmup_kwargs = { + "image": PIL.Image.new("RGB", (576, 1024)), + "height": 576, + "width": 1024, + "fps": 6, + "motion_bucket_id": 127, + "noise_aug_strength": 0.02, + "decode_chunk_size": 25, + } + + logger.info("Warming up ImageToVideoPipeline pipeline...") + total_time = 0 + for ii in range(SFAST_WARMUP_ITERATIONS): + t = time.time() + try: + self.ldm(**warmup_kwargs).frames + except Exception as e: + # FIXME: When out of memory, pipeline is corrupted. + logger.error(f"ImageToVideoPipeline warmup error: {e}") + raise e + iteration_time = time.time() - t + total_time += iteration_time + logger.info( + "Warmup iteration %s took %s seconds", ii + 1, iteration_time + ) + logger.info("Total warmup time: %s seconds", total_time) + def __call__(self, image: PIL.Image, **kwargs) -> List[List[PIL.Image]]: if "decode_chunk_size" not in kwargs: + # Decrease decode_chunk_size to reduce memory usage. kwargs["decode_chunk_size"] = 4 seed = kwargs.pop("seed", None) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 4b05e871..9c58ad9b 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -111,15 +111,25 @@ def __init__(self, model_id: str): self.ldm.vae.decode, mode="max-autotune", fullgraph=True ) - if os.environ.get("SFAST"): + if os.getenv("SFAST", "").strip().lower() == "true": logger.info( - "TextToImagePipeline will be dynamicallly compiled with stable-fast for %s", + "TextToImagePipeline will be dynamically compiled with stable-fast for " + "%s", model_id, ) from app.pipelines.sfast import compile_model self.ldm = compile_model(self.ldm) + # Warm-up the pipeline. + # TODO: Not yet supported for ImageToImagePipeline. + if os.getenv("SFAST_WARMUP", "true").lower() == "true": + logger.warning( + "The 'SFAST_WARMUP' flag is not yet supported for the " + "TextToImagePipeline and will be ignored. As a result the first " + "call may be slow if 'SFAST' is enabled." + ) + def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]: seed = kwargs.pop("seed", None) if seed is not None: