Gordonkl commited on
Commit
c54749c
·
verified ·
1 Parent(s): 8c72aec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -123
app.py CHANGED
@@ -15,7 +15,6 @@ from diffusers import (
15
  AutoencoderKL,
16
  ControlNetModel,
17
  StableDiffusionXLControlNetPipeline,
18
-
19
  )
20
  from ip_adapter import CSGO
21
  from transformers import BlipProcessor, BlipForConditionalGeneration
@@ -29,7 +28,6 @@ os.system("mv IP-Adapter/sdxl_models sdxl_models")
29
 
30
  from huggingface_hub import hf_hub_download
31
 
32
- # hf_hub_download(repo_id="h94/IP-Adapter", filename="sdxl_models/image_encoder", local_dir="./sdxl_models/image_encoder")
33
  hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/")
34
  os.system('rm -rf IP-Adapter/models')
35
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -39,21 +37,13 @@ pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix'
39
  controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic"
40
  weight_dtype = torch.float16
41
 
42
-
43
  os.system("git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic")
44
  os.system("mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors")
45
  os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors')
46
- os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors')
47
  controlnet_path = "./TTPLanet_SDXL_Controlnet_Tile_Realistic"
48
 
49
-
50
- # os.system('git clone https://huggingface.co/InstantX/CSGO')
51
- # os.system('rm -rf CSGO/csgo.bin')
52
-
53
-
54
-
55
  vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
56
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
57
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
58
  base_model_path,
59
  controlnet=controlnet,
@@ -63,7 +53,6 @@ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
63
  )
64
  pipe.enable_vae_tiling()
65
 
66
-
67
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
68
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
69
 
@@ -88,10 +77,6 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
88
  seed = random.randint(0, MAX_SEED)
89
  return seed
90
 
91
-
92
-
93
-
94
-
95
  def get_example():
96
  case = [
97
  [
@@ -137,8 +122,7 @@ def get_example():
137
  ]
138
  return case
139
 
140
-
141
- def run_for_examples(content_image_pil,style_image_pil,target, prompt, scale_c, scale_s,guidance_scale,seed):
142
  return create_image(
143
  content_image_pil=content_image_pil,
144
  style_image_pil=style_image_pil,
@@ -151,11 +135,271 @@ def run_for_examples(content_image_pil,style_image_pil,target, prompt, scale_c,
151
  seed=seed,
152
  target=target,
153
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
155
  if randomize_seed:
156
  seed = random.randint(0, MAX_SEED)
157
  return seed
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def image_grid(imgs, rows, cols):
160
  assert len(imgs) == rows * cols
161
 
@@ -166,6 +410,7 @@ def image_grid(imgs, rows, cols):
166
  for i, img in enumerate(imgs):
167
  grid.paste(img, box=(i % cols * w, i // cols * h))
168
  return grid
 
169
  @spaces.GPU
170
  def create_image(content_image_pil,
171
  style_image_pil,
@@ -178,22 +423,20 @@ def create_image(content_image_pil,
178
  seed,
179
  target="Image-Driven Style Transfer",
180
  ):
181
-
182
-
183
  if content_image_pil is None:
184
  content_image_pil = Image.fromarray(
185
  np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
186
 
187
  if prompt == '':
188
-
189
  inputs = blip_processor(content_image_pil, return_tensors="pt").to(device)
190
  out = blip_model.generate(**inputs)
191
  prompt = blip_processor.decode(out[0], skip_special_tokens=True)
 
192
  width, height, content_image = resize_content(content_image_pil)
193
  style_image = style_image_pil
194
- neg_content_prompt='text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry'
195
- if target =="Image-Driven Style Transfer":
196
 
 
197
  images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
198
  prompt=prompt,
199
  negative_prompt=neg_content_prompt,
@@ -207,10 +450,9 @@ def create_image(content_image_pil,
207
  num_samples=1,
208
  seed=seed,
209
  image=content_image.convert('RGB'),
210
- controlnet_conditioning_scale=scale_c,
211
- )
212
 
213
- elif target =="Text-Driven Style Synthesis":
214
  content_image = Image.fromarray(
215
  np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
216
 
@@ -227,11 +469,9 @@ def create_image(content_image_pil,
227
  num_samples=1,
228
  seed=42,
229
  image=content_image.convert('RGB'),
230
- controlnet_conditioning_scale=scale_c,
231
- )
232
- elif target =="Text Edit-Driven Style Synthesis":
233
-
234
 
 
235
  images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
236
  prompt=prompt,
237
  negative_prompt=neg_content_prompt,
@@ -245,25 +485,17 @@ def create_image(content_image_pil,
245
  num_samples=1,
246
  seed=seed,
247
  image=content_image.convert('RGB'),
248
- controlnet_conditioning_scale=scale_c,
249
- )
250
 
251
  return [image_grid(images, 1, num_samples)]
252
 
253
-
254
- def pil_to_cv2(image_pil):
255
- image_np = np.array(image_pil)
256
- image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
257
- return image_cv2
258
-
259
-
260
  # Description
261
  title = r"""
262
  <h1 align="center">CSGO: Content-Style Composition in Text-to-Image Generation</h1>
263
  """
264
 
265
  description = r"""
266
- <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br>
267
  How to use:<br>
268
  1. Upload a content image if you want to use image-driven style transfer.
269
  2. Upload a style image.
@@ -294,88 +526,10 @@ If our work is helpful for your research or applications, please cite us via:
294
  year={2024},
295
  journal = {arXiv 2408.16766},
296
  }
297
- ```
298
- 📧 **Contact**
299
- <br>
300
- If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
301
- """
302
-
303
- block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
304
- with block:
305
- # description
306
- gr.Markdown(title)
307
- gr.Markdown(description)
308
-
309
- with gr.Tabs():
310
- with gr.Row():
311
- with gr.Column():
312
- with gr.Row():
313
- with gr.Column():
314
- content_image_pil = gr.Image(label="Content Image (optional)", type='pil')
315
- style_image_pil = gr.Image(label="Style Image", type='pil')
316
-
317
- target = gr.Radio(["Image-Driven Style Transfer", "Text-Driven Style Synthesis", "Text Edit-Driven Style Synthesis"],
318
- value="Image-Driven Style Transfer",
319
- label="task")
320
-
321
- # prompt_type = gr.Radio(["caption of Blip", "user input"],
322
- # value="caption of Blip",
323
- # label="prompt type")
324
-
325
- prompt = gr.Textbox(label="Prompt",
326
- value="there is a small house with a sheep statue on top of it")
327
- prompt_type = gr.CheckboxGroup(
328
- ["caption of Blip", "user input"], label="prompt_type", value=["caption of Blip"],
329
- info="Choose to enter more detailed prompts yourself or use the blip model to describe content images."
330
- )
331
- if prompt_type == "caption of Blip" and target == "Image-Driven Style Transfer":
332
- prompt =''
333
-
334
- scale_c = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=0.6, label="Content Scale")
335
- scale_s = gr.Slider(minimum=0, maximum=2.0, step=0.01, value=1.0, label="Style Scale")
336
- with gr.Accordion(open=False, label="Advanced Options"):
337
-
338
- guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label="guidance scale")
339
- num_samples = gr.Slider(minimum=1, maximum=4.0, step=1.0, value=1.0, label="num samples")
340
- num_inference_steps = gr.Slider(minimum=5, maximum=100.0, step=1.0, value=50,
341
- label="num inference steps")
342
- seed = gr.Slider(minimum=-1000000, maximum=1000000, value=1, step=1, label="Seed Value")
343
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
344
-
345
- generate_button = gr.Button("Generate Image")
346
-
347
- with gr.Column():
348
- generated_image = gr.Gallery(label="Generated Image")
349
-
350
- generate_button.click(
351
- fn=randomize_seed_fn,
352
- inputs=[seed, randomize_seed],
353
- outputs=seed,
354
- queue=False,
355
- api_name=False,
356
- ).then(
357
- fn=create_image,
358
- inputs=[content_image_pil,
359
- style_image_pil,
360
- prompt,
361
- scale_c,
362
- scale_s,
363
- guidance_scale,
364
- num_samples,
365
- num_inference_steps,
366
- seed,
367
- target,],
368
- outputs=[generated_image])
369
-
370
- gr.Examples(
371
- examples=get_example(),
372
- inputs=[content_image_pil,style_image_pil,target, prompt, scale_c, scale_s,guidance_scale,seed],
373
- fn=run_for_examples,
374
- outputs=[generated_image],
375
- cache_examples=False,
376
- )
377
-
378
- gr.Markdown(article)
379
 
 
 
 
 
380
 
381
- block.launch()
 
15
  AutoencoderKL,
16
  ControlNetModel,
17
  StableDiffusionXLControlNetPipeline,
 
18
  )
19
  from ip_adapter import CSGO
20
  from transformers import BlipProcessor, BlipForConditionalGeneration
 
28
 
29
  from huggingface_hub import hf_hub_download
30
 
 
31
  hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/")
32
  os.system('rm -rf IP-Adapter/models')
33
  base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
 
37
  controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic"
38
  weight_dtype = torch.float16
39
 
 
40
  os.system("git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic")
41
  os.system("mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors")
42
  os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors')
 
43
  controlnet_path = "./TTPLanet_SDXL_Controlnet_Tile_Realistic"
44
 
 
 
 
 
 
 
45
  vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
46
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16, use_safetensors=True)
47
  pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
48
  base_model_path,
49
  controlnet=controlnet,
 
53
  )
54
  pipe.enable_vae_tiling()
55
 
 
56
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
57
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
58
 
 
77
  seed = random.randint(0, MAX_SEED)
78
  return seed
79
 
 
 
 
 
80
  def get_example():
81
  case = [
82
  [
 
122
  ]
123
  return case
124
 
125
+ def run_for_examples(content_image_pil, style_image_pil, target, prompt, scale_c, scale_s, guidance_scale, seed):
 
126
  return create_image(
127
  content_image_pil=content_image_pil,
128
  style_image_pil=style_image_pil,
 
135
  seed=seed,
136
  target=target,
137
  )
138
+
139
+ def image_grid(imgs, rows, cols):
140
+ assert len(imgs) == rows * cols
141
+
142
+ w, h = imgs[0].size
143
+ grid = Image.new('RGB', size=(cols * w, rows * h))
144
+ grid_w, grid_h = grid.size
145
+
146
+ for i, img in enumerate(imgs):
147
+ grid.paste(img, box=(i % cols * w, i // cols * h))
148
+ return grid
149
+
150
+ @spaces.GPU
151
+ def create_image(content_image_pil,
152
+ style_image_pil,
153
+ prompt,
154
+ scale_c,
155
+ scale_s,
156
+ guidance_scale,
157
+ num_samples,
158
+ num_inference_steps,
159
+ seed,
160
+ target="Image-Driven Style Transfer",
161
+ ):
162
+ if content_image_pil is None:
163
+ content_image_pil = Image.fromarray(
164
+ np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
165
+
166
+ if prompt == '':
167
+ inputs = blip_processor(content_image_pil, return_tensors="pt").to(device)
168
+ out = blip_model.generate(**inputs)
169
+ prompt = blip_processor.decode(out[0], skip_special_tokens=True)
170
+
171
+ width, height, content_image = resize_content(content_image_pil)
172
+ style_image = style_image_pil
173
+ neg_content_prompt = 'text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry'
174
+
175
+ if target == "Image-Driven Style Transfer":
176
+ images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
177
+ prompt=prompt,
178
+ negative_prompt=neg_content_prompt,
179
+ height=height,
180
+ width=width,
181
+ content_scale=1.0,
182
+ style_scale=scale_s,
183
+ guidance_scale=guidance_scale,
184
+ num_images_per_prompt=num_samples,
185
+ num_inference_steps=num_inference_steps,
186
+ num_samples=1,
187
+ seed=seed,
188
+ image=content_image.convert('RGB'),
189
+ controlnet_conditioning_scale=scale_c)
190
+
191
+ elif target == "Text-Driven Style Synthesis":
192
+ content_image = Image.fromarray(
193
+ np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
194
+
195
+ images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
196
+ prompt=prompt,
197
+ negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
198
+ height=height,
199
+ width=width,
200
+ content_scale=0.5,
201
+ style_scale=scale_s,
202
+ guidance_scale=7,
203
+ num_images_per_prompt=num_samples,
204
+ num_inference_steps=num_inference_steps,
205
+ num_samples=1,
206
+ seed=42,
207
+ image=content_image.convert('RGB'),
208
+ controlnet_conditioning_scale=scale_c)
209
+
210
+ elif target == "Text Edit-Driven Style Synthesis":
211
+ images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
212
+ prompt=prompt,
213
+ negative_prompt=neg_content_prompt,
214
+ height=height,
215
+ width=width,
216
+ content_scale=1.0,
217
+ style_scale=scale_s,
218
+ guidance_scale=guidance_scale,
219
+ num_images_per_prompt=num_samples,
220
+ num_inference_steps=num_inference_steps,
221
+ num_samples=1,
222
+ seed=seed,
223
+ image=content_image.convert('RGB'),
224
+ controlnet_conditioning_scale=scale_c)
225
+
226
+ return [image_grid(images, 1, num_samples)]
227
+
228
+ # Description
229
+ title = r"""
230
+ <h1 align="center">CSGO: Content-Style Composition in Text-to-Image Generation</h1>
231
+ """
232
+
233
+ description = r"""
234
+ <b>Official Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br>
235
+ How to use:<br>
236
+ 1. Upload a content image if you want to use image-driven style transfer.
237
+ 2. Upload a style image.
238
+ 3. Sets the type of task to perform, by default image-driven style transfer is performed. Options are <b>Image-driven style transfer, Text-driven style synthesis, and Text editing-driven style synthesis<b>.
239
+ 4. <b>If you choose a text-driven task, enter your desired prompt<b>.
240
+ 5. If you don't provide a prompt, the default is to use the BLIP model to generate the caption. We suggest that by providing detailed prompts for Content images, CSGO is able to effectively guarantee content.
241
+ 6. Click the <b>Submit</b> button to begin customization.
242
+ 7. Share your stylized photo with your friends and enjoy! 😊
243
+
244
+ Advanced usage:<br>
245
+ 1. Click advanced options.
246
+ 2. Choose different guidance and steps.
247
+ """
248
+
249
+ article = r"""
250
+ ---
251
+ 📝 **Tips**
252
+ In CSGO, the more accurate the text prompts for content images, the better the content retention.
253
+ Text-driven style synthesis and text-edit-driven style synthesis are expected to be more stable in the next release.
254
+ ---
255
+ 📝 **Citation**
256
+ <br>
257
+ If our work is helpful for your research or applications, please cite us via:
258
+ ```bibtex
259
+ @article{xing2024csgo,
260
+ title={CSGO: Content-Style Composition in Text-to-Image Generation},
261
+ author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
262
+ year={2024},
263
+ journal = {arXiv 2408.16766},
264
+ }
265
+ import sys
266
+ sys.path.append('./')
267
+ import spaces
268
+ import gradio as gr
269
+ import torch
270
+ from ip_adapter.utils import BLOCKS as BLOCKS
271
+ from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
272
+ from ip_adapter.utils import resize_content
273
+ import cv2
274
+ import numpy as np
275
+ import random
276
+ from PIL import Image
277
+ from transformers import AutoImageProcessor, AutoModel
278
+ from diffusers import (
279
+ AutoencoderKL,
280
+ ControlNetModel,
281
+ StableDiffusionXLControlNetPipeline,
282
+ )
283
+ from ip_adapter import CSGO
284
+ from transformers import BlipProcessor, BlipForConditionalGeneration
285
+
286
+ device = "cuda" if torch.cuda.is_available() else "cpu"
287
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
288
+ import os
289
+ os.system("git lfs install")
290
+ os.system("git clone https://huggingface.co/h94/IP-Adapter")
291
+ os.system("mv IP-Adapter/sdxl_models sdxl_models")
292
+
293
+ from huggingface_hub import hf_hub_download
294
+
295
+ hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/")
296
+ os.system('rm -rf IP-Adapter/models')
297
+ base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
298
+ image_encoder_path = "sdxl_models/image_encoder"
299
+ csgo_ckpt ='./CSGO/csgo_4_32.bin'
300
+ pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix'
301
+ controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic"
302
+ weight_dtype = torch.float16
303
+
304
+ os.system("git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic")
305
+ os.system("mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors")
306
+ os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors')
307
+ controlnet_path = "./TTPLanet_SDXL_Controlnet_Tile_Realistic"
308
+
309
+ vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
310
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16, use_safetensors=True)
311
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
312
+ base_model_path,
313
+ controlnet=controlnet,
314
+ torch_dtype=torch.float16,
315
+ add_watermarker=False,
316
+ vae=vae
317
+ )
318
+ pipe.enable_vae_tiling()
319
+
320
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
321
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)
322
+
323
+ target_content_blocks = BLOCKS['content']
324
+ target_style_blocks = BLOCKS['style']
325
+ controlnet_target_content_blocks = controlnet_BLOCKS['content']
326
+ controlnet_target_style_blocks = controlnet_BLOCKS['style']
327
+
328
+ csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4, num_style_tokens=32,
329
+ target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,
330
+ controlnet_adapter=True,
331
+ controlnet_target_content_blocks=controlnet_target_content_blocks,
332
+ controlnet_target_style_blocks=controlnet_target_style_blocks,
333
+ content_model_resampler=True,
334
+ style_model_resampler=True,
335
+ )
336
+
337
+ MAX_SEED = np.iinfo(np.int32).max
338
+
339
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
340
  if randomize_seed:
341
  seed = random.randint(0, MAX_SEED)
342
  return seed
343
 
344
+ def get_example():
345
+ case = [
346
+ [
347
+ "./assets/img_0.png",
348
+ './assets/img_1.png',
349
+ "Image-Driven Style Transfer",
350
+ "there is a small house with a sheep statue on top of it",
351
+ 0.6,
352
+ 1.0,
353
+ 7.0,
354
+ 42
355
+ ],
356
+ [
357
+ None,
358
+ './assets/img_1.png',
359
+ "Text-Driven Style Synthesis",
360
+ "a cat",
361
+ 0.01,
362
+ 1.0,
363
+ 7.0,
364
+ 42
365
+ ],
366
+ [
367
+ None,
368
+ './assets/img_2.png',
369
+ "Text-Driven Style Synthesis",
370
+ "a cat",
371
+ 0.01,
372
+ 1.0,
373
+ 7.0,
374
+ 42,
375
+ ],
376
+ [
377
+ "./assets/img_0.png",
378
+ './assets/img_1.png',
379
+ "Text Edit-Driven Style Synthesis",
380
+ "there is a small house",
381
+ 0.4,
382
+ 1.0,
383
+ 7.0,
384
+ 42,
385
+ ],
386
+ ]
387
+ return case
388
+
389
+ def run_for_examples(content_image_pil, style_image_pil, target, prompt, scale_c, scale_s, guidance_scale, seed):
390
+ return create_image(
391
+ content_image_pil=content_image_pil,
392
+ style_image_pil=style_image_pil,
393
+ prompt=prompt,
394
+ scale_c=scale_c,
395
+ scale_s=scale_s,
396
+ guidance_scale=guidance_scale,
397
+ num_samples=2,
398
+ num_inference_steps=50,
399
+ seed=seed,
400
+ target=target,
401
+ )
402
+
403
  def image_grid(imgs, rows, cols):
404
  assert len(imgs) == rows * cols
405
 
 
410
  for i, img in enumerate(imgs):
411
  grid.paste(img, box=(i % cols * w, i // cols * h))
412
  return grid
413
+
414
  @spaces.GPU
415
  def create_image(content_image_pil,
416
  style_image_pil,
 
423
  seed,
424
  target="Image-Driven Style Transfer",
425
  ):
 
 
426
  if content_image_pil is None:
427
  content_image_pil = Image.fromarray(
428
  np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
429
 
430
  if prompt == '':
 
431
  inputs = blip_processor(content_image_pil, return_tensors="pt").to(device)
432
  out = blip_model.generate(**inputs)
433
  prompt = blip_processor.decode(out[0], skip_special_tokens=True)
434
+
435
  width, height, content_image = resize_content(content_image_pil)
436
  style_image = style_image_pil
437
+ neg_content_prompt = 'text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry'
 
438
 
439
+ if target == "Image-Driven Style Transfer":
440
  images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
441
  prompt=prompt,
442
  negative_prompt=neg_content_prompt,
 
450
  num_samples=1,
451
  seed=seed,
452
  image=content_image.convert('RGB'),
453
+ controlnet_conditioning_scale=scale_c)
 
454
 
455
+ elif target == "Text-Driven Style Synthesis":
456
  content_image = Image.fromarray(
457
  np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB')
458
 
 
469
  num_samples=1,
470
  seed=42,
471
  image=content_image.convert('RGB'),
472
+ controlnet_conditioning_scale=scale_c)
 
 
 
473
 
474
+ elif target == "Text Edit-Driven Style Synthesis":
475
  images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image,
476
  prompt=prompt,
477
  negative_prompt=neg_content_prompt,
 
485
  num_samples=1,
486
  seed=seed,
487
  image=content_image.convert('RGB'),
488
+ controlnet_conditioning_scale=scale_c)
 
489
 
490
  return [image_grid(images, 1, num_samples)]
491
 
 
 
 
 
 
 
 
492
  # Description
493
  title = r"""
494
  <h1 align="center">CSGO: Content-Style Composition in Text-to-Image Generation</h1>
495
  """
496
 
497
  description = r"""
498
+ <b>Official Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br>
499
  How to use:<br>
500
  1. Upload a content image if you want to use image-driven style transfer.
501
  2. Upload a style image.
 
526
  year={2024},
527
  journal = {arXiv 2408.16766},
528
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
+ ### Changes made:
531
+ 1. Replaced the emoji with a plain text representation for compatibility.
532
+ 2. Removed the redundant function definition.
533
+ 3. Ensured that the HTML and Gradio block components work without syntax issues.
534
 
535
+ Now you can try running this modified version of your script. Let me know if you encounter any further issues! &#8203;:contentReference[oaicite:0]{index=0}&#8203;