Skip to content

Commit

Permalink
feat(api): add retry method to zeroshot calling
Browse files Browse the repository at this point in the history
  • Loading branch information
thangbuiq committed May 4, 2024
1 parent b46abf1 commit 2a07957
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import requests
# from openai import OpenAI
import time
from groq import Groq
import base64

Expand Down Expand Up @@ -33,31 +34,40 @@ async def predict(new_image_path):
return f"Error processing image: {str(e)}", 0

async def zeroshot(new_image_path):
try:
API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-large-patch14-336"
headers = {"Authorization": f"Bearer {HF_API}"}
retry_count = 0
max_retries = 3

while retry_count < max_retries:
try:
API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-large-patch14-336"
headers = {"Authorization": f"Bearer {HF_API}"}

def query(data):
with open(data["image_path"], "rb") as f:
img = f.read()
payload={
"parameters": data["parameters"],
"inputs": base64.b64encode(img).decode("utf-8")
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def query(data):
with open(data["image_path"], "rb") as f:
img = f.read()
payload={
"parameters": data["parameters"],
"inputs": base64.b64encode(img).decode("utf-8")
}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()

output = query({
"image_path": f"{new_image_path}",
"parameters": {"candidate_labels": output_class},
})
max_component = max(output, key=lambda x: x['score'])
predicted_value, predicted_accuracy = max_component['label'], max_component['score']
return predicted_value, predicted_accuracy
output = query({
"image_path": f"{new_image_path}",
"parameters": {"candidate_labels": output_class},
})
max_component = max(output, key=lambda x: x['score'])
predicted_value, predicted_accuracy = max_component['label'], max_component['score']
return predicted_value, predicted_accuracy

except Exception as e:
print(f"Error processing image: {str(e)}")
return "unknown", 0
except Exception as e:
print(f"Error processing image: {str(e)}")
retry_count += 1
if retry_count < max_retries:
print("Retrying...")
time.sleep(1)

return "unknown", 0


def input_trash(input):
Expand Down

0 comments on commit 2a07957

Please sign in to comment.