deletesoon / app.py
Mohaddz's picture
Update app.py
558c1bc verified
raw
history blame
4.17 kB
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
from transformers import SegformerForSemanticSegmentation
from torchvision.transforms import Resize, ToTensor, Normalize
# Load models from Hugging Face
part_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
damage_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSegMohaddz/DamageSeg")
# Define your labels
part_labels = ["front-bumper", "fender", "hood", "door", "trunk", "cars-8I1q"] # Add all your part labels
damage_labels = ["dent", "scratch", "misalignment", "crack", "etc"] # Add all your damage labels
def preprocess_image(image):
# Resize and normalize the image
transform = Resize((512, 512))
image = transform(Image.fromarray(image))
image = ToTensor()(image)
image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
return image.unsqueeze(0) # Add batch dimension
def inference_seg(model, image):
with torch.no_grad():
outputs = model(image)
logits = outputs.logits
mask = torch.argmax(logits, dim=1).squeeze().numpy()
return mask
def inference_part_seg(image):
preprocessed_image = preprocess_image(image)
return inference_seg(part_seg_model, preprocessed_image)
def inference_damage_seg(image):
preprocessed_image = preprocess_image(image)
return inference_seg(damage_seg_model, preprocessed_image)
def combine_masks(part_mask, damage_mask):
part_damage_pairs = []
for part_id, part_name in enumerate(part_labels):
if part_name == "cars-8I1q":
continue
for damage_id, damage_name in enumerate(damage_labels):
if damage_name == "etc":
continue
part_binary = (part_mask == part_id)
damage_binary = (damage_mask == damage_id)
intersection = np.logical_and(part_binary, damage_binary)
if np.any(intersection):
part_damage_pairs.append((part_name, damage_name))
return part_damage_pairs
def create_one_hot_vector(part_damage_pairs):
vector = np.zeros(len(part_labels) * len(damage_labels))
for part, damage in part_damage_pairs:
if part in part_labels and damage in damage_labels:
part_index = part_labels.index(part)
damage_index = damage_labels.index(damage)
vector_index = part_index * len(damage_labels) + damage_index
vector[vector_index] = 1
return vector
def visualize_results(image, part_mask, damage_mask):
img = Image.fromarray(image)
draw = ImageDraw.Draw(img)
for i in range(img.width):
for j in range(img.height):
part = part_labels[part_mask[j, i]]
damage = damage_labels[damage_mask[j, i]]
if part != "cars-8I1q" and damage != "etc":
draw.point((i, j), fill="red")
return img
def process_image(image):
# Perform inference
part_mask = inference_part_seg(image)
damage_mask = inference_damage_seg(image)
# Combine masks
part_damage_pairs = combine_masks(part_mask, damage_mask)
# Create one-hot encoded vector
one_hot_vector = create_one_hot_vector(part_damage_pairs)
# Visualize results
result_image = visualize_results(image, part_mask, damage_mask)
return result_image, part_damage_pairs, one_hot_vector.tolist()
def gradio_interface(input_image):
result_image, part_damage_pairs, one_hot_vector = process_image(input_image)
# Convert part_damage_pairs to a string for display
damage_description = "\n".join([f"{part} : {damage}" for part, damage in part_damage_pairs])
return result_image, damage_description, str(one_hot_vector)
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="numpy"),
outputs=[
gr.Image(type="pil", label="Detected Damage"),
gr.Textbox(label="Damage Description"),
gr.Textbox(label="One-hot Encoded Vector")
],
title="Car Damage Assessment",
description="Upload an image of a damaged car to get an assessment of the damage."
)
iface.launch()