Bils commited on
Commit
3dc7f18
·
verified ·
1 Parent(s): 601ca08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -115
app.py CHANGED
@@ -4,144 +4,238 @@ from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from PIL import Image
6
  import numpy as np
7
- import spaces
8
-
9
- # Load the model and processor
10
- model_path = "deepseek-ai/Janus-Pro-7B"
11
- config = AutoConfig.from_pretrained(model_path)
12
- language_config = config.language_config
13
- language_config._attn_implementation = 'eager'
14
-
15
- vl_gpt = AutoModelForCausalLM.from_pretrained(
16
- model_path,
17
- language_config=language_config,
18
- trust_remote_code=True
19
- )
20
- vl_gpt = vl_gpt.to(torch.bfloat16).cuda() if torch.cuda.is_available() else vl_gpt.to(torch.float16)
21
-
22
- vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
23
- tokenizer = vl_chat_processor.tokenizer
24
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
-
26
- # Helper functions
27
- def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, patch_size=16):
28
- torch.cuda.empty_cache()
29
-
30
- tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
31
- for i in range(parallel_size * 2):
32
- tokens[i, :] = input_ids
33
- if i % 2 != 0:
34
- tokens[i, 1:-1] = vl_chat_processor.pad_id
 
 
 
 
 
35
 
36
- inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
37
- generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int).to(cuda_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- pkv = None
40
- for i in range(576):
41
  with torch.no_grad():
42
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
43
- pkv = outputs.past_key_values
44
- hidden_states = outputs.last_hidden_state
45
- logits = vl_gpt.gen_head(hidden_states[:, -1, :])
46
-
47
- logit_cond = logits[0::2, :]
48
- logit_uncond = logits[1::2, :]
49
- logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
50
-
51
- probs = torch.softmax(logits / temperature, dim=-1)
52
- next_token = torch.multinomial(probs, num_samples=1)
53
- generated_tokens[:, i] = next_token.squeeze(dim=-1)
54
-
55
- next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
56
- img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
57
- inputs_embeds = img_embeds.unsqueeze(dim=1)
58
-
59
- patches = vl_gpt.gen_vision_model.decode_code(
60
- generated_tokens.to(dtype=torch.int),
61
- shape=[parallel_size, 8, width // patch_size, height // patch_size]
62
- )
63
- return patches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def unpack(patches, width, height, parallel_size=5):
66
- # Detach the tensor before converting to numpy
67
- patches = patches.detach().to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
68
- patches = np.clip((patches + 1) / 2 * 255, 0, 255)
69
-
70
- images = [Image.fromarray(patches[i].astype(np.uint8)) for i in range(parallel_size)]
71
- return images
 
 
72
 
73
  @torch.inference_mode()
74
  @spaces.GPU(duration=120)
75
- def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
76
- torch.cuda.empty_cache()
77
-
78
- if seed is not None:
 
 
 
 
 
 
 
 
 
 
79
  torch.manual_seed(seed)
80
- torch.cuda.manual_seed(seed)
81
- np.random.seed(seed)
82
-
83
- width, height, parallel_size = 384, 384, 5
84
-
85
- messages = [
86
- {'role': '<|User|>', 'content': prompt},
87
- {'role': '<|Assistant|>', 'content': ''}
88
- ]
89
-
90
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
91
- conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt=''
92
- )
93
- text += vl_chat_processor.image_start_tag
94
-
95
- input_ids = torch.LongTensor(tokenizer.encode(text))
96
- patches = generate(input_ids, width, height, cfg_weight=guidance, temperature=t2i_temperature, parallel_size=parallel_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- return unpack(patches, width, height, parallel_size)
 
 
99
 
100
- # Gradio interface
101
  def create_interface():
102
- with gr.Blocks() as demo:
103
  gr.Markdown("""
104
  # Text-to-Image Generation with Janus-Pro-7B
105
-
106
- Welcome to the Janus-Pro-7B Text-to-Image Generator! This advanced AI model by DeepSeek offers state-of-the-art capabilities in generating images from textual descriptions. Leveraging a unified multimodal framework, Janus-Pro-7B excels in both understanding and generating content, providing detailed and accurate visual representations based on your prompts.
107
-
108
- **Key Features:**
109
- - **High-Quality Image Generation:** Produces stable and detailed images that often surpass those from other leading models.
110
-
111
- For more information about Janus-Pro-7B, visit the [official Hugging Face model page](https://huggingface.co/deepseek-ai/Janus-Pro-7B).
112
  """)
113
-
114
- prompt_input = gr.Textbox(label="Prompt (describe the image)")
115
-
116
- # Option to toggle additional parameters
117
- with gr.Accordion("Advanced Parameters", open=False):
118
- seed_input = gr.Number(label="Seed (Optional)", value=12345, precision=0)
119
- guidance_slider = gr.Slider(label="CFG Guidance Weight", minimum=1, maximum=10, value=5, step=0.5)
120
- temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, value=1.0, step=0.05)
121
 
122
- generate_button = gr.Button("Generate Images")
123
- output_gallery = gr.Gallery(label="Generated Images", columns=2, height=300)
124
-
125
- generate_button.click(
126
- generate_image,
127
- inputs=[prompt_input, seed_input, guidance_slider, temperature_slider],
128
- outputs=output_gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  )
130
 
131
- # Footer
132
  gr.Markdown("""
133
- <hr>
134
- <p style="text-align: center; font-size: 0.9em;">
135
- Created with ❤️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
136
- </p>
137
  """)
138
-
 
 
 
 
 
 
 
 
139
  # Visitor Badge
140
  gr.HTML("""
141
- <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F"><img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759" /></a>
 
 
 
 
 
 
142
  """)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return demo
145
 
146
- demo = create_interface()
147
- demo.launch(share=True)
 
 
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from PIL import Image
6
  import numpy as np
7
+ import spaces
8
+ import logging
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Constants
15
+ DEFAULT_WIDTH = 384
16
+ DEFAULT_HEIGHT = 384
17
+ PARALLEL_SIZE = 5
18
+ PATCH_SIZE = 16
19
+
20
+ # Load model and processor with error handling
21
+ def load_model():
22
+ try:
23
+ model_path = "deepseek-ai/Janus-Pro-7B"
24
+ config = AutoConfig.from_pretrained(model_path)
25
+ language_config = config.language_config
26
+ language_config._attn_implementation = 'eager'
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ logger.info(f"Loading model on device: {device}")
30
+
31
+ vl_gpt = AutoModelForCausalLM.from_pretrained(
32
+ model_path,
33
+ language_config=language_config,
34
+ trust_remote_code=True,
35
+ torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32
36
+ ).to(device)
37
+
38
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
39
+ return vl_gpt, vl_chat_processor, device
40
 
41
+ except Exception as e:
42
+ logger.error(f"Model loading failed: {str(e)}")
43
+ raise RuntimeError("Failed to load model. Please check the model path and dependencies.")
44
+
45
+ try:
46
+ vl_gpt, vl_chat_processor, device = load_model()
47
+ tokenizer = vl_chat_processor.tokenizer
48
+ except RuntimeError as e:
49
+ raise e
50
+
51
+ # Helper functions with improved memory management
52
+ def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, progress=None):
53
+ try:
54
+ torch.cuda.empty_cache()
55
+ tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=device)
56
+
57
+ for i in range(parallel_size * 2):
58
+ tokens[i, :] = input_ids
59
+ if i % 2 != 0:
60
+ tokens[i, 1:-1] = vl_chat_processor.pad_id
61
 
 
 
62
  with torch.no_grad():
63
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
64
+ generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device)
65
+
66
+ pkv = None
67
+ for i in range(576):
68
+ if progress:
69
+ progress((i + 1) / 576, desc="Generating image tokens")
70
+
71
+ outputs = vl_gpt.language_model.model(
72
+ inputs_embeds=inputs_embeds,
73
+ use_cache=True,
74
+ past_key_values=pkv
75
+ )
76
+ pkv = outputs.past_key_values
77
+ hidden_states = outputs.last_hidden_state
78
+ logits = vl_gpt.gen_head(hidden_states[:, -1, :])
79
+
80
+ logit_cond = logits[0::2, :]
81
+ logit_uncond = logits[1::2, :]
82
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
83
+
84
+ probs = torch.softmax(logits / temperature, dim=-1)
85
+ next_token = torch.multinomial(probs, num_samples=1)
86
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
87
+
88
+ next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
89
+ img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
90
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
91
+
92
+ return generated_tokens
93
+
94
+ except RuntimeError as e:
95
+ logger.error(f"Generation error: {str(e)}")
96
+ raise RuntimeError("Generation failed due to memory constraints. Try reducing the parallel size.")
97
+ finally:
98
+ torch.cuda.empty_cache()
99
 
100
  def unpack(patches, width, height, parallel_size=5):
101
+ try:
102
+ patches = patches.detach().to(device='cpu', dtype=torch.float32).numpy()
103
+ patches = patches.transpose(0, 2, 3, 1)
104
+ patches = np.clip((patches + 1) / 2 * 255, 0, 255)
105
+ return [Image.fromarray(patch.astype(np.uint8)) for patch in patches]
106
+ except Exception as e:
107
+ logger.error(f"Unpacking error: {str(e)}")
108
+ raise RuntimeError("Failed to process generated image data.")
109
 
110
  @torch.inference_mode()
111
  @spaces.GPU(duration=120)
112
+ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress()):
113
+ try:
114
+ if not prompt.strip():
115
+ raise gr.Error("Please enter a valid prompt.")
116
+
117
+ progress(0, desc="Initializing...")
118
+ torch.cuda.empty_cache()
119
+
120
+ # Seed management
121
+ if seed is None:
122
+ seed = torch.seed()
123
+ else:
124
+ seed = int(seed)
125
+
126
  torch.manual_seed(seed)
127
+ if device.type == "cuda":
128
+ torch.cuda.manual_seed(seed)
129
+
130
+ messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}]
131
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
132
+ conversations=messages,
133
+ sft_format=vl_chat_processor.sft_format,
134
+ system_prompt=''
135
+ ) + vl_chat_processor.image_start_tag
136
+
137
+ input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device)
138
+ progress(0.1, desc="Generating image tokens...")
139
+
140
+ generated_tokens = generate(
141
+ input_ids,
142
+ DEFAULT_WIDTH,
143
+ DEFAULT_HEIGHT,
144
+ cfg_weight=guidance,
145
+ temperature=t2i_temperature,
146
+ parallel_size=PARALLEL_SIZE,
147
+ progress=progress
148
+ )
149
+
150
+ progress(0.9, desc="Processing images...")
151
+ patches = vl_gpt.gen_vision_model.decode_code(
152
+ generated_tokens.to(dtype=torch.int),
153
+ shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE]
154
+ )
155
+
156
+ images = unpack(patches, DEFAULT_WIDTH, DEFAULT_HEIGHT, PARALLEL_SIZE)
157
+ return images
158
 
159
+ except Exception as e:
160
+ logger.error(f"Generation failed: {str(e)}")
161
+ raise gr.Error(f"Image generation failed: {str(e)}")
162
 
 
163
  def create_interface():
164
+ with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo:
165
  gr.Markdown("""
166
  # Text-to-Image Generation with Janus-Pro-7B
167
+ **Generate high-quality images from text prompts using DeepSeek's advanced multimodal AI model.**
 
 
 
 
 
 
168
  """)
 
 
 
 
 
 
 
 
169
 
170
+ with gr.Row():
171
+ with gr.Column(scale=3):
172
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate...", lines=3)
173
+ generate_btn = gr.Button("Generate Images", variant="primary")
174
+
175
+ with gr.Accordion("Advanced Settings", open=False):
176
+ with gr.Group():
177
+ seed_input = gr.Number(label="Seed", value=None, precision=0, description="Leave empty for random seed")
178
+ guidance_slider = gr.Slider(3, 10, value=5, step=0.5,
179
+ label="CFG Guidance Weight",
180
+ info="Higher values = more prompt adherence, lower values = more creativity")
181
+ temp_slider = gr.Slider(0.1, 1.0, value=1.0, step=0.1,
182
+ label="Temperature",
183
+ info="Higher values = more randomness, lower values = more deterministic")
184
+
185
+ with gr.Column(scale=2):
186
+ output_gallery = gr.Gallery(label="Generated Images", columns=2, height=600, preview=True)
187
+ status = gr.Textbox(label="Status", interactive=False)
188
+
189
+ gr.Examples(
190
+ examples=[
191
+ ["A futuristic cityscape at sunset with flying cars and holographic advertisements"],
192
+ ["An astronaut riding a horse in photorealistic style"],
193
+ ["A cute robotic cat sitting on a stack of ancient books, digital art"]
194
+ ],
195
+ inputs=prompt_input
196
  )
197
 
 
198
  gr.Markdown("""
199
+ ## Model Information
200
+ - **Model:** [Janus-Pro-7B](https://huggingface.co/deepseek-ai/Janus-Pro-7B)
201
+ - **Output Resolution:** 384x384 pixels
202
+ - **Parallel Generation:** 5 images per request
203
  """)
204
+
205
+ # Footer Section
206
+ gr.Markdown("""
207
+ <hr style="margin-top: 2em; margin-bottom: 1em;">
208
+ <div style="text-align: center; color: #666; font-size: 0.9em;">
209
+ Created with ❤️ by <a href="https://bilsimaging.com" target="_blank" style="color: #2563eb; text-decoration: none;">bilsimaging.com</a>
210
+ </div>
211
+ """)
212
+
213
  # Visitor Badge
214
  gr.HTML("""
215
+ <div style="text-align: center; margin-top: 1em;">
216
+ <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F">
217
+ <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759"
218
+ alt="Visitor Badge"
219
+ style="display: inline-block; margin: 0 auto;">
220
+ </a>
221
+ </div>
222
  """)
223
 
224
+ generate_btn.click(
225
+ generate_image,
226
+ inputs=[prompt_input, seed_input, guidance_slider, temp_slider],
227
+ outputs=output_gallery,
228
+ api_name="generate"
229
+ )
230
+
231
+ demo.load(
232
+ fn=lambda: f"Device Status: {'GPU ✅' if device.type == 'cuda' else 'CPU ⚠️'}",
233
+ outputs=status,
234
+ queue=False
235
+ )
236
+
237
  return demo
238
 
239
+ if __name__ == "__main__":
240
+ demo = create_interface()
241
+ demo.launch(share=True)