enphase_all / app.py
arjunanand13's picture
Update app.py
76dbd5a verified
import os
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
import gradio as gr
# model_name = "arjunanand13/Florence-enphase2"
model_name = "arjunanand13/florence-enphaseall2-30e"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
torch.cuda.empty_cache()
DEFAULT_PROMPT = ("You are a Leg Lift Classifier. There is an image of a throughput component "
"and we need to identify if the leg is inserted in the hole or not. Return 'True' "
"if any leg is not completely seated in the hole; return 'False' if the leg is inserted "
"in the hole. Return only the required JSON in this format: {Leg_lift: , Reason: }.")
def predict(image, question):
if not isinstance(image, Image.Image):
raise ValueError(f"Expected image to be PIL.Image, but got {type(image)}")
encoding = processor(images=image, text=question, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(**encoding, max_length=256)
answer = processor.batch_decode(outputs, skip_special_tokens=True)[0]
return answer
def gradio_interface(image, question):
if image.mode != "RGB":
image = image.convert("RGB")
answer = predict(image, question)
return answer
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="pil", label="Upload Image"), # Ensures image is passed as a PIL object
gr.Textbox(label="Enter your question or edit the default prompt", lines=6, value=DEFAULT_PROMPT) # Default prompt pre-filled and editable
],
outputs=gr.Textbox(label="Answer"),
title="Florence-enphase Leg Lift Classifier",
description=("Upload an image and ask a question about the leg lift. The model will classify whether "
"the leg is inserted in the hole or not based on the image. You can edit the default prompt if needed.")
)
iface.launch(debug=True)