philipp-zettl's picture
use prompt embeddings rather than prompt strings
223ef25 verified
raw
history blame
No virus
2.02 kB
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
model_name = 'UnfilteredAI/NSFW-gen-v2'
pipe = DiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16
)
pipe.to('cuda')
def build_embeddings(enhanced_prompt, negative_prompt=None):
max_length = pipe.tokenizer.model_max_length
input_ids = pipe.tokenizer(enhanced_prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")
negative_ids = pipe.tokenizer(
negative_prompt or "",
truncation=False,
padding="max_length",
max_length=input_ids.shape[-1],
return_tensors="pt"
).input_ids
negative_ids = negative_ids.to("cuda")
concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
return prompt_embeds, negative_prompt_embeds
@spaces.GPU
def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, num_samples):
prompt_embeds, neg_prompt_embeds = build_embeddings(prompt, negative_prompt)
return pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
num_images_per_prompt=num_samples
).images
gr.Interface(
fn=generate,
inputs=[
gr.Text(label="Prompt"),
gr.Text("", label="Negative Prompt"),
gr.Number(7, label="Number inference steps"),
gr.Number(3, label="Guidance scale"),
gr.Number(512, label="Width"),
gr.Number(512, label="Height"),
gr.Number(1, label="# images"),
],
outputs=gr.Gallery(),
).launch()