-
Notifications
You must be signed in to change notification settings - Fork 657
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial scafold Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Apply PR comments Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * rename dir Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>
- Loading branch information
1 parent
21287ee
commit 660acf1
Showing
9 changed files
with
277 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Visual Anomaly Model.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .lightning_model import VlmAd | ||
|
||
__all__ = ["VlmAd"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""VLM backends.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .base import Backend | ||
from .ollama import Ollama | ||
|
||
__all__ = ["Backend", "Ollama"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
"""Base backend.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
|
||
|
||
class Backend(ABC): | ||
"""Base backend.""" | ||
|
||
@abstractmethod | ||
def __init__(self, api_key: str | None = None) -> None: | ||
"""Initialize the backend.""" | ||
|
||
@abstractmethod | ||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Add reference images for k-shot.""" | ||
|
||
@abstractmethod | ||
def predict(self, image: str | Path) -> str: | ||
"""Predict the anomaly label.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
"""Ollama backend. | ||
Assumes that the Ollama service is running in the background. | ||
See: https://github.com/ollama/ollama | ||
Ensure that ollama is running. On linux: `ollama serve` | ||
""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
from anomalib.utils.exceptions import try_import | ||
|
||
from .base import Backend | ||
|
||
if try_import("ollama"): | ||
from ollama import chat | ||
from ollama._client import _encode_image | ||
else: | ||
chat = None | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class Prompt: | ||
"""Ollama prompt.""" | ||
|
||
few_shot: str | ||
predict: str | ||
|
||
|
||
class Ollama(Backend): | ||
"""Ollama backend.""" | ||
|
||
def __init__(self, api_key: str | None = None, model_name: str = "llava") -> None: | ||
"""Initialize the Ollama backend.""" | ||
if api_key: | ||
logger.warning("API key is not required for Ollama backend.") | ||
self.model_name: str = model_name | ||
self._ref_images_encoded: list[str] = [] | ||
|
||
def add_reference_images(self, image: str | Path) -> None: | ||
"""Encode the image to base64.""" | ||
self._ref_images_encoded.append(_encode_image(image)) | ||
|
||
@property | ||
def prompt(self) -> Prompt: | ||
"""Get the Ollama prompt.""" | ||
return Prompt( | ||
predict=( | ||
"You are given an image. It is either normal or anomalous." | ||
"First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n" | ||
"Then give the reason for your decision.\n" | ||
"For example, 'YES: The image has a crack on the wall.'" | ||
), | ||
few_shot=( | ||
"These are a few examples of normal picture without any anomalies." | ||
" You have to use these to determine if the image I provide in the next" | ||
" chat is normal or anomalous." | ||
), | ||
) | ||
|
||
def predict(self, image: str | Path) -> str: | ||
"""Predict the anomaly label.""" | ||
if not chat: | ||
msg = "Ollama is not installed. Please install it using `pip install ollama`." | ||
raise ImportError(msg) | ||
image_encoded = _encode_image(image) | ||
messages = [] | ||
|
||
# few-shot | ||
if len(self._ref_images_encoded) > 0: | ||
messages.append({ | ||
"role": "user", | ||
"images": self._ref_images_encoded, | ||
"content": self.prompt.few_shot, | ||
}) | ||
|
||
messages.append({"role": "user", "images": [image_encoded], "content": self.prompt.predict}) | ||
|
||
response = chat( | ||
model=self.model_name, | ||
messages=messages, | ||
) | ||
return response["message"]["content"].strip() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Visual Anomaly Model for Zero/Few-Shot Anomaly Classification.""" | ||
|
||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import logging | ||
from enum import Enum | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from anomalib import LearningType | ||
from anomalib.models import AnomalyModule | ||
|
||
from .backends import Backend, Ollama | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class VlmAdBackend(Enum): | ||
"""Supported VLM backends.""" | ||
|
||
OLLAMA = "ollama" | ||
|
||
|
||
class VlmAd(AnomalyModule): | ||
"""Visual anomaly model.""" | ||
|
||
def __init__( | ||
self, | ||
backend: VlmAdBackend | str = VlmAdBackend.OLLAMA, | ||
api_key: str | None = None, | ||
k_shot: int = 3, | ||
) -> None: | ||
super().__init__() | ||
self.k_shot = k_shot | ||
backend = VlmAdBackend(backend) | ||
self.vlm_backend: Backend = self._setup_vlm(backend, api_key) | ||
|
||
@staticmethod | ||
def _setup_vlm(backend: VlmAdBackend, api_key: str | None) -> Backend: | ||
match backend: | ||
case VlmAdBackend.OLLAMA: | ||
return Ollama() | ||
case _: | ||
msg = f"Unsupported VLM backend: {backend}" | ||
raise ValueError(msg) | ||
|
||
def _setup(self) -> None: | ||
if self.k_shot: | ||
logger.info("Collecting reference images from training dataset.") | ||
dataloader = self.trainer.datamodule.train_dataloader() | ||
self.collect_reference_images(dataloader) | ||
|
||
def collect_reference_images(self, dataloader: DataLoader) -> None: | ||
"""Collect reference images for few-shot inference.""" | ||
count = 0 | ||
for batch in dataloader: | ||
for img_path in batch["image_path"]: | ||
self.vlm_backend.add_reference_images(img_path) | ||
count += 1 | ||
if count == self.k_shot: | ||
return | ||
|
||
def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict: | ||
"""Validation step.""" | ||
del args, kwargs # These variables are not used. | ||
responses = [(self.vlm_backend.predict(img_path)) for img_path in batch["image_path"]] | ||
|
||
batch["str_output"] = responses | ||
batch["pred_scores"] = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device) | ||
return batch | ||
|
||
@property | ||
def learning_type(self) -> LearningType: | ||
"""The learning type of the model.""" | ||
return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT | ||
|
||
@property | ||
def trainer_arguments(self) -> dict[str, int | float]: | ||
"""Doesn't need training.""" | ||
return {} | ||
|
||
@staticmethod | ||
def configure_transforms(image_size: tuple[int, int] | None = None) -> None: | ||
"""This modes does not require any transforms.""" | ||
if image_size is not None: | ||
logger.warning("Ignoring image_size argument as each backend has its own transforms.") |
Oops, something went wrong.