File size: 6,391 Bytes
f6fe860
 
 
 
 
 
d810db5
f6fe860
 
 
 
 
 
 
 
 
 
 
 
00f9f38
f6fe860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8197187
 
f6fe860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5da3d3
 
 
 
 
 
 
 
 
 
 
 
f6fe860
 
d810db5
 
 
 
 
 
f6fe860
 
 
 
 
 
 
 
 
 
d810db5
 
 
 
 
 
 
 
 
 
 
 
f6fe860
d810db5
8197187
f6fe860
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
import numpy as np
import spaces  # Import spaces for ZeroGPU compatibility

# Load the model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'

vl_gpt = AutoModelForCausalLM.from_pretrained(
    model_path,
    language_config=language_config,
    trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda() if torch.cuda.is_available() else vl_gpt.to(torch.float16)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Helper functions
def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, patch_size=16):
    torch.cuda.empty_cache()

    tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
    for i in range(parallel_size * 2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id
    
    inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
    generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int).to(cuda_device)

    pkv = None
    for i in range(576):
        with torch.no_grad():
            outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
            pkv = outputs.past_key_values
            hidden_states = outputs.last_hidden_state
            logits = vl_gpt.gen_head(hidden_states[:, -1, :])

            logit_cond = logits[0::2, :]
            logit_uncond = logits[1::2, :]
            logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)

            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_tokens[:, i] = next_token.squeeze(dim=-1)

            next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
            img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
            inputs_embeds = img_embeds.unsqueeze(dim=1)

    patches = vl_gpt.gen_vision_model.decode_code(
        generated_tokens.to(dtype=torch.int),
        shape=[parallel_size, 8, width // patch_size, height // patch_size]
    )
    return patches

def unpack(patches, width, height, parallel_size=5):
    # Detach the tensor before converting to numpy
    patches = patches.detach().to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
    patches = np.clip((patches + 1) / 2 * 255, 0, 255)

    images = [Image.fromarray(patches[i].astype(np.uint8)) for i in range(parallel_size)]
    return images

@torch.inference_mode()
@spaces.GPU(duration=120)
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
    torch.cuda.empty_cache()
    
    if seed is not None:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)

    width, height, parallel_size = 384, 384, 5

    messages = [
        {'role': '<|User|>', 'content': prompt},
        {'role': '<|Assistant|>', 'content': ''}
    ]

    text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt=''
    )
    text += vl_chat_processor.image_start_tag

    input_ids = torch.LongTensor(tokenizer.encode(text))
    patches = generate(input_ids, width, height, cfg_weight=guidance, temperature=t2i_temperature, parallel_size=parallel_size)

    return unpack(patches, width, height, parallel_size)

# Gradio interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("""
        # Text-to-Image Generation with Janus-Pro-7B

        Welcome to the Janus-Pro-7B Text-to-Image Generator! This advanced AI model by DeepSeek offers state-of-the-art capabilities in generating images from textual descriptions. Leveraging a unified multimodal framework, Janus-Pro-7B excels in both understanding and generating content, providing detailed and accurate visual representations based on your prompts.

        **Key Features:**
        - **Unified Multimodal Framework:** Integrates visual encoding and language processing for seamless understanding and generation.
        - **High-Quality Image Generation:** Produces stable and detailed images that often surpass those from other leading models.
        - **Flexible and Efficient:** Designed for a wide range of applications, from creative image generation to detailed visualizations.

        For more information about Janus-Pro-7B, visit the [official Hugging Face model page](https://huggingface.co/deepseek-ai/Janus-Pro-7B).
        """)
        
        prompt_input = gr.Textbox(label="Prompt (describe the image)")

        # Option to toggle additional parameters
        with gr.Accordion("Advanced Parameters", open=False):
            seed_input = gr.Number(label="Seed (Optional)", value=12345, precision=0)
            guidance_slider = gr.Slider(label="CFG Guidance Weight", minimum=1, maximum=10, value=5, step=0.5)
            temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, value=1.0, step=0.05)

        generate_button = gr.Button("Generate Images")
        output_gallery = gr.Gallery(label="Generated Images", columns=2, height=300)

        generate_button.click(
            generate_image,
            inputs=[prompt_input, seed_input, guidance_slider, temperature_slider],
            outputs=output_gallery
        )

        # Footer
        gr.Markdown("""
        <hr>
        <p style="text-align: center; font-size: 0.9em;">
            Created with ❤️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
        </p>
        """)
        
        # Visitor Badge
        gr.HTML("""
        <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F"><img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759" /></a>
        """)

    return demo

demo = create_interface()
demo.launch(share=True)