|
import os |
|
import sys |
|
import git |
|
import gradio as gr |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
repo_url = "https://github.com/replicate/cog-sdxl.git" |
|
repo_dir = "./cog-sdxl" |
|
|
|
if not os.path.exists(repo_dir): |
|
print("Cloning cog-sdxl repository...") |
|
git.Repo.clone_from(repo_url, repo_dir) |
|
|
|
|
|
sys.path.append(os.path.abspath(repo_dir)) |
|
|
|
|
|
from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
).to("cuda") |
|
|
|
|
|
pipe.load_lora_weights("fofr/sdxl-emoji", weight_name="lora.safetensors") |
|
|
|
|
|
embedding_path = hf_hub_download(repo_id="fofr/sdxl-emoji", filename="embeddings.pti", repo_type="model") |
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] |
|
tokenizers = [pipe.tokenizer, pipe.tokenizer_2] |
|
embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers) |
|
embhandler.load_embeddings(embedding_path) |
|
|
|
def generate_emoji(prompt): |
|
"""Generate an emoji image based on the user's prompt.""" |
|
prompt = f"A <s0><s1> emoji of {prompt}" |
|
images = pipe( |
|
prompt, |
|
cross_attention_kwargs={"scale": 0.8}, |
|
).images |
|
return images[0] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_emoji, |
|
inputs=gr.Textbox(label="Enter description for emoji"), |
|
outputs=gr.Image(label="Generated Emoji"), |
|
title="SDXL Emoji Generator", |
|
description="Generate a custom emoji using SDXL model with LoRA fine-tuning." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|