deletesoon / app.py
Mohaddz's picture
Update app.py
81a3c33 verified
raw
history blame
3.73 kB
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image, ImageDraw
import io
# Load your models here
part_seg_model = tf.keras.models.load_model('path_to_part_seg_model')
damage_seg_model = tf.keras.models.load_model('path_to_damage_seg_model')
# Define your labels
part_labels = ["front-bumper", "fender", "hood", "door", "trunk"] # Add all your part labels
damage_labels = ["dent", "scratch", "misalignment", "crack"] # Add all your damage labels
def inference_part_seg(image):
# Implement your part segmentation inference here
# For now, we'll return a random mask
return np.random.randint(0, len(part_labels), size=image.shape[:2])
def inference_damage_seg(image):
# Implement your damage segmentation inference here
# For now, we'll return a random mask
return np.random.randint(0, len(damage_labels), size=image.shape[:2])
def combine_masks(part_mask, damage_mask):
part_damage_pairs = []
for part_id, part_name in enumerate(part_labels):
if part_name == "cars-8I1q":
continue
for damage_id, damage_name in enumerate(damage_labels):
if damage_name == "etc":
continue
part_binary = (part_mask == part_id)
damage_binary = (damage_mask == damage_id)
intersection = np.logical_and(part_binary, damage_binary)
if np.any(intersection):
part_damage_pairs.append((part_name, damage_name))
return part_damage_pairs
def create_one_hot_vector(part_damage_pairs):
vector = np.zeros(len(part_labels) * len(damage_labels))
for part, damage in part_damage_pairs:
if part in part_labels and damage in damage_labels:
part_index = part_labels.index(part)
damage_index = damage_labels.index(damage)
vector_index = part_index * len(damage_labels) + damage_index
vector[vector_index] = 1
return vector
def visualize_results(image, part_mask, damage_mask):
img = Image.fromarray(image)
draw = ImageDraw.Draw(img)
for i in range(img.width):
for j in range(img.height):
part = part_labels[part_mask[j, i]]
damage = damage_labels[damage_mask[j, i]]
if part != "cars-8I1q" and damage != "etc":
draw.point((i, j), fill="red")
return img
def process_image(image):
# Convert to numpy array if it's not already
if isinstance(image, Image.Image):
image = np.array(image)
# Perform inference
part_mask = inference_part_seg(image)
damage_mask = inference_damage_seg(image)
# Combine masks
part_damage_pairs = combine_masks(part_mask, damage_mask)
# Create one-hot encoded vector
one_hot_vector = create_one_hot_vector(part_damage_pairs)
# Visualize results
result_image = visualize_results(image, part_mask, damage_mask)
return result_image, part_damage_pairs, one_hot_vector.tolist()
def gradio_interface(input_image):
result_image, part_damage_pairs, one_hot_vector = process_image(input_image)
# Convert part_damage_pairs to a string for display
damage_description = "\n".join([f"{part} : {damage}" for part, damage in part_damage_pairs])
return result_image, damage_description, str(one_hot_vector)
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detected Damage"),
gr.Textbox(label="Damage Description"),
gr.Textbox(label="One-hot Encoded Vector")
],
title="Car Damage Assessment",
description="Upload an image of a damaged car to get an assessment of the damage."
)
iface.launch()