Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,144 +4,238 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
-
import spaces
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
pkv = None
|
40 |
-
for i in range(576):
|
41 |
with torch.no_grad():
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def unpack(patches, width, height, parallel_size=5):
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
72 |
|
73 |
@torch.inference_mode()
|
74 |
@spaces.GPU(duration=120)
|
75 |
-
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
torch.manual_seed(seed)
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
|
|
|
|
99 |
|
100 |
-
# Gradio interface
|
101 |
def create_interface():
|
102 |
-
with gr.Blocks() as demo:
|
103 |
gr.Markdown("""
|
104 |
# Text-to-Image Generation with Janus-Pro-7B
|
105 |
-
|
106 |
-
Welcome to the Janus-Pro-7B Text-to-Image Generator! This advanced AI model by DeepSeek offers state-of-the-art capabilities in generating images from textual descriptions. Leveraging a unified multimodal framework, Janus-Pro-7B excels in both understanding and generating content, providing detailed and accurate visual representations based on your prompts.
|
107 |
-
|
108 |
-
**Key Features:**
|
109 |
-
- **High-Quality Image Generation:** Produces stable and detailed images that often surpass those from other leading models.
|
110 |
-
|
111 |
-
For more information about Janus-Pro-7B, visit the [official Hugging Face model page](https://huggingface.co/deepseek-ai/Janus-Pro-7B).
|
112 |
""")
|
113 |
-
|
114 |
-
prompt_input = gr.Textbox(label="Prompt (describe the image)")
|
115 |
-
|
116 |
-
# Option to toggle additional parameters
|
117 |
-
with gr.Accordion("Advanced Parameters", open=False):
|
118 |
-
seed_input = gr.Number(label="Seed (Optional)", value=12345, precision=0)
|
119 |
-
guidance_slider = gr.Slider(label="CFG Guidance Weight", minimum=1, maximum=10, value=5, step=0.5)
|
120 |
-
temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, value=1.0, step=0.05)
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
)
|
130 |
|
131 |
-
# Footer
|
132 |
gr.Markdown("""
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
""")
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
# Visitor Badge
|
140 |
gr.HTML("""
|
141 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
""")
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
return demo
|
145 |
|
146 |
-
|
147 |
-
demo
|
|
|
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from PIL import Image
|
6 |
import numpy as np
|
7 |
+
import spaces
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Configure logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# Constants
|
15 |
+
DEFAULT_WIDTH = 384
|
16 |
+
DEFAULT_HEIGHT = 384
|
17 |
+
PARALLEL_SIZE = 5
|
18 |
+
PATCH_SIZE = 16
|
19 |
+
|
20 |
+
# Load model and processor with error handling
|
21 |
+
def load_model():
|
22 |
+
try:
|
23 |
+
model_path = "deepseek-ai/Janus-Pro-7B"
|
24 |
+
config = AutoConfig.from_pretrained(model_path)
|
25 |
+
language_config = config.language_config
|
26 |
+
language_config._attn_implementation = 'eager'
|
27 |
+
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
logger.info(f"Loading model on device: {device}")
|
30 |
+
|
31 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(
|
32 |
+
model_path,
|
33 |
+
language_config=language_config,
|
34 |
+
trust_remote_code=True,
|
35 |
+
torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32
|
36 |
+
).to(device)
|
37 |
+
|
38 |
+
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
39 |
+
return vl_gpt, vl_chat_processor, device
|
40 |
|
41 |
+
except Exception as e:
|
42 |
+
logger.error(f"Model loading failed: {str(e)}")
|
43 |
+
raise RuntimeError("Failed to load model. Please check the model path and dependencies.")
|
44 |
+
|
45 |
+
try:
|
46 |
+
vl_gpt, vl_chat_processor, device = load_model()
|
47 |
+
tokenizer = vl_chat_processor.tokenizer
|
48 |
+
except RuntimeError as e:
|
49 |
+
raise e
|
50 |
+
|
51 |
+
# Helper functions with improved memory management
|
52 |
+
def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, progress=None):
|
53 |
+
try:
|
54 |
+
torch.cuda.empty_cache()
|
55 |
+
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=device)
|
56 |
+
|
57 |
+
for i in range(parallel_size * 2):
|
58 |
+
tokens[i, :] = input_ids
|
59 |
+
if i % 2 != 0:
|
60 |
+
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
61 |
|
|
|
|
|
62 |
with torch.no_grad():
|
63 |
+
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
64 |
+
generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device)
|
65 |
+
|
66 |
+
pkv = None
|
67 |
+
for i in range(576):
|
68 |
+
if progress:
|
69 |
+
progress((i + 1) / 576, desc="Generating image tokens")
|
70 |
+
|
71 |
+
outputs = vl_gpt.language_model.model(
|
72 |
+
inputs_embeds=inputs_embeds,
|
73 |
+
use_cache=True,
|
74 |
+
past_key_values=pkv
|
75 |
+
)
|
76 |
+
pkv = outputs.past_key_values
|
77 |
+
hidden_states = outputs.last_hidden_state
|
78 |
+
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
79 |
+
|
80 |
+
logit_cond = logits[0::2, :]
|
81 |
+
logit_uncond = logits[1::2, :]
|
82 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
83 |
+
|
84 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
85 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
86 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
87 |
+
|
88 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1)
|
89 |
+
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
90 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
91 |
+
|
92 |
+
return generated_tokens
|
93 |
+
|
94 |
+
except RuntimeError as e:
|
95 |
+
logger.error(f"Generation error: {str(e)}")
|
96 |
+
raise RuntimeError("Generation failed due to memory constraints. Try reducing the parallel size.")
|
97 |
+
finally:
|
98 |
+
torch.cuda.empty_cache()
|
99 |
|
100 |
def unpack(patches, width, height, parallel_size=5):
|
101 |
+
try:
|
102 |
+
patches = patches.detach().to(device='cpu', dtype=torch.float32).numpy()
|
103 |
+
patches = patches.transpose(0, 2, 3, 1)
|
104 |
+
patches = np.clip((patches + 1) / 2 * 255, 0, 255)
|
105 |
+
return [Image.fromarray(patch.astype(np.uint8)) for patch in patches]
|
106 |
+
except Exception as e:
|
107 |
+
logger.error(f"Unpacking error: {str(e)}")
|
108 |
+
raise RuntimeError("Failed to process generated image data.")
|
109 |
|
110 |
@torch.inference_mode()
|
111 |
@spaces.GPU(duration=120)
|
112 |
+
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress()):
|
113 |
+
try:
|
114 |
+
if not prompt.strip():
|
115 |
+
raise gr.Error("Please enter a valid prompt.")
|
116 |
+
|
117 |
+
progress(0, desc="Initializing...")
|
118 |
+
torch.cuda.empty_cache()
|
119 |
+
|
120 |
+
# Seed management
|
121 |
+
if seed is None:
|
122 |
+
seed = torch.seed()
|
123 |
+
else:
|
124 |
+
seed = int(seed)
|
125 |
+
|
126 |
torch.manual_seed(seed)
|
127 |
+
if device.type == "cuda":
|
128 |
+
torch.cuda.manual_seed(seed)
|
129 |
+
|
130 |
+
messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}]
|
131 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
132 |
+
conversations=messages,
|
133 |
+
sft_format=vl_chat_processor.sft_format,
|
134 |
+
system_prompt=''
|
135 |
+
) + vl_chat_processor.image_start_tag
|
136 |
+
|
137 |
+
input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device)
|
138 |
+
progress(0.1, desc="Generating image tokens...")
|
139 |
+
|
140 |
+
generated_tokens = generate(
|
141 |
+
input_ids,
|
142 |
+
DEFAULT_WIDTH,
|
143 |
+
DEFAULT_HEIGHT,
|
144 |
+
cfg_weight=guidance,
|
145 |
+
temperature=t2i_temperature,
|
146 |
+
parallel_size=PARALLEL_SIZE,
|
147 |
+
progress=progress
|
148 |
+
)
|
149 |
+
|
150 |
+
progress(0.9, desc="Processing images...")
|
151 |
+
patches = vl_gpt.gen_vision_model.decode_code(
|
152 |
+
generated_tokens.to(dtype=torch.int),
|
153 |
+
shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE]
|
154 |
+
)
|
155 |
+
|
156 |
+
images = unpack(patches, DEFAULT_WIDTH, DEFAULT_HEIGHT, PARALLEL_SIZE)
|
157 |
+
return images
|
158 |
|
159 |
+
except Exception as e:
|
160 |
+
logger.error(f"Generation failed: {str(e)}")
|
161 |
+
raise gr.Error(f"Image generation failed: {str(e)}")
|
162 |
|
|
|
163 |
def create_interface():
|
164 |
+
with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo:
|
165 |
gr.Markdown("""
|
166 |
# Text-to-Image Generation with Janus-Pro-7B
|
167 |
+
**Generate high-quality images from text prompts using DeepSeek's advanced multimodal AI model.**
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column(scale=3):
|
172 |
+
prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate...", lines=3)
|
173 |
+
generate_btn = gr.Button("Generate Images", variant="primary")
|
174 |
+
|
175 |
+
with gr.Accordion("Advanced Settings", open=False):
|
176 |
+
with gr.Group():
|
177 |
+
seed_input = gr.Number(label="Seed", value=None, precision=0, description="Leave empty for random seed")
|
178 |
+
guidance_slider = gr.Slider(3, 10, value=5, step=0.5,
|
179 |
+
label="CFG Guidance Weight",
|
180 |
+
info="Higher values = more prompt adherence, lower values = more creativity")
|
181 |
+
temp_slider = gr.Slider(0.1, 1.0, value=1.0, step=0.1,
|
182 |
+
label="Temperature",
|
183 |
+
info="Higher values = more randomness, lower values = more deterministic")
|
184 |
+
|
185 |
+
with gr.Column(scale=2):
|
186 |
+
output_gallery = gr.Gallery(label="Generated Images", columns=2, height=600, preview=True)
|
187 |
+
status = gr.Textbox(label="Status", interactive=False)
|
188 |
+
|
189 |
+
gr.Examples(
|
190 |
+
examples=[
|
191 |
+
["A futuristic cityscape at sunset with flying cars and holographic advertisements"],
|
192 |
+
["An astronaut riding a horse in photorealistic style"],
|
193 |
+
["A cute robotic cat sitting on a stack of ancient books, digital art"]
|
194 |
+
],
|
195 |
+
inputs=prompt_input
|
196 |
)
|
197 |
|
|
|
198 |
gr.Markdown("""
|
199 |
+
## Model Information
|
200 |
+
- **Model:** [Janus-Pro-7B](https://huggingface.co/deepseek-ai/Janus-Pro-7B)
|
201 |
+
- **Output Resolution:** 384x384 pixels
|
202 |
+
- **Parallel Generation:** 5 images per request
|
203 |
""")
|
204 |
+
|
205 |
+
# Footer Section
|
206 |
+
gr.Markdown("""
|
207 |
+
<hr style="margin-top: 2em; margin-bottom: 1em;">
|
208 |
+
<div style="text-align: center; color: #666; font-size: 0.9em;">
|
209 |
+
Created with ❤️ by <a href="https://bilsimaging.com" target="_blank" style="color: #2563eb; text-decoration: none;">bilsimaging.com</a>
|
210 |
+
</div>
|
211 |
+
""")
|
212 |
+
|
213 |
# Visitor Badge
|
214 |
gr.HTML("""
|
215 |
+
<div style="text-align: center; margin-top: 1em;">
|
216 |
+
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F">
|
217 |
+
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759"
|
218 |
+
alt="Visitor Badge"
|
219 |
+
style="display: inline-block; margin: 0 auto;">
|
220 |
+
</a>
|
221 |
+
</div>
|
222 |
""")
|
223 |
|
224 |
+
generate_btn.click(
|
225 |
+
generate_image,
|
226 |
+
inputs=[prompt_input, seed_input, guidance_slider, temp_slider],
|
227 |
+
outputs=output_gallery,
|
228 |
+
api_name="generate"
|
229 |
+
)
|
230 |
+
|
231 |
+
demo.load(
|
232 |
+
fn=lambda: f"Device Status: {'GPU ✅' if device.type == 'cuda' else 'CPU ⚠️'}",
|
233 |
+
outputs=status,
|
234 |
+
queue=False
|
235 |
+
)
|
236 |
+
|
237 |
return demo
|
238 |
|
239 |
+
if __name__ == "__main__":
|
240 |
+
demo = create_interface()
|
241 |
+
demo.launch(share=True)
|