Mohaddz commited on
Commit
81a3c33
·
verified ·
1 Parent(s): bf0d4d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -1
app.py CHANGED
@@ -1,3 +1,103 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- gr.load("models/Mohaddz/Detr-spareparts").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image, ImageDraw
5
+ import io
6
 
7
+ # Load your models here
8
+ part_seg_model = tf.keras.models.load_model('path_to_part_seg_model')
9
+ damage_seg_model = tf.keras.models.load_model('path_to_damage_seg_model')
10
+
11
+ # Define your labels
12
+ part_labels = ["front-bumper", "fender", "hood", "door", "trunk"] # Add all your part labels
13
+ damage_labels = ["dent", "scratch", "misalignment", "crack"] # Add all your damage labels
14
+
15
+ def inference_part_seg(image):
16
+ # Implement your part segmentation inference here
17
+ # For now, we'll return a random mask
18
+ return np.random.randint(0, len(part_labels), size=image.shape[:2])
19
+
20
+ def inference_damage_seg(image):
21
+ # Implement your damage segmentation inference here
22
+ # For now, we'll return a random mask
23
+ return np.random.randint(0, len(damage_labels), size=image.shape[:2])
24
+
25
+ def combine_masks(part_mask, damage_mask):
26
+ part_damage_pairs = []
27
+ for part_id, part_name in enumerate(part_labels):
28
+ if part_name == "cars-8I1q":
29
+ continue
30
+ for damage_id, damage_name in enumerate(damage_labels):
31
+ if damage_name == "etc":
32
+ continue
33
+ part_binary = (part_mask == part_id)
34
+ damage_binary = (damage_mask == damage_id)
35
+ intersection = np.logical_and(part_binary, damage_binary)
36
+ if np.any(intersection):
37
+ part_damage_pairs.append((part_name, damage_name))
38
+ return part_damage_pairs
39
+
40
+ def create_one_hot_vector(part_damage_pairs):
41
+ vector = np.zeros(len(part_labels) * len(damage_labels))
42
+ for part, damage in part_damage_pairs:
43
+ if part in part_labels and damage in damage_labels:
44
+ part_index = part_labels.index(part)
45
+ damage_index = damage_labels.index(damage)
46
+ vector_index = part_index * len(damage_labels) + damage_index
47
+ vector[vector_index] = 1
48
+ return vector
49
+
50
+ def visualize_results(image, part_mask, damage_mask):
51
+ img = Image.fromarray(image)
52
+ draw = ImageDraw.Draw(img)
53
+
54
+ for i in range(img.width):
55
+ for j in range(img.height):
56
+ part = part_labels[part_mask[j, i]]
57
+ damage = damage_labels[damage_mask[j, i]]
58
+ if part != "cars-8I1q" and damage != "etc":
59
+ draw.point((i, j), fill="red")
60
+
61
+ return img
62
+
63
+ def process_image(image):
64
+ # Convert to numpy array if it's not already
65
+ if isinstance(image, Image.Image):
66
+ image = np.array(image)
67
+
68
+ # Perform inference
69
+ part_mask = inference_part_seg(image)
70
+ damage_mask = inference_damage_seg(image)
71
+
72
+ # Combine masks
73
+ part_damage_pairs = combine_masks(part_mask, damage_mask)
74
+
75
+ # Create one-hot encoded vector
76
+ one_hot_vector = create_one_hot_vector(part_damage_pairs)
77
+
78
+ # Visualize results
79
+ result_image = visualize_results(image, part_mask, damage_mask)
80
+
81
+ return result_image, part_damage_pairs, one_hot_vector.tolist()
82
+
83
+ def gradio_interface(input_image):
84
+ result_image, part_damage_pairs, one_hot_vector = process_image(input_image)
85
+
86
+ # Convert part_damage_pairs to a string for display
87
+ damage_description = "\n".join([f"{part} : {damage}" for part, damage in part_damage_pairs])
88
+
89
+ return result_image, damage_description, str(one_hot_vector)
90
+
91
+ iface = gr.Interface(
92
+ fn=gradio_interface,
93
+ inputs=gr.Image(type="pil"),
94
+ outputs=[
95
+ gr.Image(type="pil", label="Detected Damage"),
96
+ gr.Textbox(label="Damage Description"),
97
+ gr.Textbox(label="One-hot Encoded Vector")
98
+ ],
99
+ title="Car Damage Assessment",
100
+ description="Upload an image of a damaged car to get an assessment of the damage."
101
+ )
102
+
103
+ iface.launch()