deletesoon / app.py
Mohaddz's picture
Update app.py
e6423a5 verified
raw
history blame
3.33 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
from transformers import SegformerForSemanticSegmentation, AutoFeatureExtractor
import cv2
import json
import random
# Load models
part_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
damage_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
feature_extractor = AutoFeatureExtractor.from_pretrained("Mohaddz/huggingCars")
# Load parts list
with open('cars117.json', 'r', encoding='utf-8') as f:
data = json.load(f)
all_parts = sorted(list(set(part for entry in data.values() for part in entry.get('replaced_parts', []))))
def process_image(image):
# Convert to RGB if it's not
if image.mode != 'RGB':
image = image.convert('RGB')
# Prepare input for the model
inputs = feature_extractor(images=image, return_tensors="pt")
# Get damage segmentation
with torch.no_grad():
damage_output = damage_seg_model(**inputs).logits
damage_features = damage_output.squeeze().detach().numpy()
# Create damage segmentation heatmap
damage_heatmap = create_heatmap(damage_features)
damage_heatmap_resized = cv2.resize(damage_heatmap, (image.size[0], image.size[1]))
# Create annotated damage image
image_array = np.array(image)
damage_mask = np.argmax(damage_features, axis=0)
damage_mask_resized = cv2.resize(damage_mask, (image.size[0], image.size[1]), interpolation=cv2.INTER_NEAREST)
overlay = np.zeros_like(image_array)
overlay[damage_mask_resized > 0] = [255, 0, 0] # Red color for damage
annotated_image = cv2.addWeighted(image_array, 1, overlay, 0.5, 0)
# Process for part prediction and heatmap
with torch.no_grad():
part_output = part_seg_model(**inputs).logits
part_features = part_output.squeeze().detach().numpy()
part_heatmap = create_heatmap(part_features)
part_heatmap_resized = cv2.resize(part_heatmap, (image.size[0], image.size[1]))
# Simulate part prediction (for demonstration purposes)
num_predictions = random.randint(3, 5)
predicted_parts = [(part, random.random()) for part in random.sample(all_parts, num_predictions)]
predicted_parts.sort(key=lambda x: x[1], reverse=True)
return (Image.fromarray(annotated_image),
Image.fromarray(damage_heatmap_resized),
Image.fromarray(part_heatmap_resized),
"\n".join([f"{part}: {prob:.2f}" for part, prob in predicted_parts]))
def create_heatmap(features):
heatmap = np.sum(features, axis=0)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
heatmap = np.uint8(255 * heatmap)
return cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Annotated Damage"),
gr.Image(type="pil", label="Damage Heatmap"),
gr.Image(type="pil", label="Part Segmentation Heatmap"),
gr.Textbox(label="Predicted Parts to Replace (Simulated)")
],
title="Car Damage Assessment (Demo Version)",
description="Upload an image of a damaged car to get a simulated assessment. Note: Part predictions are randomly generated for demonstration purposes."
)
iface.launch()