|
import sys |
|
sys.path.append('./') |
|
import spaces |
|
import gradio as gr |
|
import torch |
|
from ip_adapter.utils import BLOCKS as BLOCKS |
|
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS |
|
from ip_adapter.utils import resize_content |
|
import cv2 |
|
import numpy as np |
|
import random |
|
from PIL import Image |
|
from transformers import AutoImageProcessor, AutoModel |
|
from diffusers import ( |
|
AutoencoderKL, |
|
ControlNetModel, |
|
StableDiffusionXLControlNetPipeline, |
|
) |
|
from ip_adapter import CSGO |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 |
|
import os |
|
os.system("git lfs install") |
|
os.system("git clone https://huggingface.co/h94/IP-Adapter") |
|
os.system("mv IP-Adapter/sdxl_models sdxl_models") |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/") |
|
os.system('rm -rf IP-Adapter/models') |
|
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
|
image_encoder_path = "sdxl_models/image_encoder" |
|
csgo_ckpt ='./CSGO/csgo_4_32.bin' |
|
pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix' |
|
controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic" |
|
weight_dtype = torch.float16 |
|
|
|
os.system("git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic") |
|
os.system("mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors") |
|
os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors') |
|
controlnet_path = "./TTPLanet_SDXL_Controlnet_Tile_Realistic" |
|
|
|
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16) |
|
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16, use_safetensors=True) |
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
base_model_path, |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
add_watermarker=False, |
|
vae=vae |
|
) |
|
pipe.enable_vae_tiling() |
|
|
|
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
|
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device) |
|
|
|
target_content_blocks = BLOCKS['content'] |
|
target_style_blocks = BLOCKS['style'] |
|
controlnet_target_content_blocks = controlnet_BLOCKS['content'] |
|
controlnet_target_style_blocks = controlnet_BLOCKS['style'] |
|
|
|
csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4, num_style_tokens=32, |
|
target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks, |
|
controlnet_adapter=True, |
|
controlnet_target_content_blocks=controlnet_target_content_blocks, |
|
controlnet_target_style_blocks=controlnet_target_style_blocks, |
|
content_model_resampler=True, |
|
style_model_resampler=True, |
|
) |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
def get_example(): |
|
case = [ |
|
[ |
|
"./assets/img_0.png", |
|
'./assets/img_1.png', |
|
"Image-Driven Style Transfer", |
|
"there is a small house with a sheep statue on top of it", |
|
0.6, |
|
1.0, |
|
7.0, |
|
42 |
|
], |
|
[ |
|
None, |
|
'./assets/img_1.png', |
|
"Text-Driven Style Synthesis", |
|
"a cat", |
|
0.01, |
|
1.0, |
|
7.0, |
|
42 |
|
], |
|
[ |
|
None, |
|
'./assets/img_2.png', |
|
"Text-Driven Style Synthesis", |
|
"a cat", |
|
0.01, |
|
1.0, |
|
7.0, |
|
42, |
|
], |
|
[ |
|
"./assets/img_0.png", |
|
'./assets/img_1.png', |
|
"Text Edit-Driven Style Synthesis", |
|
"there is a small house", |
|
0.4, |
|
1.0, |
|
7.0, |
|
42, |
|
], |
|
] |
|
return case |
|
|
|
def run_for_examples(content_image_pil, style_image_pil, target, prompt, scale_c, scale_s, guidance_scale, seed): |
|
return create_image( |
|
content_image_pil=content_image_pil, |
|
style_image_pil=style_image_pil, |
|
prompt=prompt, |
|
scale_c=scale_c, |
|
scale_s=scale_s, |
|
guidance_scale=guidance_scale, |
|
num_samples=2, |
|
num_inference_steps=50, |
|
seed=seed, |
|
target=target, |
|
) |
|
|
|
def image_grid(imgs, rows, cols): |
|
assert len(imgs) == rows * cols |
|
|
|
w, h = imgs[0].size |
|
grid = Image.new('RGB', size=(cols * w, rows * h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(imgs): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|
|
@spaces.GPU |
|
def create_image(content_image_pil, |
|
style_image_pil, |
|
prompt, |
|
scale_c, |
|
scale_s, |
|
guidance_scale, |
|
num_samples, |
|
num_inference_steps, |
|
seed, |
|
target="Image-Driven Style Transfer", |
|
): |
|
if content_image_pil is None: |
|
content_image_pil = Image.fromarray( |
|
np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB') |
|
|
|
if prompt == '': |
|
inputs = blip_processor(content_image_pil, return_tensors="pt").to(device) |
|
out = blip_model.generate(**inputs) |
|
prompt = blip_processor.decode(out[0], skip_special_tokens=True) |
|
|
|
width, height, content_image = resize_content(content_image_pil) |
|
style_image = style_image_pil |
|
neg_content_prompt = 'text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry' |
|
|
|
if target == "Image-Driven Style Transfer": |
|
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, |
|
prompt=prompt, |
|
negative_prompt=neg_content_prompt, |
|
height=height, |
|
width=width, |
|
content_scale=1.0, |
|
style_scale=scale_s, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=num_inference_steps, |
|
num_samples=1, |
|
seed=seed, |
|
image=content_image.convert('RGB'), |
|
controlnet_conditioning_scale=scale_c) |
|
|
|
elif target == "Text-Driven Style Synthesis": |
|
content_image = Image.fromarray( |
|
np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB') |
|
|
|
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, |
|
prompt=prompt, |
|
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", |
|
height=height, |
|
width=width, |
|
content_scale=0.5, |
|
style_scale=scale_s, |
|
guidance_scale=7, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=num_inference_steps, |
|
num_samples=1, |
|
seed=42, |
|
image=content_image.convert('RGB'), |
|
controlnet_conditioning_scale=scale_c) |
|
|
|
elif target == "Text Edit-Driven Style Synthesis": |
|
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, |
|
prompt=prompt, |
|
negative_prompt=neg_content_prompt, |
|
height=height, |
|
width=width, |
|
content_scale=1.0, |
|
style_scale=scale_s, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=num_inference_steps, |
|
num_samples=1, |
|
seed=seed, |
|
image=content_image.convert('RGB'), |
|
controlnet_conditioning_scale=scale_c) |
|
|
|
return [image_grid(images, 1, num_samples)] |
|
|
|
|
|
title = r""" |
|
<h1 align="center">CSGO: Content-Style Composition in Text-to-Image Generation</h1> |
|
""" |
|
|
|
description = r""" |
|
<b>Official Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br> |
|
How to use:<br> |
|
1. Upload a content image if you want to use image-driven style transfer. |
|
2. Upload a style image. |
|
3. Sets the type of task to perform, by default image-driven style transfer is performed. Options are <b>Image-driven style transfer, Text-driven style synthesis, and Text editing-driven style synthesis<b>. |
|
4. <b>If you choose a text-driven task, enter your desired prompt<b>. |
|
5. If you don't provide a prompt, the default is to use the BLIP model to generate the caption. We suggest that by providing detailed prompts for Content images, CSGO is able to effectively guarantee content. |
|
6. Click the <b>Submit</b> button to begin customization. |
|
7. Share your stylized photo with your friends and enjoy! 😊 |
|
|
|
Advanced usage:<br> |
|
1. Click advanced options. |
|
2. Choose different guidance and steps. |
|
""" |
|
|
|
article = r""" |
|
--- |
|
📝 **Tips** |
|
In CSGO, the more accurate the text prompts for content images, the better the content retention. |
|
Text-driven style synthesis and text-edit-driven style synthesis are expected to be more stable in the next release. |
|
--- |
|
📝 **Citation** |
|
<br> |
|
If our work is helpful for your research or applications, please cite us via: |
|
```bibtex |
|
@article{xing2024csgo, |
|
title={CSGO: Content-Style Composition in Text-to-Image Generation}, |
|
author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li}, |
|
year={2024}, |
|
journal = {arXiv 2408.16766}, |
|
} |
|
import sys |
|
sys.path.append('./') |
|
import spaces |
|
import gradio as gr |
|
import torch |
|
from ip_adapter.utils import BLOCKS as BLOCKS |
|
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS |
|
from ip_adapter.utils import resize_content |
|
import cv2 |
|
import numpy as np |
|
import random |
|
from PIL import Image |
|
from transformers import AutoImageProcessor, AutoModel |
|
from diffusers import ( |
|
AutoencoderKL, |
|
ControlNetModel, |
|
StableDiffusionXLControlNetPipeline, |
|
) |
|
from ip_adapter import CSGO |
|
from transformers import BlipProcessor, BlipForConditionalGeneration |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32 |
|
import os |
|
os.system("git lfs install") |
|
os.system("git clone https://huggingface.co/h94/IP-Adapter") |
|
os.system("mv IP-Adapter/sdxl_models sdxl_models") |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/") |
|
os.system('rm -rf IP-Adapter/models') |
|
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
|
image_encoder_path = "sdxl_models/image_encoder" |
|
csgo_ckpt ='./CSGO/csgo_4_32.bin' |
|
pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix' |
|
controlnet_path = "TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic" |
|
weight_dtype = torch.float16 |
|
|
|
os.system("git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic") |
|
os.system("mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors") |
|
os.system('rm -rf TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v1_fp16.safetensors') |
|
controlnet_path = "./TTPLanet_SDXL_Controlnet_Tile_Realistic" |
|
|
|
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16) |
|
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16, use_safetensors=True) |
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
base_model_path, |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
add_watermarker=False, |
|
vae=vae |
|
) |
|
pipe.enable_vae_tiling() |
|
|
|
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") |
|
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device) |
|
|
|
target_content_blocks = BLOCKS['content'] |
|
target_style_blocks = BLOCKS['style'] |
|
controlnet_target_content_blocks = controlnet_BLOCKS['content'] |
|
controlnet_target_style_blocks = controlnet_BLOCKS['style'] |
|
|
|
csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4, num_style_tokens=32, |
|
target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks, |
|
controlnet_adapter=True, |
|
controlnet_target_content_blocks=controlnet_target_content_blocks, |
|
controlnet_target_style_blocks=controlnet_target_style_blocks, |
|
content_model_resampler=True, |
|
style_model_resampler=True, |
|
) |
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
def get_example(): |
|
case = [ |
|
[ |
|
"./assets/img_0.png", |
|
'./assets/img_1.png', |
|
"Image-Driven Style Transfer", |
|
"there is a small house with a sheep statue on top of it", |
|
0.6, |
|
1.0, |
|
7.0, |
|
42 |
|
], |
|
[ |
|
None, |
|
'./assets/img_1.png', |
|
"Text-Driven Style Synthesis", |
|
"a cat", |
|
0.01, |
|
1.0, |
|
7.0, |
|
42 |
|
], |
|
[ |
|
None, |
|
'./assets/img_2.png', |
|
"Text-Driven Style Synthesis", |
|
"a cat", |
|
0.01, |
|
1.0, |
|
7.0, |
|
42, |
|
], |
|
[ |
|
"./assets/img_0.png", |
|
'./assets/img_1.png', |
|
"Text Edit-Driven Style Synthesis", |
|
"there is a small house", |
|
0.4, |
|
1.0, |
|
7.0, |
|
42, |
|
], |
|
] |
|
return case |
|
|
|
def run_for_examples(content_image_pil, style_image_pil, target, prompt, scale_c, scale_s, guidance_scale, seed): |
|
return create_image( |
|
content_image_pil=content_image_pil, |
|
style_image_pil=style_image_pil, |
|
prompt=prompt, |
|
scale_c=scale_c, |
|
scale_s=scale_s, |
|
guidance_scale=guidance_scale, |
|
num_samples=2, |
|
num_inference_steps=50, |
|
seed=seed, |
|
target=target, |
|
) |
|
|
|
def image_grid(imgs, rows, cols): |
|
assert len(imgs) == rows * cols |
|
|
|
w, h = imgs[0].size |
|
grid = Image.new('RGB', size=(cols * w, rows * h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(imgs): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|
|
@spaces.GPU |
|
def create_image(content_image_pil, |
|
style_image_pil, |
|
prompt, |
|
scale_c, |
|
scale_s, |
|
guidance_scale, |
|
num_samples, |
|
num_inference_steps, |
|
seed, |
|
target="Image-Driven Style Transfer", |
|
): |
|
if content_image_pil is None: |
|
content_image_pil = Image.fromarray( |
|
np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB') |
|
|
|
if prompt == '': |
|
inputs = blip_processor(content_image_pil, return_tensors="pt").to(device) |
|
out = blip_model.generate(**inputs) |
|
prompt = blip_processor.decode(out[0], skip_special_tokens=True) |
|
|
|
width, height, content_image = resize_content(content_image_pil) |
|
style_image = style_image_pil |
|
neg_content_prompt = 'text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry' |
|
|
|
if target == "Image-Driven Style Transfer": |
|
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, |
|
prompt=prompt, |
|
negative_prompt=neg_content_prompt, |
|
height=height, |
|
width=width, |
|
content_scale=1.0, |
|
style_scale=scale_s, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=num_inference_steps, |
|
num_samples=1, |
|
seed=seed, |
|
image=content_image.convert('RGB'), |
|
controlnet_conditioning_scale=scale_c) |
|
|
|
elif target == "Text-Driven Style Synthesis": |
|
content_image = Image.fromarray( |
|
np.zeros((1024, 1024, 3), dtype=np.uint8)).convert('RGB') |
|
|
|
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, |
|
prompt=prompt, |
|
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry", |
|
height=height, |
|
width=width, |
|
content_scale=0.5, |
|
style_scale=scale_s, |
|
guidance_scale=7, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=num_inference_steps, |
|
num_samples=1, |
|
seed=42, |
|
image=content_image.convert('RGB'), |
|
controlnet_conditioning_scale=scale_c) |
|
|
|
elif target == "Text Edit-Driven Style Synthesis": |
|
images = csgo.generate(pil_content_image=content_image, pil_style_image=style_image, |
|
prompt=prompt, |
|
negative_prompt=neg_content_prompt, |
|
height=height, |
|
width=width, |
|
content_scale=1.0, |
|
style_scale=scale_s, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=num_samples, |
|
num_inference_steps=num_inference_steps, |
|
num_samples=1, |
|
seed=seed, |
|
image=content_image.convert('RGB'), |
|
controlnet_conditioning_scale=scale_c) |
|
|
|
return [image_grid(images, 1, num_samples)] |
|
|
|
# Description |
|
title = r""" |
|
<h1 align="center">CSGO: Content-Style Composition in Text-to-Image Generation</h1> |
|
""" |
|
|
|
description = r''' |
|
<b>Official Gradio demo</b> for <a href='https://github.com/instantX-research/CSGO' target='_blank'><b>CSGO: Content-Style Composition in Text-to-Image Generation</b></a>.<br> |
|
How to use:<br> |
|
1. Upload a content image if you want to use image-driven style transfer. |
|
2. Upload a style image. |
|
3. Sets the type of task to perform, by default image-driven style transfer is performed. Options are <b>Image-driven style transfer, Text-driven style synthesis, and Text editing-driven style synthesis<b>. |
|
4. <b>If you choose a text-driven task, enter your desired prompt<b>. |
|
5. If you don't provide a prompt, the default is to use the BLIP model to generate the caption. We suggest that by providing detailed prompts for Content images, CSGO is able to effectively guarantee content. |
|
''' |
|
|
|
|
|
Advanced usage:<br> |
|
1. Click advanced options. |
|
2. Choose different guidance and steps. |
|
""" |
|
|