vedAi / predict.py
randomshit11's picture
Upload 2 files
7d51abb verified
raw
history blame
No virus
1.11 kB
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import numpy as np
from io import BytesIO # Add this import statement
processor = AutoImageProcessor.from_pretrained("dima806/medicinal_plants_image_detection")
model = AutoModelForImageClassification.from_pretrained("dima806/medicinal_plants_image_detection")
def read_image(file) -> Image.Image:
pil_image = Image.open(BytesIO(file))
return pil_image
def transformacao(file: Image.Image):
inputs = processor(images=file, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = logits.softmax(dim=1).squeeze()
# Get top 3 predictions
top3_probabilities, top3_indices = torch.topk(probabilities, 3)
labels = model.config.id2label
response = []
for prob, idx in zip(top3_probabilities, top3_indices):
resp = {}
resp["class"] = labels[idx.item()]
resp["confidence"] = f"{prob.item()*100:0.2f} %"
response.append(resp)
return response