|
import os |
|
import torch |
|
import gradio as gr |
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM,AutoProcessor |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
processor = AutoProcessor.from_pretrained("microsoft/git-base") |
|
model = AutoModelForCausalLM.from_pretrained("sam749/sd-portrait-caption").to(device) |
|
|
|
def generate_captions(images:[Image],max_length=200): |
|
|
|
inputs = processor(images=images, return_tensors="pt").to(device) |
|
pixel_values = inputs.pixel_values |
|
generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length) |
|
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True) |
|
return generated_caption |
|
|
|
def generate_caption(image,max_length=200): |
|
return generate_captions(image,max_length)[0] |
|
|
|
|
|
inputs = [ |
|
gr.Image(sources=["upload", "clipboard"], |
|
height=400, |
|
type="pil" |
|
), |
|
gr.Slider(minimum=10, |
|
maximum=400, |
|
value=200, |
|
label='max length', |
|
step=8, |
|
) |
|
] |
|
outputs = [ |
|
gr.Text(label="Generated Caption"), |
|
] |
|
|
|
demo = gr.Interface( |
|
fn=generate_caption, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Stable Diffusion Portrait Captioner", |
|
theme="gradio/monochrome", |
|
api_name="caption", |
|
submit_btn=gr.Button("caption it", variant="primary"), |
|
allow_flagging="never", |
|
) |
|
demo.queue( |
|
max_size=10, |
|
) |
|
|
|
demo.launch() |
|
|