Mohaddz commited on
Commit
e6423a5
·
verified ·
1 Parent(s): d31aa3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -23
app.py CHANGED
@@ -2,30 +2,16 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
- import tensorflow as tf
6
  from transformers import SegformerForSemanticSegmentation, AutoFeatureExtractor
7
  import cv2
8
  import json
 
9
 
10
  # Load models
11
  part_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
12
  damage_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
13
  feature_extractor = AutoFeatureExtractor.from_pretrained("Mohaddz/huggingCars")
14
 
15
- # Recreate the model architecture
16
- def create_model(input_shape, num_classes):
17
- inputs = tf.keras.Input(shape=input_shape)
18
- x = tf.keras.layers.Dense(64, activation='relu')(inputs)
19
- x = tf.keras.layers.Dense(32, activation='relu')(x)
20
- outputs = tf.keras.layers.Dense(num_classes, activation='sigmoid')(x)
21
- return tf.keras.Model(inputs=inputs, outputs=outputs)
22
-
23
- # Load model weights
24
- input_shape = 33 # Adjust this based on your actual input shape
25
- num_classes = 29 # Adjust this based on your actual number of classes
26
- dl_model = create_model(input_shape, num_classes)
27
- dl_model.load_weights('improved_car_damage_prediction_model_weights.h5')
28
-
29
  # Load parts list
30
  with open('cars117.json', 'r', encoding='utf-8') as f:
31
  data = json.load(f)
@@ -63,16 +49,15 @@ def process_image(image):
63
  part_heatmap = create_heatmap(part_features)
64
  part_heatmap_resized = cv2.resize(part_heatmap, (image.size[0], image.size[1]))
65
 
66
- # Predict parts to replace
67
- input_vector = np.concatenate([part_features.mean(axis=(1, 2)), damage_features.mean(axis=(1, 2))])
68
- prediction = dl_model.predict(np.array([input_vector]))
69
- predicted_parts = [(all_parts[i], float(prob)) for i, prob in enumerate(prediction[0]) if prob > 0.1]
70
  predicted_parts.sort(key=lambda x: x[1], reverse=True)
71
 
72
  return (Image.fromarray(annotated_image),
73
  Image.fromarray(damage_heatmap_resized),
74
  Image.fromarray(part_heatmap_resized),
75
- "\n".join([f"{part}: {prob:.2f}" for part, prob in predicted_parts[:5]]))
76
 
77
  def create_heatmap(features):
78
  heatmap = np.sum(features, axis=0)
@@ -87,10 +72,10 @@ iface = gr.Interface(
87
  gr.Image(type="pil", label="Annotated Damage"),
88
  gr.Image(type="pil", label="Damage Heatmap"),
89
  gr.Image(type="pil", label="Part Segmentation Heatmap"),
90
- gr.Textbox(label="Predicted Parts to Replace")
91
  ],
92
- title="Car Damage Assessment",
93
- description="Upload an image of a damaged car to get an assessment."
94
  )
95
 
96
  iface.launch()
 
2
  import torch
3
  from PIL import Image
4
  import numpy as np
 
5
  from transformers import SegformerForSemanticSegmentation, AutoFeatureExtractor
6
  import cv2
7
  import json
8
+ import random
9
 
10
  # Load models
11
  part_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
12
  damage_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
13
  feature_extractor = AutoFeatureExtractor.from_pretrained("Mohaddz/huggingCars")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Load parts list
16
  with open('cars117.json', 'r', encoding='utf-8') as f:
17
  data = json.load(f)
 
49
  part_heatmap = create_heatmap(part_features)
50
  part_heatmap_resized = cv2.resize(part_heatmap, (image.size[0], image.size[1]))
51
 
52
+ # Simulate part prediction (for demonstration purposes)
53
+ num_predictions = random.randint(3, 5)
54
+ predicted_parts = [(part, random.random()) for part in random.sample(all_parts, num_predictions)]
 
55
  predicted_parts.sort(key=lambda x: x[1], reverse=True)
56
 
57
  return (Image.fromarray(annotated_image),
58
  Image.fromarray(damage_heatmap_resized),
59
  Image.fromarray(part_heatmap_resized),
60
+ "\n".join([f"{part}: {prob:.2f}" for part, prob in predicted_parts]))
61
 
62
  def create_heatmap(features):
63
  heatmap = np.sum(features, axis=0)
 
72
  gr.Image(type="pil", label="Annotated Damage"),
73
  gr.Image(type="pil", label="Damage Heatmap"),
74
  gr.Image(type="pil", label="Part Segmentation Heatmap"),
75
+ gr.Textbox(label="Predicted Parts to Replace (Simulated)")
76
  ],
77
+ title="Car Damage Assessment (Demo Version)",
78
+ description="Upload an image of a damaged car to get a simulated assessment. Note: Part predictions are randomly generated for demonstration purposes."
79
  )
80
 
81
  iface.launch()