Skip to content

Commit

Permalink
🔨 Scaffold for refactor (#2340)
Browse files Browse the repository at this point in the history
* initial scafold

Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>

* Apply PR comments

Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>

* rename dir

Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>

---------

Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 authored Oct 4, 2024
1 parent 21287ee commit 660acf1
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 88 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ core = [
"open-clip-torch>=2.23.0,<2.26.1",
]
openvino = ["openvino>=2024.0", "nncf>=2.10.0", "onnx>=1.16.0"]
vlm = ["ollama", "transformers"]
loggers = [
"comet-ml>=3.31.7",
"gradio>=4",
Expand Down Expand Up @@ -84,7 +85,7 @@ test = [
"coverage[toml]",
"tox",
]
full = ["anomalib[core,openvino,loggers,notebooks]"]
full = ["anomalib[core,openvino,loggers,notebooks, vlm]"]
dev = ["anomalib[full,docs,test]"]

[project.scripts]
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Rkde,
Stfpm,
Uflow,
VlmAd,
WinClip,
)
from .video import AiVad
Expand Down Expand Up @@ -62,6 +63,7 @@ class UnknownModelError(ModuleNotFoundError):
"Stfpm",
"Uflow",
"AiVad",
"VlmAd",
"WinClip",
"Llm",
"Llmollama",
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .rkde import Rkde
from .stfpm import Stfpm
from .uflow import Uflow
from .vlm import VlmAd
from .winclip import WinClip

__all__ = [
Expand All @@ -44,6 +45,7 @@
"Rkde",
"Stfpm",
"Uflow",
"VlmAd",
"WinClip",
"Llm",
"Llmollama",
Expand Down
8 changes: 8 additions & 0 deletions src/anomalib/models/image/vlm_ad/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Visual Anomaly Model."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import VlmAd

__all__ = ["VlmAd"]
9 changes: 9 additions & 0 deletions src/anomalib/models/image/vlm_ad/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""VLM backends."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .base import Backend
from .ollama import Ollama

__all__ = ["Backend", "Ollama"]
23 changes: 23 additions & 0 deletions src/anomalib/models/image/vlm_ad/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Base backend."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from pathlib import Path


class Backend(ABC):
"""Base backend."""

@abstractmethod
def __init__(self, api_key: str | None = None) -> None:
"""Initialize the backend."""

@abstractmethod
def add_reference_images(self, image: str | Path) -> None:
"""Add reference images for k-shot."""

@abstractmethod
def predict(self, image: str | Path) -> str:
"""Predict the anomaly label."""
89 changes: 89 additions & 0 deletions src/anomalib/models/image/vlm_ad/backends/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Ollama backend.
Assumes that the Ollama service is running in the background.
See: https://github.com/ollama/ollama
Ensure that ollama is running. On linux: `ollama serve`
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from dataclasses import dataclass
from pathlib import Path

from anomalib.utils.exceptions import try_import

from .base import Backend

if try_import("ollama"):
from ollama import chat
from ollama._client import _encode_image
else:
chat = None

logger = logging.getLogger(__name__)


@dataclass
class Prompt:
"""Ollama prompt."""

few_shot: str
predict: str


class Ollama(Backend):
"""Ollama backend."""

def __init__(self, api_key: str | None = None, model_name: str = "llava") -> None:
"""Initialize the Ollama backend."""
if api_key:
logger.warning("API key is not required for Ollama backend.")
self.model_name: str = model_name
self._ref_images_encoded: list[str] = []

def add_reference_images(self, image: str | Path) -> None:
"""Encode the image to base64."""
self._ref_images_encoded.append(_encode_image(image))

@property
def prompt(self) -> Prompt:
"""Get the Ollama prompt."""
return Prompt(
predict=(
"You are given an image. It is either normal or anomalous."
"First say 'YES' if the image is anomalous, or 'NO' if it is normal.\n"
"Then give the reason for your decision.\n"
"For example, 'YES: The image has a crack on the wall.'"
),
few_shot=(
"These are a few examples of normal picture without any anomalies."
" You have to use these to determine if the image I provide in the next"
" chat is normal or anomalous."
),
)

def predict(self, image: str | Path) -> str:
"""Predict the anomaly label."""
if not chat:
msg = "Ollama is not installed. Please install it using `pip install ollama`."
raise ImportError(msg)
image_encoded = _encode_image(image)
messages = []

# few-shot
if len(self._ref_images_encoded) > 0:
messages.append({
"role": "user",
"images": self._ref_images_encoded,
"content": self.prompt.few_shot,
})

messages.append({"role": "user", "images": [image_encoded], "content": self.prompt.predict})

response = chat(
model=self.model_name,
messages=messages,
)
return response["message"]["content"].strip()
88 changes: 88 additions & 0 deletions src/anomalib/models/image/vlm_ad/lightning_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Visual Anomaly Model for Zero/Few-Shot Anomaly Classification."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from enum import Enum

import torch
from torch.utils.data import DataLoader

from anomalib import LearningType
from anomalib.models import AnomalyModule

from .backends import Backend, Ollama

logger = logging.getLogger(__name__)


class VlmAdBackend(Enum):
"""Supported VLM backends."""

OLLAMA = "ollama"


class VlmAd(AnomalyModule):
"""Visual anomaly model."""

def __init__(
self,
backend: VlmAdBackend | str = VlmAdBackend.OLLAMA,
api_key: str | None = None,
k_shot: int = 3,
) -> None:
super().__init__()
self.k_shot = k_shot
backend = VlmAdBackend(backend)
self.vlm_backend: Backend = self._setup_vlm(backend, api_key)

@staticmethod
def _setup_vlm(backend: VlmAdBackend, api_key: str | None) -> Backend:
match backend:
case VlmAdBackend.OLLAMA:
return Ollama()
case _:
msg = f"Unsupported VLM backend: {backend}"
raise ValueError(msg)

def _setup(self) -> None:
if self.k_shot:
logger.info("Collecting reference images from training dataset.")
dataloader = self.trainer.datamodule.train_dataloader()
self.collect_reference_images(dataloader)

def collect_reference_images(self, dataloader: DataLoader) -> None:
"""Collect reference images for few-shot inference."""
count = 0
for batch in dataloader:
for img_path in batch["image_path"]:
self.vlm_backend.add_reference_images(img_path)
count += 1
if count == self.k_shot:
return

def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> dict:
"""Validation step."""
del args, kwargs # These variables are not used.
responses = [(self.vlm_backend.predict(img_path)) for img_path in batch["image_path"]]

batch["str_output"] = responses
batch["pred_scores"] = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device)
return batch

@property
def learning_type(self) -> LearningType:
"""The learning type of the model."""
return LearningType.ZERO_SHOT if self.k_shot == 0 else LearningType.FEW_SHOT

@property
def trainer_arguments(self) -> dict[str, int | float]:
"""Doesn't need training."""
return {}

@staticmethod
def configure_transforms(image_size: tuple[int, int] | None = None) -> None:
"""This modes does not require any transforms."""
if image_size is not None:
logger.warning("Ignoring image_size argument as each backend has its own transforms.")
Loading

0 comments on commit 660acf1

Please sign in to comment.