Skip to content

Commit

Permalink
use load image, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Dec 6, 2023
1 parent d5b2b13 commit 7dd0602
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,21 @@ pip3 install autodistill-dinov2
```python
from autodistill_dinov2 import DINOv2

target_model = DINOv2()
target_model = DINOv2(None)

# train a model
# specify the directory where your annotations (in multiclass classification folder format)
# DINOv2 embeddings are saved in a file called "embeddings.json" the folder in which you are working
# with the structure {filename: embedding}
target_model.train("./context_images_labeled")

# get class list
# print(target_model.ontology.classes())

# run inference on the new model
pred = target_model.predict("./context_images_labeled/train/images/dog-7.jpg")

print(pred)
```


Expand Down
2 changes: 1 addition & 1 deletion autodistill_dinov2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dinov2_model import DINOv2

__version__ = "0.1.0"
__version__ = "0.1.1"
41 changes: 20 additions & 21 deletions autodistill_dinov2/dinov2_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import json
import os
import warnings
from dataclasses import dataclass
from typing import Any

import numpy as np
import supervision as sv
import torch
import torchvision.transforms as T
from autodistill.detection import CaptionOntology
from autodistill.classification import ClassificationBaseModel
from PIL import Image
from sklearn import svm
from tqdm import tqdm

import warnings
from autodistill.classification import ClassificationBaseModel
from autodistill.detection import CaptionOntology
from autodistill.helpers import load_image as load_image_autodistill_main

warnings.filterwarnings("ignore", category=UserWarning)

Expand All @@ -24,11 +25,11 @@
)


def load_image(img: str) -> torch.Tensor:
def load_image(img: Any) -> torch.Tensor:
"""
Load an image and return a tensor that can be used as an input to DINOv2.
"""
img = Image.open(img)
img = load_image_autodistill_main(img, return_format="PIL")

transformed_img = transform_image(img)[:3].unsqueeze(0)

Expand Down Expand Up @@ -66,13 +67,13 @@ def __init__(self, ontology: CaptionOntology):
self.dinov2_model = dinov2_vits14
self.ontology = ontology

def predict(self, input: str) -> sv.Classifications:
def predict(self, input: Any) -> sv.Classifications:
embedding = compute_embeddings([input], self.dinov2_model)

class_id = self.model.predict(np.array(embedding[input]).reshape(-1, 384))

return sv.Classifications(
class_id=np.array([self.ontology.classes().index(class_id)]),
class_id=np.array([self.ontology.classes().index(class_id[0])]),
confidence=np.array([1]),
)

Expand All @@ -82,35 +83,33 @@ def train(self, dataset_location: str):
clf = svm.SVC(gamma="scale")

classes = dataset.classes

images = list(dataset.images.keys())
annotations = dataset.annotations

all_images = []

for image in images:
class_label = classes[annotations[image].class_id[0]]

all_images.append(os.path.join(dataset_location, class_label, image))

embeddings = compute_embeddings(all_images, self.dinov2_model)
embeddings = compute_embeddings(images, self.dinov2_model)

with open("embeddings.json", "w") as f:
json.dump(embeddings, f)

y = [
classes[annotations[os.path.basename(file)].class_id[0]]
for file in all_images
classes[annotations[file].class_id[0]]
for file in images
]

embedding_list = [embeddings[file] for file in all_images]
embedding_list = [embeddings[file] for file in images]

# svm needs at least 2 classes
unqiue_classes = list(set(y))
unique_classes = list(set(y))

if len(unqiue_classes) == 1:
if len(unique_classes) == 1:
raise ValueError("Only one class in dataset")

# DINOv2 has 384 dimensions
clf.fit(np.array(embedding_list).reshape(-1, 384), y)

self.ontology = CaptionOntology(
{prompt: prompt for prompt in classes}
)

self.model = clf
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import re

import setuptools
from setuptools import find_packages
import re

with open("./autodistill_dinov2/__init__.py", 'r') as f:
with open("./autodistill_dinov2/__init__.py", "r") as f:
content = f.read()
# from https://www.py4u.net/discuss/139845
version = re.search(r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', content).group(1)

with open("README.md", "r") as fh:
long_description = fh.read()

setuptools.setup(
name="autodistill-dinov2",
name="autodistill-dinov2",
version=version,
author="Roboflow",
author_email="support@roboflow.com",
Expand All @@ -24,7 +25,7 @@
"torchvision",
"supervision",
"numpy",
"PIL",
"Pillow",
"tqdm",
"scikit-learn",
"autodistill",
Expand Down

0 comments on commit 7dd0602

Please sign in to comment.