File size: 3,509 Bytes
4f93ba9
978b355
71b8b5d
 
 
978b355
 
 
 
37b71af
978b355
71b8b5d
4f93ba9
978b355
71b8b5d
83c6e0c
 
 
 
 
 
 
 
 
71b8b5d
83c6e0c
978b355
71b8b5d
 
 
 
 
 
 
 
 
978b355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71b8b5d
978b355
93fea1b
 
 
87360eb
978b355
87360eb
 
978b355
87360eb
 
93fea1b
 
87360eb
 
 
93fea1b
87360eb
93fea1b
71b8b5d
87360eb
 
978b355
 
93fea1b
87360eb
 
978b355
87360eb
 
978b355
93fea1b
 
978b355
87360eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from ultralytics import YOLO
import torch
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
import gradio as gr

# Global Color Palette
COLORS = np.random.uniform(0, 255, size=(80, 3))

def parse_detections(results):
    boxes, colors, names = [], [], []
    for result in results:
        # Accessing boxes directly from the result
        for box in result.boxes:
            xmin, ymin, xmax, ymax = box.xyxy[0].int().tolist()  # Convert to list of integers
            category = int(box.cls[0].item())  # Class index
            name = result.names[category]  # Get class name from names
            boxes.append((xmin, ymin, xmax, ymax))
            colors.append(COLORS[category])  # Ensure COLORS is defined elsewhere in your code
            names.append(name)

    return boxes, colors, names

def draw_detections(boxes, colors, names, img):
    for box, color, name in zip(boxes, colors, names):
        xmin, ymin, xmax, ymax = box
        cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 2)
        cv2.putText(img, name, (xmin, ymin - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2,
                    lineType=cv2.LINE_AA)
    return img


def generate_cam_image(model, target_layers, tensor, rgb_img, boxes):
    cam = EigenCAM(model, target_layers)
    grayscale_cam = cam(tensor)[0, :, :]
    img_float = np.float32(rgb_img) / 255
    cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
    renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
    for x1, y1, x2, y2 in boxes:
        renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
    renormalized_cam = scale_cam_image(renormalized_cam)
    renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)

    return cam_image, renormalized_cam_image


def xai_yolov8n(image):
    model = YOLO('yolov8n.pt')  # Load YOLOv8n pre-trained weights
    model.eval()

    # Check if GPU is available and use it
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    target_layers = [model.model.model[-2]]  # Grad-CAM target layer
    
    # Process the image through the model
    results = model([image])
    
    # If results are a list, extract the first element (detected results)
    if isinstance(results, list):
        results = results[0]  # Extracting the first result (if list)
    
    # Ensure that outputs are in tensor form
    logits = results.pred[0]  # Get the prediction tensor from the results

    # Parse the detections
    boxes, colors, names = parse_detections([results])  # Ensure results are passed as a list
    detections_img = draw_detections(boxes, colors, names, image.copy())
    
    # Prepare image for Grad-CAM
    img_float = np.float32(image) / 255
    transform = transforms.ToTensor()
    tensor = transform(img_float).unsqueeze(0).to(device)  # Ensure tensor is on the right device
    
    # Generate CAM images
    cam_image, renormalized_cam_image = generate_cam_image(model, target_layers, tensor, image, boxes)

    # Combine original image, CAM image, and renormalized CAM image
    final_image = np.hstack((image, cam_image, renormalized_cam_image))

    # Return final image and a caption
    caption = "Results using YOLOv8n"
    return Image.fromarray(final_image), caption