arjunanand13 commited on
Commit
f64ef42
·
verified ·
1 Parent(s): 76dbd5a

Create 11dec24_app.py

Browse files
Files changed (1) hide show
  1. 11dec24_app.py +56 -0
11dec24_app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoModelForCausalLM, AutoProcessor
5
+ import torch
6
+ import gradio as gr
7
+
8
+ # model_name = "arjunanand13/Florence-enphase2"
9
+ model_name = "arjunanand13/florence-enphaseall2-30e"
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ print(f"Using device: {device}")
12
+
13
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
14
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
15
+
16
+ torch.cuda.empty_cache()
17
+
18
+ DEFAULT_PROMPT = ("You are a Leg Lift Classifier. There is an image of a throughput component "
19
+ "and we need to identify if the leg is inserted in the hole or not. Return 'True' "
20
+ "if any leg is not completely seated in the hole; return 'False' if the leg is inserted "
21
+ "in the hole. Return only the required JSON in this format: {Leg_lift: , Reason: }.")
22
+
23
+ def predict(image, question):
24
+
25
+ if not isinstance(image, Image.Image):
26
+ raise ValueError(f"Expected image to be PIL.Image, but got {type(image)}")
27
+
28
+
29
+ encoding = processor(images=image, text=question, return_tensors="pt").to(device)
30
+
31
+ with torch.no_grad():
32
+ outputs = model.generate(**encoding, max_length=256)
33
+
34
+ answer = processor.batch_decode(outputs, skip_special_tokens=True)[0]
35
+ return answer
36
+
37
+ def gradio_interface(image, question):
38
+ if image.mode != "RGB":
39
+ image = image.convert("RGB")
40
+
41
+ answer = predict(image, question)
42
+ return answer
43
+
44
+ iface = gr.Interface(
45
+ fn=gradio_interface,
46
+ inputs=[
47
+ gr.Image(type="pil", label="Upload Image"), # Ensures image is passed as a PIL object
48
+ gr.Textbox(label="Enter your question or edit the default prompt", lines=6, value=DEFAULT_PROMPT) # Default prompt pre-filled and editable
49
+ ],
50
+ outputs=gr.Textbox(label="Answer"),
51
+ title="Florence-enphase Leg Lift Classifier",
52
+ description=("Upload an image and ask a question about the leg lift. The model will classify whether "
53
+ "the leg is inserted in the hole or not based on the image. You can edit the default prompt if needed.")
54
+ )
55
+
56
+ iface.launch(debug=True)