diff --git a/docs/source/images/gptvad/broken.png b/docs/source/images/gptvad/broken.png new file mode 100644 index 0000000000..37748f63be Binary files /dev/null and b/docs/source/images/gptvad/broken.png differ diff --git a/docs/source/images/gptvad/good.png b/docs/source/images/gptvad/good.png new file mode 100644 index 0000000000..20688478d6 Binary files /dev/null and b/docs/source/images/gptvad/good.png differ diff --git a/pyproject.toml b/pyproject.toml index bbfd0fe1a3..989ed6d8e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ core = [ "lightning>=2.2", "torch>=2", "torchmetrics>=1.3.2", + "openai>=1.38.0", # NOTE: open-clip-torch throws the following error on v2.26.1 # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator # 'aten::_native_multi_head_attention' to ONNX opset version 14 is not supported diff --git a/src/anomalib/__init__.py b/src/anomalib/__init__.py index 1b7a30497c..05abd94901 100644 --- a/src/anomalib/__init__.py +++ b/src/anomalib/__init__.py @@ -22,3 +22,4 @@ class TaskType(str, Enum): CLASSIFICATION = "classification" DETECTION = "detection" SEGMENTATION = "segmentation" + VISUAL_PROMPTING = "visual prompting" diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 081e43d2aa..ac66160efa 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -75,10 +75,10 @@ def setup( pixel_metric_names: list[str] | dict[str, dict[str, Any]] if self.pixel_metric_names is None: pixel_metric_names = [] - elif self.task == TaskType.CLASSIFICATION: + elif self.task in (TaskType.CLASSIFICATION, TaskType.VISUAL_PROMPTING): pixel_metric_names = [] logger.warning( - "Cannot perform pixel-level evaluation when task type is classification. " + "Cannot perform pixel-level evaluation when task type is classification or language. " "Ignoring the following pixel-level metrics: %s", self.pixel_metric_names, ) diff --git a/src/anomalib/data/base/dataset.py b/src/anomalib/data/base/dataset.py index 7cfba278ac..e255613d0d 100644 --- a/src/anomalib/data/base/dataset.py +++ b/src/anomalib/data/base/dataset.py @@ -20,9 +20,11 @@ from anomalib.data.utils import LabelName, masks_to_boxes, read_image, read_mask _EXPECTED_COLUMNS_CLASSIFICATION = ["image_path", "split"] +_EXPECTED_COLUMNS_VISUAL_PROMPTING = ["image_path", "split"] _EXPECTED_COLUMNS_SEGMENTATION = [*_EXPECTED_COLUMNS_CLASSIFICATION, "mask_path"] _EXPECTED_COLUMNS_PERTASK = { "classification": _EXPECTED_COLUMNS_CLASSIFICATION, + "visual prompting": _EXPECTED_COLUMNS_VISUAL_PROMPTING, "segmentation": _EXPECTED_COLUMNS_SEGMENTATION, "detection": _EXPECTED_COLUMNS_SEGMENTATION, } @@ -169,7 +171,7 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: image = read_image(image_path, as_tensor=True) item = {"image_path": image_path, "label": label_index} - if self.task == TaskType.CLASSIFICATION: + if self.task in (TaskType.CLASSIFICATION, TaskType.VISUAL_PROMPTING): item["image"] = self.transform(image) if self.transform else image elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): # Only Anomalous (1) images have masks in anomaly datasets diff --git a/src/anomalib/data/base/depth.py b/src/anomalib/data/base/depth.py index dbd5377cb6..87c3e388a6 100644 --- a/src/anomalib/data/base/depth.py +++ b/src/anomalib/data/base/depth.py @@ -48,7 +48,7 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: depth_image = to_tensor(read_depth_image(depth_path)) item = {"image_path": image_path, "depth_path": depth_path, "label": label_index} - if self.task == TaskType.CLASSIFICATION: + if self.task in (TaskType.CLASSIFICATION, TaskType.VISUAL_PROMPTING): item["image"], item["depth_image"] = ( self.transform(image, depth_image) if self.transform else (image, depth_image) ) diff --git a/src/anomalib/deploy/inferencers/openvino_inferencer.py b/src/anomalib/deploy/inferencers/openvino_inferencer.py index 7ed44a99da..ccae672c54 100644 --- a/src/anomalib/deploy/inferencers/openvino_inferencer.py +++ b/src/anomalib/deploy/inferencers/openvino_inferencer.py @@ -277,7 +277,7 @@ def post_process(self, predictions: np.ndarray, metadata: dict | DictConfig | No pred_idx = pred_score >= metadata["image_threshold"] pred_label = LabelName.ABNORMAL if pred_idx else LabelName.NORMAL - if task == TaskType.CLASSIFICATION: + if task in (TaskType.CLASSIFICATION, TaskType.VISUAL_PROMPTING): _, pred_score = self._normalize(pred_scores=pred_score, metadata=metadata) elif task in (TaskType.SEGMENTATION, TaskType.DETECTION): if "pixel_threshold" in metadata: diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index b4bb36a875..04fc19aa1b 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -24,6 +24,7 @@ Fastflow, Fre, Ganomaly, + GptVad, Padim, Patchcore, ReverseDistillation, @@ -51,6 +52,7 @@ class UnknownModelError(ModuleNotFoundError): "Fastflow", "Fre", "Ganomaly", + "GptVad", "Padim", "Patchcore", "ReverseDistillation", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index f3a5435038..58b5ec6a49 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -14,6 +14,7 @@ from .fastflow import Fastflow from .fre import Fre from .ganomaly import Ganomaly +from .gptvad import GptVad from .padim import Padim from .patchcore import Patchcore from .reverse_distillation import ReverseDistillation @@ -34,6 +35,7 @@ "Fastflow", "Fre", "Ganomaly", + "GptVad", "Padim", "Patchcore", "ReverseDistillation", diff --git a/src/anomalib/models/image/gptvad/__init__.py b/src/anomalib/models/image/gptvad/__init__.py new file mode 100644 index 0000000000..76bd5d75bb --- /dev/null +++ b/src/anomalib/models/image/gptvad/__init__.py @@ -0,0 +1,7 @@ +"""Generative Pre-Trained Transformer (GPT) based Large Language Model (LLM).""" +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import GptVad + +__all__ = ["GptVad"] diff --git a/src/anomalib/models/image/gptvad/chatgpt.py b/src/anomalib/models/image/gptvad/chatgpt.py new file mode 100644 index 0000000000..5ad1656dc7 --- /dev/null +++ b/src/anomalib/models/image/gptvad/chatgpt.py @@ -0,0 +1,169 @@ +"""Wrapper for the OpenAI calls to the VLM model.""" +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from typing import Any + +import openai + + +class APIKeyError(Exception): + """APIKeyError error.""" + + +class GPTWrapper: + """A wrapper class for making API calls to OpenAI's GPT-4 model to detect anomalies in images. + + Environment variable OPENAI_API_KEY (str): API key for OpenAI. + https://platform.openai.com/docs/quickstart/step-2-set-up-your-api-key + Other possible models: https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4 + All models with vision capabilities: 'gpt-4-turbo-2024-04-09', 'gpt-4-turbo', + all versions of 'gpt-4o-mini', and 'gpt-4o' + + Args: + images (list[str]): List of base64 images. If only one image is provided, + it is treated as the anomalous image. If multiple images are provided, + the last one is considered anomalous, and the rest are treated as normal examples. + model_name (str): Model name for OpenAI API VLM. Default "gpt-4o" + detail (bool): If the images will be sended with high detail or low detail. + + """ + + def __init__(self, model_name: str = "gpt-4o", detail: bool = True) -> None: + openai_key = os.getenv("OPENAI_API_KEY") + self.model_name = model_name + self.detail = detail + if not openai_key: + msg = "OpenAI environment key not found.(OPENAI_API_KEY)" + raise APIKeyError(msg) + + def api_call( + self, + images: list[str], + extension: str = "png", + ) -> str: + """Makes an API call to OpenAI's GPT-4 model to detect anomalies in an image. + + Args: + images (list[str]): List of base64 images. If only one image is provided, + it is treated as the anomalous image. If multiple images are provided, + the last one is considered anomalous, and the rest are treated as normal examples. + extension (str): Extension of the group of images that needs to be checked for anomalies. Default = 'png' + + Returns: + str: The response from the GPT-4 model indicating whether the image has anomalies or not. + It returns 'NO' if there are no anomalies and 'YES: description' if there are anomalies, + where 'description' provides details of the anomaly and its position. + + Raises: + openai.error.OpenAIError: If there is an error during the API call. + """ + prompt: str = "" + + detail_img = "high" if self.detail else "low" + messages: list[dict[str, Any]] = [] + + if len(images) > 0: + # If multiple images are provided, the last one is considered anomalous, + # and the rest are treated as normal examples. + prompt = """ + You will receive a group of images that are going to be an example + of the typical image without any anomaly, + and the last image that you need to decide if it has an anomaly or not. + Answer with a 'NO' if it does not have any anomalies and 'YES: description' + where description is a description of the anomaly provided, position. + """ + + messages.append( + { + "role": "system", + "content": prompt, + }, + ) + for image in images: + image_message = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/{extension};base64,{image}", + "detail": detail_img, + }, + }, + ], + }, + ] + messages.extend(image_message) + + elif len(images) == 1: + # If only one image is provided, + # it is treated as the anomalous image. + prompt = """ + Examine the provided image carefully to determine if there is an obvious anomaly present. + Anomalies may include mechanical malfunctions, unexpected objects, safety hazards, structural damages, + or unusual patterns or defects in the objects. + + Instructions: + + 1. Thoroughly inspect the image for any irregularities or deviations from normal operating conditions. + + 2. Clearly state if an obvious anomaly is detected. + - If an anomaly is detected, begin with 'YES,' followed by a detailed description of the anomaly. + - If no anomaly is detected, simply state 'NO' and end the analysis. + + Example Output Structure: + + 'YES: + - Description: Conveyor belt misalignment causing potential blockages. + This may result in production delays and equipment damage. + Immediate realignment and inspection are recommended.' + + 'NO' + + Considerations: + + - Ensure accuracy in identifying anomalies to prevent overlooking critical issues. + - Provide clear and concise descriptions for any detected anomalies. + - Focus on obvious anomalies that could impact final use of the object operation or safety. + """ + messages.append( + { + "role": "system", + "content": prompt, + }, + ) + # Add the single image + messages.append( + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/{extension};base64,{images[0]}", + "detail": detail_img, + }, + }, + ], + }, + ) + else: + msg = "No images provided for anomaly detection." + raise ValueError(msg) + + try: + # Make the API call using the openai library + response = openai.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=300, + ) + return response.choices[-1].message.content or "" + except Exception: + msg = "Error generating a response with OpenAI API." + logging.exception(msg) + raise diff --git a/src/anomalib/models/image/gptvad/lightning_model.py b/src/anomalib/models/image/gptvad/lightning_model.py new file mode 100644 index 0000000000..ab8a3e76a3 --- /dev/null +++ b/src/anomalib/models/image/gptvad/lightning_model.py @@ -0,0 +1,155 @@ +"""OpenAI Visual Large Model: Zero-/Few-Shot Anomaly Classification. + +Paper (No paper) +""" +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import logging +from pathlib import Path + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch.utils.data import DataLoader + +from anomalib import LearningType +from anomalib.metrics.threshold import ManualThreshold +from anomalib.models.components import AnomalyModule + +from .chatgpt import GPTWrapper + +logger = logging.getLogger(__name__) + +__all__ = ["GptVad"] + + +class GptVad(AnomalyModule): + """OpenAI VLM Lightning model using OpenAI's GPT-4 for image anomaly detection. + + Args: + k_shot(int): The number of images that will compare to detect if it is an anomaly. + model_name (str): The OpenAI VLM for visual anomaly detection. + detail (bool): The detail of the input in the vlm for the image detection 'high'(true) 'low'(false). + """ + + def __init__( + self, + k_shot: int = 0, + model_name: str = "gpt-4o", + detail: bool = True, + ) -> None: + super().__init__() + + self.k_shot = k_shot + + self.model_name = model_name + self.detail = detail + self.image_threshold = ManualThreshold() + self.vlm = GPTWrapper(model_name=self.model_name, detail=self.detail) + + def _setup(self) -> None: + dataloader = self.trainer.datamodule.train_dataloader() + pre_images = self.collect_reference_images(dataloader) + self.pre_images = pre_images + + def _encode_image(self, image_path: str) -> str: + """Function to encode the image into base64 to send it with the prompt.""" + path = Path(image_path) + with path.open("rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict[str, str | torch.Tensor]: + """Train Step of LLM.""" + del args, kwargs # These variables are not used. + # no train on llm + return batch + + @staticmethod + def configure_optimizers() -> None: + """OpenaiVlm doesn't require optimization, therefore returns no optimizers.""" + return + + def validation_step( + self, + batch: dict[str, str | list[str] | torch.Tensor], + *args, + **kwargs, + ) -> STEP_OUTPUT: + """Get batch of anomaly maps from input image batch. + + Args: + batch (dict[str, str | list[str] | torch.Tensor]): Batch containing image filename, image, label and mask + args: Additional arguments. + kwargs: Additional keyword arguments. + + Returns: + dict[str, Any]: str_otput and pred_scores, the output of the Llm and pred_scores 1.0 if is an anomaly image. + """ + del args, kwargs # These variables are not used. + batch_size = len(batch["image_path"]) + outputs: list[str] = [] + predictions: list[float] = [] + for i in range(batch_size): + # Getting the base64 string + base64_images = [self._encode_image(img) for img in self.pre_images] + base64_images.append(self._encode_image(batch["image_path"][i])) + + try: + output = self.vlm.api_call(base64_images) + except Exception: + logging.exception( + f"Error calling openAI API for image {batch['image_path'][i]}", + ) + output = "Error" + + # set an error and get to normal if not followed + prediction = 0.0 + if output.startswith("N"): + prediction = 0.0 + elif output.startswith("Y"): + prediction = 1.0 + else: + logging.warning( + f"(Set predition to '0' Normal)Could not identify if there is anomaly by the output:\n{output}", + ) + + outputs.append(output) + predictions.append(prediction) + logging.debug(f"Output: {output}, Prediction: {prediction}") + + batch["str_output"] = outputs + batch["pred_scores"] = torch.tensor(predictions).to(self.device) + batch["pred_labels"] = torch.tensor(predictions).to(self.device) + return batch + + @property + def trainer_arguments(self) -> dict[str, int | float]: + """Set model-specific trainer arguments.""" + return {} + + @property + def learning_type(self) -> LearningType: + """The learning type of the model. + + Llm is a zero-/few-shot model, depending on the user configuration. Therefore, the learning type is + set to ``LearningType.FEW_SHOT`` when ``k_shot`` is greater than zero and ``LearningType.ZERO_SHOT`` otherwise. + """ + return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT + + def collect_reference_images(self, dataloader: DataLoader) -> list[str]: + """Collect reference images for few-shot inference. + + The reference images are collected by iterating the training dataset until the required number of images are + collected. + + Returns: + ref_images list[str]: A list containing the reference images path. + """ + reference_images_paths: list[str] = [] + for batch in dataloader: + image_paths = batch["image_path"][: self.k_shot - len(reference_images_paths)] + reference_images_paths.extend(image_paths) + if self.k_shot == len(reference_images_paths): + break + return reference_images_paths diff --git a/src/anomalib/models/image/gptvad/readme.md b/src/anomalib/models/image/gptvad/readme.md new file mode 100644 index 0000000000..f775caf9f2 --- /dev/null +++ b/src/anomalib/models/image/gptvad/readme.md @@ -0,0 +1,98 @@ +# GptVad: Zero-/Few-Shot Anomaly Classification + +This repository contains the implementation of the `OpenAI VLM`, a model designed for zero-shot and few-shot anomaly detection using OpenAI's GPT-4 for image analysis. + +## Description + +The `OpenAI VLM` is an anomaly detection model that leverages OpenAI's GPT-4 to identify anomalies in images. It supports both zero-shot and few-shot modes: + +- **Zero-Shot Mode**: Direct anomaly detection without any prior examples of normal images. +- **Few-Shot Mode**: Anomaly detection using a small set of normal reference images to improve accuracy. + +The model operates by encoding images into base64 format and passing them to the GPT-4 API. In zero-shot mode, the model analyzes the image directly. In few-shot mode, the model compares the target image with a set of reference images to detect anomalies. + +## Features + +- **Zero-/Few-Shot Learning**: Capable of performing anomaly detection without training (zero-shot) or with a few normal examples (few-shot). +- **OpenAI GPT-4 Integration**: Utilizes the latest advancements in natural language processing and image understanding for anomaly detection. + +## Usage + +### Zero-Shot Anomaly Detection + +In zero-shot mode, the model does not require any reference images: + +```python +from anomalib.data import MVTec +from anomalib.engine import Engine +from anomalib.models import GptVad +from dotenv import load_dotenv + +# Load the environment variables from the .env file +# The implementation searchs for an environment variable OPENAI_API_KEY +# that will contain the key of OpenAI. + +# load from .env to an environment variable. +load_dotenv() + +model = GptVad(k_shot=0) +engine = Engine(task=TaskType.VISUAL_PROMPTING) +datamodule = MVTec( + category=bottle, + train_batch_size=1, + eval_batch_size=1, + num_workers=0, + ) +engine.test(model=model, datamodule=datamodule) +``` + +### Few-Shot Anomaly Detection + +In few-shot mode, the model uses a small set of normal reference images: + +```python +from anomalib.data import MVTec +from anomalib.engine import Engine +from anomalib.models import GptVad +from dotenv import load_dotenv + +# Load the environment variables from the .env file +# load_dotenv(dotenv_path=env_path) +load_dotenv() + +model = GptVad(k_shot=2) +engine = Engine(task=TaskType.VISUAL_PROMPTING) +datamodule = MVTec( + category=bottle, + train_batch_size=1, + eval_batch_size=1, + num_workers=0, + ) +engine.test(model=model, datamodule=datamodule) +``` + +## Parameters + +| Parameter | Type | Description | Default | +| ------------ | ---- | ----------------------------------------------------------------------------------------------- | -------------------------- | +| `k_shot` | int | Number of normal reference images used in few-shot mode. | `0` | +| `model_name` | str | The OpenAI VLM for the image detection. | `"gpt-4o-mini-2024-07-18"` | +| `detail` | bool | The detail level of the input in the VLM for image detection: 'high' (`true`), 'low' (`false`). | `True` | + +## Example Outputs + +The model returns a response indicating whether an anomaly is detected: + +- **Zero-Shot/Few-Shot Example**: + + ```plaintext + "NO" + ``` + + ![GptVad result no anomaly](/docs/source/images/gptvad/good.png "GptVad without anomaly result") + + ```plaintext + "YES: Description of the detected anomaly." + ``` + + ![GptVad result with anomaly](/docs/source/images/gptvad/broken.png "GptVad with Anomaly result") diff --git a/src/anomalib/utils/visualization/image.py b/src/anomalib/utils/visualization/image.py index d2e1cb0d6e..2e7eec9f7c 100644 --- a/src/anomalib/utils/visualization/image.py +++ b/src/anomalib/utils/visualization/image.py @@ -1,8 +1,10 @@ """Image/video generator.""" + # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import textwrap from collections.abc import Iterator from enum import Enum from pathlib import Path @@ -39,6 +41,7 @@ def __init__( image: np.ndarray, pred_score: float, pred_label: str, + text_descr: str | None = None, anomaly_map: np.ndarray | None = None, gt_mask: np.ndarray | None = None, pred_mask: np.ndarray | None = None, @@ -47,6 +50,7 @@ def __init__( box_labels: np.ndarray | None = None, normalize: bool = False, ) -> None: + self.text_descr = text_descr self.anomaly_map = anomaly_map self.box_labels = box_labels self.gt_boxes = gt_boxes @@ -93,6 +97,7 @@ def __repr__(self) -> str: repr_str += f", segmentations={self.segmentations}" if self.segmentations is not None else "" repr_str += f", normal_boxes={self.normal_boxes}" if self.normal_boxes is not None else "" repr_str += f", anomalous_boxes={self.anomalous_boxes}" if self.anomalous_boxes is not None else "" + repr_str += f", text_descr={self.text_descr}" if self.text_descr is not None else "" repr_str += ")" return repr_str @@ -160,6 +165,7 @@ def _visualize_batch(self, batch: dict) -> Iterator[GeneratorResult]: image_result = ImageResult( image=image, + text_descr=batch["str_output"][i] if "str_output" in batch else None, pred_score=batch["pred_scores"][i].cpu().numpy().item() if "pred_scores" in batch else None, pred_label=batch["pred_labels"][i].cpu().numpy().item() if "pred_labels" in batch else None, anomaly_map=batch["anomaly_maps"][i].cpu().numpy() if "anomaly_maps" in batch else None, @@ -236,6 +242,15 @@ def _visualize_full(self, image_result: ImageResult) -> np.ndarray: else: image_classified = add_normal_label(image_result.image, 1 - image_result.pred_score) image_grid.add_image(image=image_classified, title="Prediction") + elif self.task == TaskType.VISUAL_PROMPTING: + description = "" + if image_result.text_descr: + description = image_result.text_descr + if image_result.pred_label: + image_classified = add_anomalous_label(image_result.image, image_result.pred_score) + else: + image_classified = add_normal_label(image_result.image, 1 - image_result.pred_score) + image_grid.add_image(image_classified, title=description) return image_grid.generate() @@ -274,6 +289,22 @@ def _visualize_simple(self, image_result: ImageResult) -> np.ndarray: else: image_classified = add_normal_label(image_result.image, 1 - image_result.pred_score) return image_classified + + if self.task == TaskType.VISUAL_PROMPTING: + image_grid = _ImageGrid() + description = "" + if image_result.text_descr: + description = image_result.text_descr + + if image_result.pred_label: + image_classified = add_anomalous_label(image_result.image, image_result.pred_score) + else: + image_classified = add_normal_label(image_result.image, 1 - image_result.pred_score) + + image_grid.add_image(image_classified, title=description) + + return image_grid.generate() + msg = f"Unknown task type: {self.task}" raise ValueError(msg) @@ -290,7 +321,12 @@ def __init__(self) -> None: self.figure: matplotlib.figure.Figure | None = None self.axis: Axes | np.ndarray | None = None - def add_image(self, image: np.ndarray, title: str | None = None, color_map: str | None = None) -> None: + def add_image( + self, + image: np.ndarray, + title: str | None = None, + color_map: str | None = None, + ) -> None: """Add an image to the grid. Args: @@ -323,7 +359,15 @@ def generate(self) -> np.ndarray: axis.axes.yaxis.set_visible(b=False) axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255) if image_dict["title"] is not None: - axis.title.set_text(image_dict["title"]) + wrapped_text = textwrap.fill( + image_dict["title"], + width=70 // num_cols, + ) # Adjust 'width' based on your subplot size and preference + + axis.set_title(wrapped_text, fontsize=10) + + self.figure.subplots_adjust(top=0.7) + self.figure.canvas.draw() # convert canvas to numpy array to prepare for visualization with opencv img = np.frombuffer(self.figure.canvas.tostring_rgb(), dtype=np.uint8) diff --git a/tests/integration/model/test_models.py b/tests/integration/model/test_models.py index 09a4749b84..8c3262dec2 100644 --- a/tests/integration/model/test_models.py +++ b/tests/integration/model/test_models.py @@ -6,7 +6,9 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import os from pathlib import Path +from unittest.mock import patch import pytest @@ -143,6 +145,9 @@ def test_export( dataset_path (Path): Root to dataset from fixture. project_path (Path): Path to temporary project folder from fixture. """ + if model_name == "gpt_vad": + pytest.skip(f"{model_name} can not be exported") + if model_name == "rkde": # TODO(ashwinvaidya17): Restore this test after fixing the issue # https://github.com/openvinotoolkit/anomalib/issues/1513 @@ -176,11 +181,23 @@ def _get_objects( tuple[AnomalyModule, AnomalibDataModule, Engine]: Returns the created objects for model, dataset, and engine """ + # Mock the GPTWrapper if the model_name is "gpt_vad" + if model_name == "gpt_vad": + os.environ["OPENAI_API_KEY"] = "fake-api-key" + with ( + patch("anomalib.models.image.gptvad.chatgpt.GPTWrapper") as mock_gptwrapper, + ): + mock_instance = mock_gptwrapper.return_value + mock_instance.api_call.return_value = "NO" + self.mock_gptwrapper = mock_gptwrapper # Store the mock for potential later use + # select task type if model_name in ("rkde", "ai_vad"): task_type = TaskType.DETECTION elif model_name in ("ganomaly", "dfkde"): task_type = TaskType.CLASSIFICATION + elif model_name in ("gpt_vad"): + task_type = TaskType.VISUAL_PROMPTING else: task_type = TaskType.SEGMENTATION diff --git a/tests/unit/models/image/gptvad/__init__.py b/tests/unit/models/image/gptvad/__init__.py new file mode 100644 index 0000000000..724f891c65 --- /dev/null +++ b/tests/unit/models/image/gptvad/__init__.py @@ -0,0 +1,4 @@ +"""Unit tests for GptVad zero-/few-shot anomaly detection model.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/models/image/gptvad/test_api.py b/tests/unit/models/image/gptvad/test_api.py new file mode 100644 index 0000000000..7545c2ee6b --- /dev/null +++ b/tests/unit/models/image/gptvad/test_api.py @@ -0,0 +1,72 @@ +"""Unit tests for GptVad OpenAI Api funtions.""" +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +from pytest_mock import MockerFixture + +from anomalib.models.image.gptvad.chatgpt import GPTWrapper + + +class TestGPTWrapper: + """Unit tests for api_call.""" + + @pytest.fixture(autouse=True) + def _mock_env(self, mocker: MockerFixture) -> None: + """Fixture to automatically patch environment variables.""" + mocker.patch.dict(os.environ, {"OPENAI_API_KEY": "fake-api-key"}) + + def test_api_call(self, mocker: MockerFixture) -> None: + """Tests for api_call positive response and few shot.""" + # Set up the mock response from OpenAI + mock_response = mocker.MagicMock() + mock_response.choices = [mocker.MagicMock(message=mocker.MagicMock(content="YES: Anomaly detected."))] + + # Mock the openai.chat.completions.create function + mock_openai_create = mocker.patch("anomalib.models.image.gptvad.chatgpt.openai.chat.completions.create") + mock_openai_create.return_value = mock_response + + # Initialize the GPTWrapper instance + wrapper = GPTWrapper(model_name="gpt-4o-mini-2024-07-18", detail=True) + + # Prepare test images (simulated base64 encoded strings) + test_images = ["base64encodedimage1", "base64encodedimage2"] + + # Call the api_call method + response = wrapper.api_call(images=test_images) + + # Check if the response matches the expected output + assert response == "YES: Anomaly detected." + + # Check if the openai API was called with the expected parameters + mock_openai_create.assert_called_once_with( + model="gpt-4o-mini-2024-07-18", + messages=mocker.ANY, # Ignore specific messages content in this check + max_tokens=300, + ) + + def test_api_call_no_anomaly(self, mocker: MockerFixture) -> None: + """Tests for api_call negative response and zero shot.""" + # Set up the mock response from OpenAI + mock_response = mocker.MagicMock() + mock_response.choices = [mocker.MagicMock(message=mocker.MagicMock(content="NO"))] + # Mock the openai.chat.completions.create function + mock_openai_create = mocker.patch("anomalib.models.image.gptvad.chatgpt.openai.chat.completions.create") + mock_openai_create.return_value = mock_response + + # Initialize the GPTWrapper instance + wrapper = GPTWrapper(model_name="gpt-4o-mini-2024-07-18", detail=False) + + # Prepare test images (simulated base64 encoded strings) + test_images = ["base64encodedimage1"] + + # Call the api_call method + response = wrapper.api_call(images=test_images) + + # Check if the response matches the expected output + assert response == "NO" + + # Check if the openai API was called correctly + mock_openai_create.assert_called_once() diff --git a/tests/unit/utils/test_visualizer.py b/tests/unit/utils/test_visualizer.py index 19a905e558..be6003ed5b 100644 --- a/tests/unit/utils/test_visualizer.py +++ b/tests/unit/utils/test_visualizer.py @@ -15,7 +15,7 @@ from anomalib.data import MVTec, PredictDataset from anomalib.engine import Engine from anomalib.models import get_model -from anomalib.utils.visualization.image import _ImageGrid +from anomalib.utils.visualization.image import ImageResult, ImageVisualizer, VisualizationMode, _ImageGrid def test_visualize_fully_defected_masks() -> None: @@ -35,6 +35,37 @@ def test_visualize_fully_defected_masks() -> None: assert np.all(plotted_img[0][..., 0] == 255) +def test_model_visualizer_visual_prompting() -> None: + """Test visualizer image on TaskType.VISUAL_PROMPTING.""" + anomaly_map = np.zeros((100, 100), dtype=np.float64) + anomaly_map[10:20, 10:20] = 1.0 + gt_mask = np.zeros((100, 100)) + gt_mask[15:25, 15:25] = 1.0 + rng = np.random.default_rng() + image = rng.integers(0, 255, size=(100, 100, 3), dtype=np.uint8) + + image_result = ImageResult( + image=image, + pred_score=0.9, + pred_label="abnormal", + text_descr=( + "Some very long text to see how it is formatted in the image" + " Some very long text to see how it is formatted in the image" + ), + anomaly_map=anomaly_map, + gt_mask=gt_mask, + pred_mask=anomaly_map, + ) + + image_visualizer = ImageVisualizer( + mode=VisualizationMode.FULL, + task=TaskType.VISUAL_PROMPTING, + ) + result = image_visualizer.visualize_image(image_result) + + assert result.shape == (500, 500, 3) + + class TestVisualizer: """Test visualization callback for test and predict with different task types."""