File size: 3,823 Bytes
5c9bc3a
 
a1ee699
5c9bc3a
 
 
9e11359
5c9bc3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1ee699
5c9bc3a
 
 
a1ee699
5c9bc3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e11359
 
0b14df6
9e11359
d337955
0b14df6
d337955
 
 
 
 
 
0b14df6
d337955
 
 
0b14df6
d337955
9e11359
 
 
d337955
9e11359
 
 
 
 
 
 
 
 
 
5287106
9e11359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1ee699
d337955
0b14df6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr

# Model definition and setup
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super(VisionLanguageModel, self).__init__()
        self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.language_model = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(
            self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
            2  # Number of classes: benign or malignant
        )

    def forward(self, input_ids, attention_mask, pixel_values):
        vision_outputs = self.vision_model(pixel_values=pixel_values)
        vision_pooled_output = vision_outputs.pooler_output

        language_outputs = self.language_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        language_pooled_output = language_outputs.pooler_output

        combined_features = torch.cat(
            (vision_pooled_output, language_pooled_output),
            dim=1
        )

        logits = self.classifier(combined_features)
        return logits

model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

def predict(image, text_input):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values
    encoding = tokenizer(
        text_input,
        add_special_tokens=True,
        max_length=256,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    with torch.no_grad():
        outputs = model(
            input_ids=encoding['input_ids'],
            attention_mask=encoding['attention_mask'],
            pixel_values=image
        )
    _, prediction = torch.max(outputs, dim=1)
    return prediction.item()  # 1 for Malignant, 0 for Benign

# Enhanced UI with black text
with gr.Blocks(css="""
    body { 
        color: black;
    }
    .benign, .malignant { 
        background-color: white; 
        border: 1px solid lightgray; 
        padding: 10px; 
        border-radius: 5px; 
        color: black;
    }
    .benign.correct, .malignant.correct { 
        background-color: lightgreen; 
        color: black;
    }
""") as demo:
    gr.Markdown(
        """
        # 🩺 SKIN LESION CLASSIFICATION
        Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Skin Lesion Image")
            text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")

        with gr.Column(scale=1):
            gr.Markdown("## PREDICTION RESULTS")
            benign_output = gr.HTML("<div class='benign'>Benign</div>")
            malignant_output = gr.HTML("<div class='malignant'>Malignant</div>")

    def display_prediction(image, text_input):
        prediction = predict(image, text_input)
        benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "")
        malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "")
        return benign_html, malignant_html

    # Submit button and prediction outputs
    submit_btn = gr.Button("Get Prediction")
    submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output])

demo.launch()