Spaces:
Running
Running
File size: 4,143 Bytes
bf0d4d8 81a3c33 c22d417 bf0d4d8 c22d417 81a3c33 c22d417 81a3c33 c22d417 81a3c33 c22d417 81a3c33 c22d417 81a3c33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image, ImageDraw
from transformers import TFSegformerForSemanticSegmentation
# Load models from Hugging Face
part_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
damage_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
# Define your labels
part_labels = ["front-bumper", "fender", "hood", "door", "trunk", "cars-8I1q"] # Add all your part labels
damage_labels = ["dent", "scratch", "misalignment", "crack", "etc"] # Add all your damage labels
def preprocess_image(image):
# Resize and normalize the image
image = tf.image.resize(image, (512, 512))
image = tf.keras.applications.imagenet_utils.preprocess_input(image)
return image
def inference_part_seg(image):
preprocessed_image = preprocess_image(image)
outputs = part_seg_model(preprocessed_image, training=False)
logits = outputs.logits
mask = tf.argmax(logits, axis=-1)
return mask.numpy()
def inference_damage_seg(image):
preprocessed_image = preprocess_image(image)
outputs = damage_seg_model(preprocessed_image, training=False)
logits = outputs.logits
mask = tf.argmax(logits, axis=-1)
return mask.numpy()
def combine_masks(part_mask, damage_mask):
part_damage_pairs = []
for part_id, part_name in enumerate(part_labels):
if part_name == "cars-8I1q":
continue
for damage_id, damage_name in enumerate(damage_labels):
if damage_name == "etc":
continue
part_binary = (part_mask == part_id)
damage_binary = (damage_mask == damage_id)
intersection = np.logical_and(part_binary, damage_binary)
if np.any(intersection):
part_damage_pairs.append((part_name, damage_name))
return part_damage_pairs
def create_one_hot_vector(part_damage_pairs):
vector = np.zeros(len(part_labels) * len(damage_labels))
for part, damage in part_damage_pairs:
if part in part_labels and damage in damage_labels:
part_index = part_labels.index(part)
damage_index = damage_labels.index(damage)
vector_index = part_index * len(damage_labels) + damage_index
vector[vector_index] = 1
return vector
def visualize_results(image, part_mask, damage_mask):
img = Image.fromarray((image * 255).astype(np.uint8))
draw = ImageDraw.Draw(img)
for i in range(img.width):
for j in range(img.height):
part = part_labels[part_mask[j, i]]
damage = damage_labels[damage_mask[j, i]]
if part != "cars-8I1q" and damage != "etc":
draw.point((i, j), fill="red")
return img
def process_image(image):
# Convert to numpy array if it's not already
if isinstance(image, Image.Image):
image = np.array(image)
# Perform inference
part_mask = inference_part_seg(image)
damage_mask = inference_damage_seg(image)
# Combine masks
part_damage_pairs = combine_masks(part_mask, damage_mask)
# Create one-hot encoded vector
one_hot_vector = create_one_hot_vector(part_damage_pairs)
# Visualize results
result_image = visualize_results(image, part_mask, damage_mask)
return result_image, part_damage_pairs, one_hot_vector.tolist()
def gradio_interface(input_image):
result_image, part_damage_pairs, one_hot_vector = process_image(input_image)
# Convert part_damage_pairs to a string for display
damage_description = "\n".join([f"{part} : {damage}" for part, damage in part_damage_pairs])
return result_image, damage_description, str(one_hot_vector)
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detected Damage"),
gr.Textbox(label="Damage Description"),
gr.Textbox(label="One-hot Encoded Vector")
],
title="Car Damage Assessment",
description="Upload an image of a damaged car to get an assessment of the damage."
)
iface.launch() |