Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,391 Bytes
f6fe860 d810db5 f6fe860 00f9f38 f6fe860 8197187 f6fe860 b5da3d3 f6fe860 d810db5 f6fe860 d810db5 f6fe860 d810db5 8197187 f6fe860 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
import numpy as np
import spaces # Import spaces for ZeroGPU compatibility
# Load the model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(
model_path,
language_config=language_config,
trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda() if torch.cuda.is_available() else vl_gpt.to(torch.float16)
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Helper functions
def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, patch_size=16):
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(576):
with torch.no_grad():
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size]
)
return patches
def unpack(patches, width, height, parallel_size=5):
# Detach the tensor before converting to numpy
patches = patches.detach().to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
patches = np.clip((patches + 1) / 2 * 255, 0, 255)
images = [Image.fromarray(patches[i].astype(np.uint8)) for i in range(parallel_size)]
return images
@torch.inference_mode()
@spaces.GPU(duration=120)
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
torch.cuda.empty_cache()
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
width, height, parallel_size = 384, 384, 5
messages = [
{'role': '<|User|>', 'content': prompt},
{'role': '<|Assistant|>', 'content': ''}
]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt=''
)
text += vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
patches = generate(input_ids, width, height, cfg_weight=guidance, temperature=t2i_temperature, parallel_size=parallel_size)
return unpack(patches, width, height, parallel_size)
# Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("""
# Text-to-Image Generation with Janus-Pro-7B
Welcome to the Janus-Pro-7B Text-to-Image Generator! This advanced AI model by DeepSeek offers state-of-the-art capabilities in generating images from textual descriptions. Leveraging a unified multimodal framework, Janus-Pro-7B excels in both understanding and generating content, providing detailed and accurate visual representations based on your prompts.
**Key Features:**
- **Unified Multimodal Framework:** Integrates visual encoding and language processing for seamless understanding and generation.
- **High-Quality Image Generation:** Produces stable and detailed images that often surpass those from other leading models.
- **Flexible and Efficient:** Designed for a wide range of applications, from creative image generation to detailed visualizations.
For more information about Janus-Pro-7B, visit the [official Hugging Face model page](https://huggingface.co/deepseek-ai/Janus-Pro-7B).
""")
prompt_input = gr.Textbox(label="Prompt (describe the image)")
# Option to toggle additional parameters
with gr.Accordion("Advanced Parameters", open=False):
seed_input = gr.Number(label="Seed (Optional)", value=12345, precision=0)
guidance_slider = gr.Slider(label="CFG Guidance Weight", minimum=1, maximum=10, value=5, step=0.5)
temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, value=1.0, step=0.05)
generate_button = gr.Button("Generate Images")
output_gallery = gr.Gallery(label="Generated Images", columns=2, height=300)
generate_button.click(
generate_image,
inputs=[prompt_input, seed_input, guidance_slider, temperature_slider],
outputs=output_gallery
)
# Footer
gr.Markdown("""
<hr>
<p style="text-align: center; font-size: 0.9em;">
Created with ❤️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
</p>
""")
# Visitor Badge
gr.HTML("""
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F"><img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759" /></a>
""")
return demo
demo = create_interface()
demo.launch(share=True)
|