deletesoon / app.py
Mohaddz's picture
Update app.py
c22d417 verified
raw
history blame
4.14 kB
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image, ImageDraw
from transformers import TFSegformerForSemanticSegmentation
# Load models from Hugging Face
part_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
damage_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
# Define your labels
part_labels = ["front-bumper", "fender", "hood", "door", "trunk", "cars-8I1q"] # Add all your part labels
damage_labels = ["dent", "scratch", "misalignment", "crack", "etc"] # Add all your damage labels
def preprocess_image(image):
# Resize and normalize the image
image = tf.image.resize(image, (512, 512))
image = tf.keras.applications.imagenet_utils.preprocess_input(image)
return image
def inference_part_seg(image):
preprocessed_image = preprocess_image(image)
outputs = part_seg_model(preprocessed_image, training=False)
logits = outputs.logits
mask = tf.argmax(logits, axis=-1)
return mask.numpy()
def inference_damage_seg(image):
preprocessed_image = preprocess_image(image)
outputs = damage_seg_model(preprocessed_image, training=False)
logits = outputs.logits
mask = tf.argmax(logits, axis=-1)
return mask.numpy()
def combine_masks(part_mask, damage_mask):
part_damage_pairs = []
for part_id, part_name in enumerate(part_labels):
if part_name == "cars-8I1q":
continue
for damage_id, damage_name in enumerate(damage_labels):
if damage_name == "etc":
continue
part_binary = (part_mask == part_id)
damage_binary = (damage_mask == damage_id)
intersection = np.logical_and(part_binary, damage_binary)
if np.any(intersection):
part_damage_pairs.append((part_name, damage_name))
return part_damage_pairs
def create_one_hot_vector(part_damage_pairs):
vector = np.zeros(len(part_labels) * len(damage_labels))
for part, damage in part_damage_pairs:
if part in part_labels and damage in damage_labels:
part_index = part_labels.index(part)
damage_index = damage_labels.index(damage)
vector_index = part_index * len(damage_labels) + damage_index
vector[vector_index] = 1
return vector
def visualize_results(image, part_mask, damage_mask):
img = Image.fromarray((image * 255).astype(np.uint8))
draw = ImageDraw.Draw(img)
for i in range(img.width):
for j in range(img.height):
part = part_labels[part_mask[j, i]]
damage = damage_labels[damage_mask[j, i]]
if part != "cars-8I1q" and damage != "etc":
draw.point((i, j), fill="red")
return img
def process_image(image):
# Convert to numpy array if it's not already
if isinstance(image, Image.Image):
image = np.array(image)
# Perform inference
part_mask = inference_part_seg(image)
damage_mask = inference_damage_seg(image)
# Combine masks
part_damage_pairs = combine_masks(part_mask, damage_mask)
# Create one-hot encoded vector
one_hot_vector = create_one_hot_vector(part_damage_pairs)
# Visualize results
result_image = visualize_results(image, part_mask, damage_mask)
return result_image, part_damage_pairs, one_hot_vector.tolist()
def gradio_interface(input_image):
result_image, part_damage_pairs, one_hot_vector = process_image(input_image)
# Convert part_damage_pairs to a string for display
damage_description = "\n".join([f"{part} : {damage}" for part, damage in part_damage_pairs])
return result_image, damage_description, str(one_hot_vector)
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detected Damage"),
gr.Textbox(label="Damage Description"),
gr.Textbox(label="One-hot Encoded Vector")
],
title="Car Damage Assessment",
description="Upload an image of a damaged car to get an assessment of the damage."
)
iface.launch()