Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from transformers import AutoProcessor, AutoModel | |
from peft import PeftModel | |
from PIL import Image | |
class ClassificationHead(nn.Module): | |
def __init__(self, input_dim): | |
super().__init__() | |
self.linear = nn.Linear(input_dim, 2) | |
def forward(self, x): | |
return self.linear(x) | |
def load_model(): | |
device = torch.device("cpu") | |
base_model = AutoModel.from_pretrained( | |
"google/siglip-so400m-patch14-384", | |
device_map="cpu", | |
torch_dtype=torch.float32, | |
attn_implementation="sdpa" | |
).vision_model | |
model = PeftModel.from_pretrained(base_model, "fumo_lora", local_files_only=True) | |
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384") | |
head = ClassificationHead(1152) | |
head.load_state_dict(torch.load("fumo_lora/classification_head.pth", weights_only=True, map_location="cpu")) | |
model.eval() | |
head.eval() | |
return model, processor, head, device | |
model, processor, head, device = load_model() | |
def predict_image(image): | |
if image is None: | |
return "Please provide an image." | |
try: | |
# Process image | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model( | |
pixel_values=inputs.pixel_values.to(device, dtype=torch.float32), | |
) | |
pooled = outputs.last_hidden_state.mean(dim=1) | |
logits = head(pooled) | |
prob = F.softmax(logits, dim=1) | |
fumo_prob = prob[0, 1].item() | |
not_fumo_prob = prob[0, 0].item() | |
result = f"Results:\n" | |
result += f"Fumo probability: {fumo_prob:.3f}\n" | |
result += f"Not fumo probability: {not_fumo_prob:.3f}\n" | |
result += f"\nVerdict: {'FUMO!' if fumo_prob > 0.5 else 'Not a fumo'}" | |
return result | |
except Exception as e: | |
return f"Error: {str(e)}" | |
htmlhead = """ | |
<script> | |
function onLoad() { | |
setTimeout(() => { | |
const buttons = [...document.querySelectorAll("button")].filter(v => v.innerText.includes("Flag as")); | |
buttons.forEach(v => v.disabled = true); | |
const submit = [...document.querySelectorAll("button")].filter(v => v.innerText.includes("Submit"))[0]; | |
submit.addEventListener("click", function() { | |
buttons.forEach(v => v.disabled = false); | |
}); | |
}, 1500); | |
} | |
if (document.readyState === 'complete') { | |
onLoad(); | |
} else { | |
window.addEventListener('load', onLoad); | |
} | |
</script> | |
""" | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="pil", width=384, height=384), | |
outputs=gr.Textbox(), | |
title="Fumo Classifier (LoRA)", | |
description="Drop an image to check if it's a Fumo!", | |
examples=["examples/fumo1.jpg", "examples/fumo2.jpg", "examples/no_fumo1.jpg", "examples/no_fumo2.jpg", "examples/no_fumo3.png"], | |
flagging_mode="manual", | |
flagging_options=["Correct π", "Incorrect π"], | |
head=htmlhead, | |
) | |
if __name__== "__main__": | |
has_bf16 = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False | |
# or for CPU: | |
has_bf16_cpu = torch.cpu.is_bf16_supported() if hasattr(torch.cpu, 'is_bf16_supported') else False | |
print(f"BF16 support: {has_bf16} (GPU), {has_bf16_cpu} (CPU)") | |
demo.launch() |