syedfaisalabrar's picture
Update app.py
d885a52 verified
raw
history blame
3 kB
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image, ImageEnhance
from ultralytics import YOLO
import json
model_path = "best.pt"
model = YOLO(model_path)
def preprocess_image(image):
image = Image.fromarray(np.array(image))
image = ImageEnhance.Sharpness(image).enhance(2.0) # Increase sharpness
image = ImageEnhance.Contrast(image).enhance(1.5) # Increase contrast
image = ImageEnhance.Brightness(image).enhance(0.8) # Reduce brightness
# Resize image to 800px width while maintaining aspect ratio
width = 800
aspect_ratio = image.height / image.width
height = int(width * aspect_ratio)
image = image.resize((width, height))
return image
def imageRotation(image):
"""Dummy function for now."""
return image
def vision_ai_api(image, label):
"""Dummy function simulating API call. Returns dummy JSON response."""
return {
"label": label,
"extracted_data": {
"name": "-------",
"dob": "-------",
"id_number": "-------"
}
}
def predict(image):
image = preprocess_image(image)
results = model(image, conf=0.80)
detected_classes = set()
labels = []
cropped_images = {}
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = box.conf[0]
cls = int(box.cls[0])
class_name = model.names[cls]
print(f"Detected: {class_name} ({conf:.2f}) at [{x1}, {y1}, {x2}, {y2}]")
detected_classes.add(class_name)
labels.append(f"{class_name} {conf:.2f}")
# Ensure bounding boxes are within the image
height, width = image.shape[:2]
x1, y1, x2, y2 = max(0, x1), max(0, y1), min(width, x2), min(height, y2)
if x1 >= x2 or y1 >= y2:
print("Invalid bounding box, skipping.")
continue
cropped = image[y1:y2, x1:x2]
cropped_pil = Image.fromarray(cropped)
# Call API
api_response = vision_ai_api(cropped_pil, class_name)
cropped_images[class_name] = {"image": cropped_pil, "api_response": json.dumps(api_response, indent=4)}
if not cropped_images:
return None, "No front detected", None, "No back detected", ["No valid detections"]
return (
cropped_images.get("front", {}).get("image", None),
cropped_images.get("front", {}).get("api_response", "{}"),
cropped_images.get("back", {}).get("image", None),
cropped_images.get("back", {}).get("api_response", "{}"),
labels
)
# Gradio Interface
iface = gr.Interface(
fn=predict,
inputs="image",
outputs=["image", "text"],
title="License Field Detection (Front & Back Card)",
description="Detect front & back of a license card, crop the images, and call Vision AI API separately for each."
)
iface.launch()