ankitkupadhyay commited on
Commit
9e11359
·
verified ·
1 Parent(s): 74c62de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -22
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import torch
2
- import torch
3
  import torch.nn as nn
4
  from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
5
  from PIL import Image
6
  import gradio as gr
7
 
 
8
  class VisionLanguageModel(nn.Module):
9
  def __init__(self):
10
  super(VisionLanguageModel, self).__init__()
@@ -33,7 +33,6 @@ class VisionLanguageModel(nn.Module):
33
  logits = self.classifier(combined_features)
34
  return logits
35
 
36
- # Load the model checkpoint with safer loading
37
  model = VisionLanguageModel()
38
  model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
39
  model.eval()
@@ -42,10 +41,7 @@ tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
42
  feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
43
 
44
  def predict(image, text_input):
45
- # Preprocess the image
46
  image = feature_extractor(images=image, return_tensors="pt").pixel_values
47
-
48
- # Preprocess the text
49
  encoding = tokenizer(
50
  text_input,
51
  add_special_tokens=True,
@@ -54,8 +50,6 @@ def predict(image, text_input):
54
  truncation=True,
55
  return_tensors='pt'
56
  )
57
-
58
- # Make a prediction
59
  with torch.no_grad():
60
  outputs = model(
61
  input_ids=encoding['input_ids'],
@@ -63,19 +57,44 @@ def predict(image, text_input):
63
  pixel_values=image
64
  )
65
  _, prediction = torch.max(outputs, dim=1)
66
- return "Malignant" if prediction.item() == 1 else "Benign"
67
-
68
- # Define Gradio interface with updated component syntax
69
- iface = gr.Interface(
70
- fn=predict,
71
- inputs=[
72
- gr.Image(type="pil", label="Upload Skin Lesion Image"),
73
- gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
74
- ],
75
- outputs="text",
76
- title="Skin Lesion Classification Demo",
77
- description="This model classifies skin lesions as benign or malignant based on an image and clinical information."
78
- )
79
-
80
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
 
1
  import torch
 
2
  import torch.nn as nn
3
  from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+ # Model definition and setup
8
  class VisionLanguageModel(nn.Module):
9
  def __init__(self):
10
  super(VisionLanguageModel, self).__init__()
 
33
  logits = self.classifier(combined_features)
34
  return logits
35
 
 
36
  model = VisionLanguageModel()
37
  model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
38
  model.eval()
 
41
  feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
42
 
43
  def predict(image, text_input):
 
44
  image = feature_extractor(images=image, return_tensors="pt").pixel_values
 
 
45
  encoding = tokenizer(
46
  text_input,
47
  add_special_tokens=True,
 
50
  truncation=True,
51
  return_tensors='pt'
52
  )
 
 
53
  with torch.no_grad():
54
  outputs = model(
55
  input_ids=encoding['input_ids'],
 
57
  pixel_values=image
58
  )
59
  _, prediction = torch.max(outputs, dim=1)
60
+ return prediction.item() # 1 for Malignant, 0 for Benign
61
+
62
+ # Enhanced UI with color-coded prediction display
63
+ with gr.Blocks(css="""
64
+ .benign {background-color: white; border: 1px solid lightgray; padding: 10px; border-radius: 5px;}
65
+ .malignant {background-color: white; border: 1px solid lightgray; padding: 10px; border-radius: 5px;}
66
+ .benign.correct {background-color: lightgreen;}
67
+ .malignant.correct {background-color: lightgreen;}
68
+ """) as demo:
69
+ gr.Markdown(
70
+ """
71
+ # 🩺 SKIN LESION CLASSIFICATION
72
+ Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant.
73
+ """
74
+ )
75
+
76
+ with gr.Row():
77
+ with gr.Column(scale=1):
78
+ image_input = gr.Image(type="pil", label="Upload Skin Lesion Image")
79
+ text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")
80
+
81
+ with gr.Column(scale=1):
82
+ benign_output = gr.HTML("<div class='benign'>Benign</div>")
83
+ malignant_output = gr.HTML("<div class='malignant'>Malignant</div>")
84
+ gr.Markdown("## Example:")
85
+ example_image = gr.Image(value="skin_cancer_detection/Unknown-4.png") # Provide path to an example image
86
+ example_text = gr.Textbox(value="consistent with resolving/involuting keratoacanthoma 67", interactive=False)
87
+
88
+ def display_prediction(image, text_input):
89
+ prediction = predict(image, text_input)
90
+ benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "")
91
+ malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "")
92
+ return benign_html, malignant_html
93
+
94
+ # Submit button and prediction outputs
95
+ submit_btn = gr.Button("Get Prediction")
96
+ submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output])
97
+
98
+ demo.launch()
99
+
100