Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,019 Bytes
6001e3c 6bc9074 e845246 6bc9074 6001e3c 6bc9074 e845246 6bc9074 223ef25 6bc9074 8c7013a 223ef25 8c7013a 223ef25 13518e4 8c7013a 6bc9074 8c7013a 6bc9074 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
import spaces
import torch
from diffusers import DiffusionPipeline
model_name = 'UnfilteredAI/NSFW-gen-v2'
pipe = DiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16
)
pipe.to('cuda')
def build_embeddings(enhanced_prompt, negative_prompt=None):
max_length = pipe.tokenizer.model_max_length
input_ids = pipe.tokenizer(enhanced_prompt, return_tensors="pt").input_ids
input_ids = input_ids.to("cuda")
negative_ids = pipe.tokenizer(
negative_prompt or "",
truncation=False,
padding="max_length",
max_length=input_ids.shape[-1],
return_tensors="pt"
).input_ids
negative_ids = negative_ids.to("cuda")
concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])
prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
return prompt_embeds, negative_prompt_embeds
@spaces.GPU
def generate(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, num_samples):
prompt_embeds, neg_prompt_embeds = build_embeddings(prompt, negative_prompt)
return pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=neg_prompt_embeds,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
num_images_per_prompt=num_samples
).images
gr.Interface(
fn=generate,
inputs=[
gr.Text(label="Prompt"),
gr.Text("", label="Negative Prompt"),
gr.Number(7, label="Number inference steps"),
gr.Number(3, label="Guidance scale"),
gr.Number(512, label="Width"),
gr.Number(512, label="Height"),
gr.Number(1, label="# images"),
],
outputs=gr.Gallery(),
).launch() |