deletesoon / app.py
Mohaddz's picture
Update app.py
a4bb933 verified
raw
history blame
3.99 kB
import gradio as gr
import torch
from PIL import Image
import numpy as np
import tensorflow as tf
from transformers import SegformerForSemanticSegmentation, AutoFeatureExtractor
import cv2
import json
# Load models
part_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/huggingCars")
damage_seg_model = SegformerForSemanticSegmentation.from_pretrained("Mohaddz/DamageSeg")
feature_extractor = AutoFeatureExtractor.from_pretrained("Mohaddz/huggingCars")
# Recreate the model architecture
def create_model(input_shape, num_classes):
inputs = tf.keras.Input(shape=input_shape)
x = tf.keras.layers.Dense(64, activation='relu')(inputs)
x = tf.keras.layers.Dense(32, activation='relu')(x)
outputs = tf.keras.layers.Dense(num_classes, activation='sigmoid')(x)
return tf.keras.Model(inputs=inputs, outputs=outputs)
# Load model weights
input_shape = 33 # Adjust this based on your actual input shape
num_classes = 29 # Adjust this based on your actual number of classes
dl_model = create_model(input_shape, num_classes)
dl_model.load_weights('improved_car_damage_prediction_model_weights.h5')
# Load parts list
with open('cars117.json', 'r', encoding='utf-8') as f:
data = json.load(f)
all_parts = sorted(list(set(part for entry in data.values() for part in entry.get('replaced_parts', []))))
def process_image(image):
# Convert to RGB if it's not
if image.mode != 'RGB':
image = image.convert('RGB')
# Prepare input for the model
inputs = feature_extractor(images=image, return_tensors="pt")
# Get damage segmentation
with torch.no_grad():
damage_output = damage_seg_model(**inputs).logits
damage_features = damage_output.squeeze().detach().numpy()
# Create damage segmentation heatmap
damage_heatmap = create_heatmap(damage_features)
damage_heatmap_resized = cv2.resize(damage_heatmap, (image.size[0], image.size[1]))
# Create annotated damage image
image_array = np.array(image)
damage_mask = np.argmax(damage_features, axis=0)
damage_mask_resized = cv2.resize(damage_mask, (image.size[0], image.size[1]), interpolation=cv2.INTER_NEAREST)
overlay = np.zeros_like(image_array)
overlay[damage_mask_resized > 0] = [255, 0, 0] # Red color for damage
annotated_image = cv2.addWeighted(image_array, 1, overlay, 0.5, 0)
# Process for part prediction and heatmap
with torch.no_grad():
part_output = part_seg_model(**inputs).logits
part_features = part_output.squeeze().detach().numpy()
part_heatmap = create_heatmap(part_features)
part_heatmap_resized = cv2.resize(part_heatmap, (image.size[0], image.size[1]))
# Predict parts to replace
input_vector = np.concatenate([part_features.mean(axis=(1, 2)), damage_features.mean(axis=(1, 2))])
prediction = dl_model.predict(np.array([input_vector]))
predicted_parts = [(all_parts[i], float(prob)) for i, prob in enumerate(prediction[0]) if prob > 0.1]
predicted_parts.sort(key=lambda x: x[1], reverse=True)
return (Image.fromarray(annotated_image),
Image.fromarray(damage_heatmap_resized),
Image.fromarray(part_heatmap_resized),
"\n".join([f"{part}: {prob:.2f}" for part, prob in predicted_parts[:5]]))
def create_heatmap(features):
heatmap = np.sum(features, axis=0)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
heatmap = np.uint8(255 * heatmap)
return cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Annotated Damage"),
gr.Image(type="pil", label="Damage Heatmap"),
gr.Image(type="pil", label="Part Segmentation Heatmap"),
gr.Textbox(label="Predicted Parts to Replace")
],
title="Car Damage Assessment",
description="Upload an image of a damaged car to get an assessment."
)
iface.launch()