sagar007 commited on
Commit
44fe76b
·
verified ·
1 Parent(s): bc1fb47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -201
app.py CHANGED
@@ -1,29 +1,25 @@
1
- import gradio as gr
2
- import PIL
3
  import torch
4
- import numpy as np
5
- from PIL import Image
6
  from tqdm import tqdm
 
7
  import torch.nn.functional as F
8
- import torchvision.transforms as T
9
- from diffusers import LMSDiscreteScheduler, DiffusionPipeline
10
-
11
- # configurations
12
- torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
13
- height, width = 512,512
14
- guidance_scale = 8
15
- loss_scale = 200
16
- num_inference_steps = 50
17
 
 
 
18
 
 
19
  model_path = "CompVis/stable-diffusion-v1-4"
20
  sd_pipeline = DiffusionPipeline.from_pretrained(
21
  model_path,
22
- low_cpu_mem_usage = True,
23
  torch_dtype=torch.float32
24
  ).to(torch_device)
25
 
26
-
27
  sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
28
  sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
29
  sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
@@ -32,199 +28,77 @@ sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
32
  sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
33
  sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
34
 
35
-
36
- styles_mapping = {
37
- "Illustration Style": '<illustration-style>', "Line Art":'<line-art>',
38
- "Hitokomoru Style":'<hitokomoru-style-nao>', "Marc Allante": '<Marc_Allante>',
39
- "Midjourney":'<midjourney-style>', "Hanfu Anime": '<hanfu-anime-style>',
 
 
 
40
  "Birb Style": '<birb-style>'
41
  }
42
 
43
- # Define seeds for all the styles
44
- seed_list = [11, 56, 110, 65, 5, 29, 47]
45
-
46
- # Loss Function based on Edge Detection
47
- def edge_detection(image):
48
- channels = image.shape[1]
49
-
50
- # Define the kernels for Edge Detection
51
- ed_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
52
- ed_y = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
53
-
54
- # Replicate the Edge detection kernels for each channel
55
- ed_x = ed_x.repeat(channels, 1, 1, 1).to(image.device)
56
- ed_y = ed_y.repeat(channels, 1, 1, 1).to(image.device)
57
-
58
- # ed_x = ed_x.to(torch.float16)
59
- # ed_y = ed_y.to(torch.float16)
60
-
61
- # Convolve the image with the Edge detection kernels
62
- conv_ed_x = F.conv2d(image, ed_x, padding=1, groups=channels)
63
- conv_ed_y = F.conv2d(image, ed_y, padding=1, groups=channels)
64
-
65
- # Combine the x and y gradients after convolution
66
- ed_value = torch.sqrt(conv_ed_x**2 + conv_ed_y**2)
67
-
68
- return ed_value
69
-
70
- def edge_loss(image):
71
- ed_value = edge_detection(image)
72
- ed_capped = (ed_value > 0.5).to(torch.float32)
73
- return F.mse_loss(ed_value, ed_capped)
74
-
75
- def compute_loss(original_image, loss_type):
76
-
77
- if loss_type == 'blue':
78
- # blue loss
79
- # [:,2] -> all images in batch, only the blue channel
80
- error = torch.abs(original_image[:,2] - 0.9).mean()
81
- elif loss_type == 'edge':
82
- # edge loss
83
- error = edge_loss(original_image)
84
- elif loss_type == 'contrast':
85
- # RGB to Gray loss
86
- transformed_image = T.functional.adjust_contrast(original_image, contrast_factor = 2)
87
- error = torch.abs(transformed_image - original_image).mean()
88
- elif loss_type == 'brightness':
89
- # brightnesss loss
90
- transformed_image = T.functional.adjust_brightness(original_image, brightness_factor = 2)
91
- error = torch.abs(transformed_image - original_image).mean()
92
- elif loss_type == 'sharpness':
93
- # sharpness loss
94
- transformed_image = T.functional.adjust_sharpness(original_image, sharpness_factor = 2)
95
- error = torch.abs(transformed_image - original_image).mean()
96
- elif loss_type == 'saturation':
97
- # saturation loss
98
- transformed_image = T.functional.adjust_saturation(original_image, saturation_factor = 10)
99
- error = torch.abs(transformed_image - original_image).mean()
100
- else:
101
- print("error. Loss not defined")
102
-
103
- return error
104
-
105
- def get_examples():
106
- examples = [
107
- ['A bird sitting on a tree', 'Midjourney', 'edge']
108
- ]
109
- return examples
110
-
111
- # Existing functions (latents_to_pil, show_image, generate_image)
112
- # ... (Copy all the existing functions here)
113
- def latents_to_pil(latents):
114
- # bath of latents -> list of images
115
- latents = (1 / 0.18215) * latents
116
- with torch.no_grad():
117
- image = sd_pipeline.vae.decode(latents).sample
118
- image = (image / 2 + 0.5).clamp(0, 1) # 0 to 1
119
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
120
- image = (image * 255).round().astype("uint8")
121
- return Image.fromarray(image[0])
122
-
123
-
124
- def show_image(prompt, concept, guidance_type):
125
-
126
- for idx, sd in enumerate(styles_mapping.keys()):
127
- if(sd == concept):
128
- break
129
- seed = seed_list[idx]
130
- prompt = f"{prompt} in the style of {styles_mapping[sd]}"
131
- styled_image_without_loss = latents_to_pil(generate_image(seed, prompt, guidance_type, loss_flag=False))
132
- styled_image_with_loss = latents_to_pil(generate_image(seed, prompt, guidance_type, loss_flag=True))
133
- return([styled_image_without_loss, styled_image_with_loss])
134
-
135
 
136
- def generate_image(seed, prompt, loss_type, loss_flag=False):
137
-
138
- generator = torch.manual_seed(seed)
139
- batch_size = 1
140
-
141
- # scheduler
142
- scheduler = LMSDiscreteScheduler(beta_start = 0.00085, beta_end = 0.012, beta_schedule = "scaled_linear", num_train_timesteps = 1000)
143
  scheduler.set_timesteps(num_inference_steps)
144
  scheduler.timesteps = scheduler.timesteps.to(torch.float32)
145
 
146
- # text embeddings of the prompt
147
- text_input = sd_pipeline.tokenizer(prompt, padding='max_length', max_length = sd_pipeline.tokenizer.model_max_length, truncation= True, return_tensors="pt")
148
- input_ids = text_input.input_ids.to(torch_device)
149
-
150
  with torch.no_grad():
151
- text_embeddings = sd_pipeline.text_encoder(text_input.input_ids.to(torch_device))[0]
152
-
153
- max_length = text_input.input_ids.shape[-1]
154
- uncond_input = sd_pipeline.tokenizer(
155
- [""] * batch_size, padding="max_length", max_length= max_length, return_tensors="pt"
156
- )
157
 
 
 
158
  with torch.no_grad():
159
- uncond_embeddings = sd_pipeline.text_encoder(uncond_input.input_ids.to(torch_device))[0]
160
-
161
- text_embeddings = torch.cat([uncond_embeddings,text_embeddings]) # shape: 2,77,768
162
-
163
- # random latent
164
- latents = torch.randn(
165
- (batch_size, sd_pipeline.unet.config.in_channels, height// 8, width //8),
166
- generator = generator,
167
- ) .to(torch.float32)
168
-
169
-
170
- latents = latents.to(torch_device)
171
- latents = latents * scheduler.init_noise_sigma
172
-
173
- for i, t in tqdm(enumerate(scheduler.timesteps), total = len(scheduler.timesteps)):
174
-
175
- latent_model_input = torch.cat([latents] * 2)
176
- sigma = scheduler.sigmas[i]
177
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
178
-
179
- with torch.no_grad():
180
- noise_pred = sd_pipeline.unet(latent_model_input.to(torch.float32), t, encoder_hidden_states=text_embeddings)["sample"]
181
-
182
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
183
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
184
-
185
- if loss_flag and i%5 == 0:
186
-
187
- latents = latents.detach().requires_grad_()
188
- # the following line alone does not work, it requires change to reduce step only once
189
- # hence commenting it out
190
- #latents_x0 = scheduler.step(noise_pred,t, latents).pred_original_sample
191
- latents_x0 = latents - sigma * noise_pred
192
-
193
- # use vae to decode the image
194
- denoised_images = sd_pipeline.vae.decode((1/ 0.18215) * latents_x0).sample / 2 + 0.5 # range(0,1)
195
-
196
- loss = compute_loss(denoised_images, loss_type) * loss_scale
197
- #loss = loss.to(torch.float16)
198
- print(f"{i} loss {loss}")
199
-
200
- cond_grad = torch.autograd.grad(loss, latents)[0]
201
- latents = latents.detach() - cond_grad * sigma**2
202
-
203
- latents = scheduler.step(noise_pred,t, latents).prev_sample
204
-
205
- return latents
206
-
207
- # Gradio interface function
208
- def generate_images(prompt, style, guidance_type):
209
- images = show_image(prompt, style, guidance_type)
210
- return images[0], images[1]
211
-
212
- # Create Gradio interface
213
- iface = gr.Interface(
214
- fn=generate_images,
215
- inputs=[
216
- gr.Textbox(label="Prompt"),
217
- gr.Dropdown(list(styles_mapping.keys()), label="Style"),
218
- gr.Dropdown(["blue", "edge", "contrast", "brightness", "sharpness", "saturation"], label="Guidance Type"),
219
- ],
220
- outputs=[
221
- gr.Image(label="Image without Loss"),
222
- gr.Image(label="Image with Loss"),
223
- ],
224
- examples=get_examples(),
225
- title="Text Inversion Image Generation",
226
- description="Generate images using text inversion with different styles and guidance types.",
227
- )
228
-
229
- # Launch the app
230
- iface.launch()
 
1
+ import os
 
2
  import torch
3
+ import gradio as gr
 
4
  from tqdm import tqdm
5
+ from PIL import Image
6
  import torch.nn.functional as F
7
+ from torchvision import transforms as tfms
8
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
9
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline
 
 
 
 
 
 
10
 
11
+ torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
12
+ if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
13
 
14
+ # Load the pipeline
15
  model_path = "CompVis/stable-diffusion-v1-4"
16
  sd_pipeline = DiffusionPipeline.from_pretrained(
17
  model_path,
18
+ low_cpu_mem_usage=True,
19
  torch_dtype=torch.float32
20
  ).to(torch_device)
21
 
22
+ # Load textual inversions
23
  sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
24
  sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
25
  sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
 
28
  sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
29
  sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
30
 
31
+ # Update style token dictionary
32
+ style_token_dict = {
33
+ "Illustration Style": '<illustration-style>',
34
+ "Line Art":'<line-art>',
35
+ "Hitokomoru Style":'<hitokomoru-style-nao>',
36
+ "Marc Allante": '<Marc_Allante>',
37
+ "Midjourney":'<midjourney-style>',
38
+ "Hanfu Anime": '<hanfu-anime-style>',
39
  "Birb Style": '<birb-style>'
40
  }
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def set_timesteps(scheduler, num_inference_steps):
 
 
 
 
 
 
44
  scheduler.set_timesteps(num_inference_steps)
45
  scheduler.timesteps = scheduler.timesteps.to(torch.float32)
46
 
47
+ def pil_to_latent(input_im):
 
 
 
48
  with torch.no_grad():
49
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
50
+ return 0.18215 * latent.latent_dist.sample()
 
 
 
 
51
 
52
+ def latents_to_pil(latents):
53
+ latents = (1 / 0.18215) * latents
54
  with torch.no_grad():
55
+ image = vae.decode(latents).sample
56
+ image = (image / 2 + 0.5).clamp(0, 1)
57
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
58
+ images = (image * 255).round().astype("uint8")
59
+ pil_images = [Image.fromarray(image) for image in images]
60
+ return pil_images
61
+
62
+ def generate_with_pipeline(prompt, num_inference_steps, guidance_scale, seed):
63
+ generator = torch.Generator(device=torch_device).manual_seed(seed)
64
+ image = sd_pipeline(
65
+ prompt,
66
+ num_inference_steps=num_inference_steps,
67
+ guidance_scale=guidance_scale,
68
+ generator=generator
69
+ ).images[0]
70
+ return image
71
+
72
+ def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
73
+ prompt = text + " " + style_token_dict[style]
74
+
75
+ # Generate image with pipeline
76
+ image_pipeline = generate_with_pipeline(prompt, inference_step, guidance_scale, seed)
77
+
78
+ # For the guided image, we'll need to implement a custom pipeline or modify the existing one
79
+ # This is a placeholder and would need to be implemented
80
+ image_guide = image_pipeline # This should be replaced with actual guided generation
81
+
82
+ return image_pipeline, image_guide
83
+
84
+ title = "Stable Diffusion with Textual Inversion"
85
+ description = "A simple Gradio interface to infer Stable Diffusion and generate images with different art styles"
86
+ examples = [["A sweet potato farm", 'Illustration Style', 10, 4.5, 1, 'Grayscale', 100],
87
+ ["Sky full of cotton candy", 'Line Art', 10, 9.5, 2, 'Bright', 200]]
88
+
89
+ demo = gr.Interface(inference,
90
+ inputs = [gr.Textbox(label="Prompt", type="text"),
91
+ gr.Dropdown(label="Style", choices=list(style_token_dict.keys()), value="Illustration Style"),
92
+ gr.Slider(10, 30, 10, step = 1, label="Inference steps"),
93
+ gr.Slider(1, 10, 7.5, step = 0.1, label="Guidance scale"),
94
+ gr.Slider(0, 10000, 1, step = 1, label="Seed"),
95
+ gr.Dropdown(label="Guidance method", choices=['Grayscale', 'Bright', 'Contrast',
96
+ 'Symmetry', 'Saturation'], value="Grayscale"),
97
+ gr.Slider(100, 10000, 200, step = 100, label="Loss scale")],
98
+ outputs= [gr.Image(width=320, height=320, label="Generated art"),
99
+ gr.Image(width=320, height=320, label="Generated art with guidance")],
100
+ title=title,
101
+ description=description,
102
+ examples=examples)
103
+
104
+ demo.launch()