File size: 4,143 Bytes
bf0d4d8
81a3c33
 
 
c22d417
bf0d4d8
c22d417
 
 
81a3c33
 
c22d417
 
 
 
 
 
 
 
81a3c33
 
c22d417
 
 
 
 
81a3c33
 
c22d417
 
 
 
 
81a3c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c22d417
81a3c33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image, ImageDraw
from transformers import TFSegformerForSemanticSegmentation

# Load models from Hugging Face
part_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
damage_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")

# Define your labels
part_labels = ["front-bumper", "fender", "hood", "door", "trunk", "cars-8I1q"]  # Add all your part labels
damage_labels = ["dent", "scratch", "misalignment", "crack", "etc"]  # Add all your damage labels

def preprocess_image(image):
    # Resize and normalize the image
    image = tf.image.resize(image, (512, 512))
    image = tf.keras.applications.imagenet_utils.preprocess_input(image)
    return image

def inference_part_seg(image):
    preprocessed_image = preprocess_image(image)
    outputs = part_seg_model(preprocessed_image, training=False)
    logits = outputs.logits
    mask = tf.argmax(logits, axis=-1)
    return mask.numpy()

def inference_damage_seg(image):
    preprocessed_image = preprocess_image(image)
    outputs = damage_seg_model(preprocessed_image, training=False)
    logits = outputs.logits
    mask = tf.argmax(logits, axis=-1)
    return mask.numpy()

def combine_masks(part_mask, damage_mask):
    part_damage_pairs = []
    for part_id, part_name in enumerate(part_labels):
        if part_name == "cars-8I1q":
            continue
        for damage_id, damage_name in enumerate(damage_labels):
            if damage_name == "etc":
                continue
            part_binary = (part_mask == part_id)
            damage_binary = (damage_mask == damage_id)
            intersection = np.logical_and(part_binary, damage_binary)
            if np.any(intersection):
                part_damage_pairs.append((part_name, damage_name))
    return part_damage_pairs

def create_one_hot_vector(part_damage_pairs):
    vector = np.zeros(len(part_labels) * len(damage_labels))
    for part, damage in part_damage_pairs:
        if part in part_labels and damage in damage_labels:
            part_index = part_labels.index(part)
            damage_index = damage_labels.index(damage)
            vector_index = part_index * len(damage_labels) + damage_index
            vector[vector_index] = 1
    return vector

def visualize_results(image, part_mask, damage_mask):
    img = Image.fromarray((image * 255).astype(np.uint8))
    draw = ImageDraw.Draw(img)
    
    for i in range(img.width):
        for j in range(img.height):
            part = part_labels[part_mask[j, i]]
            damage = damage_labels[damage_mask[j, i]]
            if part != "cars-8I1q" and damage != "etc":
                draw.point((i, j), fill="red")
    
    return img

def process_image(image):
    # Convert to numpy array if it's not already
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    # Perform inference
    part_mask = inference_part_seg(image)
    damage_mask = inference_damage_seg(image)
    
    # Combine masks
    part_damage_pairs = combine_masks(part_mask, damage_mask)
    
    # Create one-hot encoded vector
    one_hot_vector = create_one_hot_vector(part_damage_pairs)
    
    # Visualize results
    result_image = visualize_results(image, part_mask, damage_mask)
    
    return result_image, part_damage_pairs, one_hot_vector.tolist()

def gradio_interface(input_image):
    result_image, part_damage_pairs, one_hot_vector = process_image(input_image)
    
    # Convert part_damage_pairs to a string for display
    damage_description = "\n".join([f"{part} : {damage}" for part, damage in part_damage_pairs])
    
    return result_image, damage_description, str(one_hot_vector)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Image(type="pil", label="Detected Damage"),
        gr.Textbox(label="Damage Description"),
        gr.Textbox(label="One-hot Encoded Vector")
    ],
    title="Car Damage Assessment",
    description="Upload an image of a damaged car to get an assessment of the damage."
)

iface.launch()