Mohaddz commited on
Commit
558c1bc
·
verified ·
1 Parent(s): c22d417

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import gradio as gr
2
  import numpy as np
3
- import tensorflow as tf
4
  from PIL import Image, ImageDraw
5
- from transformers import TFSegformerForSemanticSegmentation
 
6
 
7
  # Load models from Hugging Face
8
- part_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
9
- damage_seg_model = TFSegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
10
 
11
  # Define your labels
12
  part_labels = ["front-bumper", "fender", "hood", "door", "trunk", "cars-8I1q"] # Add all your part labels
@@ -14,23 +15,26 @@ damage_labels = ["dent", "scratch", "misalignment", "crack", "etc"] # Add all y
14
 
15
  def preprocess_image(image):
16
  # Resize and normalize the image
17
- image = tf.image.resize(image, (512, 512))
18
- image = tf.keras.applications.imagenet_utils.preprocess_input(image)
19
- return image
 
 
 
 
 
 
 
 
 
20
 
21
  def inference_part_seg(image):
22
  preprocessed_image = preprocess_image(image)
23
- outputs = part_seg_model(preprocessed_image, training=False)
24
- logits = outputs.logits
25
- mask = tf.argmax(logits, axis=-1)
26
- return mask.numpy()
27
 
28
  def inference_damage_seg(image):
29
  preprocessed_image = preprocess_image(image)
30
- outputs = damage_seg_model(preprocessed_image, training=False)
31
- logits = outputs.logits
32
- mask = tf.argmax(logits, axis=-1)
33
- return mask.numpy()
34
 
35
  def combine_masks(part_mask, damage_mask):
36
  part_damage_pairs = []
@@ -58,7 +62,7 @@ def create_one_hot_vector(part_damage_pairs):
58
  return vector
59
 
60
  def visualize_results(image, part_mask, damage_mask):
61
- img = Image.fromarray((image * 255).astype(np.uint8))
62
  draw = ImageDraw.Draw(img)
63
 
64
  for i in range(img.width):
@@ -71,10 +75,6 @@ def visualize_results(image, part_mask, damage_mask):
71
  return img
72
 
73
  def process_image(image):
74
- # Convert to numpy array if it's not already
75
- if isinstance(image, Image.Image):
76
- image = np.array(image)
77
-
78
  # Perform inference
79
  part_mask = inference_part_seg(image)
80
  damage_mask = inference_damage_seg(image)
@@ -100,7 +100,7 @@ def gradio_interface(input_image):
100
 
101
  iface = gr.Interface(
102
  fn=gradio_interface,
103
- inputs=gr.Image(type="pil"),
104
  outputs=[
105
  gr.Image(type="pil", label="Detected Damage"),
106
  gr.Textbox(label="Damage Description"),
 
1
  import gradio as gr
2
  import numpy as np
3
+ import torch
4
  from PIL import Image, ImageDraw
5
+ from transformers import SegformerForSemanticSegmentation
6
+ from torchvision.transforms import Resize, ToTensor, Normalize
7
 
8
  # Load models from Hugging Face
9
+ part_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
10
+ damage_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSegMohaddz/DamageSeg")
11
 
12
  # Define your labels
13
  part_labels = ["front-bumper", "fender", "hood", "door", "trunk", "cars-8I1q"] # Add all your part labels
 
15
 
16
  def preprocess_image(image):
17
  # Resize and normalize the image
18
+ transform = Resize((512, 512))
19
+ image = transform(Image.fromarray(image))
20
+ image = ToTensor()(image)
21
+ image = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
22
+ return image.unsqueeze(0) # Add batch dimension
23
+
24
+ def inference_seg(model, image):
25
+ with torch.no_grad():
26
+ outputs = model(image)
27
+ logits = outputs.logits
28
+ mask = torch.argmax(logits, dim=1).squeeze().numpy()
29
+ return mask
30
 
31
  def inference_part_seg(image):
32
  preprocessed_image = preprocess_image(image)
33
+ return inference_seg(part_seg_model, preprocessed_image)
 
 
 
34
 
35
  def inference_damage_seg(image):
36
  preprocessed_image = preprocess_image(image)
37
+ return inference_seg(damage_seg_model, preprocessed_image)
 
 
 
38
 
39
  def combine_masks(part_mask, damage_mask):
40
  part_damage_pairs = []
 
62
  return vector
63
 
64
  def visualize_results(image, part_mask, damage_mask):
65
+ img = Image.fromarray(image)
66
  draw = ImageDraw.Draw(img)
67
 
68
  for i in range(img.width):
 
75
  return img
76
 
77
  def process_image(image):
 
 
 
 
78
  # Perform inference
79
  part_mask = inference_part_seg(image)
80
  damage_mask = inference_damage_seg(image)
 
100
 
101
  iface = gr.Interface(
102
  fn=gradio_interface,
103
+ inputs=gr.Image(type="numpy"),
104
  outputs=[
105
  gr.Image(type="pil", label="Detected Damage"),
106
  gr.Textbox(label="Damage Description"),