sagar007 commited on
Commit
21e3bc8
·
verified ·
1 Parent(s): 11ba365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -103
app.py CHANGED
@@ -1,13 +1,93 @@
1
  import os
2
  import torch
3
  import gradio as gr
4
- from tqdm import tqdm
5
  from PIL import Image
6
  import torch.nn.functional as F
7
  from torchvision import transforms as tfms
8
- from transformers import CLIPTextModel, CLIPTokenizer, logging
9
- from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, DiffusionPipeline
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  HTML_TEMPLATE = """
12
  <style>
13
  body {
@@ -145,105 +225,7 @@ HTML_TEMPLATE = """
145
  </div>
146
  """
147
 
148
- torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
149
- if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
150
-
151
- # Load the pipeline
152
- model_path = "CompVis/stable-diffusion-v1-4"
153
- sd_pipeline = DiffusionPipeline.from_pretrained(
154
- model_path,
155
- low_cpu_mem_usage=True,
156
- torch_dtype=torch.float32
157
- ).to(torch_device)
158
-
159
- # Load textual inversions
160
- sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
161
- sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
162
- sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
163
- sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
164
- sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
165
- sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
166
- sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
167
-
168
- # Update style token dictionary
169
- style_token_dict = {
170
- "Illustration Style": '<illustration-style>',
171
- "Line Art":'<line-art>',
172
- "Hitokomoru Style":'<hitokomoru-style-nao>',
173
- "Marc Allante": '<Marc_Allante>',
174
- "Midjourney":'<midjourney-style>',
175
- "Hanfu Anime": '<hanfu-anime-style>',
176
- "Birb Style": '<birb-style>'
177
- }
178
-
179
- def apply_guidance(image, guidance_method, loss_scale):
180
- # Convert PIL Image to tensor
181
- img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
182
-
183
- if guidance_method == 'Grayscale':
184
- gray = tfms.Grayscale(3)(img_tensor)
185
- guided = img_tensor + (gray - img_tensor) * (loss_scale / 10000)
186
- elif guidance_method == 'Bright':
187
- bright = F.relu(img_tensor) # Simple brightness increase
188
- guided = img_tensor + (bright - img_tensor) * (loss_scale / 10000)
189
- elif guidance_method == 'Contrast':
190
- mean = img_tensor.mean()
191
- contrast = (img_tensor - mean) * 2 + mean
192
- guided = img_tensor + (contrast - img_tensor) * (loss_scale / 10000)
193
- elif guidance_method == 'Symmetry':
194
- flipped = torch.flip(img_tensor, [3]) # Flip horizontally
195
- guided = img_tensor + (flipped - img_tensor) * (loss_scale / 10000)
196
- elif guidance_method == 'Saturation':
197
- saturated = tfms.functional.adjust_saturation(img_tensor, 2)
198
- guided = img_tensor + (saturated - img_tensor) * (loss_scale / 10000)
199
- else:
200
- return image
201
-
202
- # Convert back to PIL Image
203
- guided = guided.squeeze(0).clamp(0, 1)
204
- guided = (guided * 255).byte().cpu().permute(1, 2, 0).numpy()
205
- return Image.fromarray(guided)
206
-
207
- def generate_with_guidance(prompt, num_inference_steps, guidance_scale, seed, guidance_method, loss_scale):
208
- # Generate image with pipeline
209
- generator = torch.Generator(device=torch_device).manual_seed(seed)
210
- image = sd_pipeline(
211
- prompt,
212
- num_inference_steps=num_inference_steps,
213
- guidance_scale=guidance_scale,
214
- generator=generator
215
- ).images[0]
216
-
217
- # Apply guidance
218
- guided_image = apply_guidance(image, guidance_method, loss_scale)
219
-
220
- return guided_image
221
-
222
- def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale):
223
- prompt = text + " " + style_token_dict[style]
224
-
225
- # Generate image with pipeline
226
- image_pipeline = sd_pipeline(
227
- prompt,
228
- num_inference_steps=inference_step,
229
- guidance_scale=guidance_scale,
230
- generator=torch.Generator(device=torch_device).manual_seed(seed)
231
- ).images[0]
232
-
233
- # Generate image with guidance
234
- image_guide = generate_with_guidance(prompt, inference_step, guidance_scale, seed, guidance_method, loss_scale)
235
-
236
- return image_pipeline, image_guide
237
-
238
- title = "Generative with Textual Inversion and Guidance"
239
- description = "A Gradio interface to infer Stable Diffusion and generate images with different art styles and guidance methods"
240
- examples = [
241
- ["A majestic castle on a floating island", 'Illustration Style', 10, 7.5, 42, 'Grayscale', 200]
242
- ]
243
-
244
- title = "Generative Art with Textual Inversion and Guidance"
245
- description = "Create unique artworks using Stable Diffusion with various styles and guidance methods."
246
-
247
  with gr.Blocks(css=HTML_TEMPLATE) as demo:
248
  gr.HTML(HTML_TEMPLATE)
249
  with gr.Row():
@@ -280,4 +262,5 @@ with gr.Blocks(css=HTML_TEMPLATE) as demo:
280
  cache_examples=True,
281
  )
282
 
283
- demo.launch()
 
 
1
  import os
2
  import torch
3
  import gradio as gr
 
4
  from PIL import Image
5
  import torch.nn.functional as F
6
  from torchvision import transforms as tfms
7
+ from diffusers import DiffusionPipeline
 
8
 
9
+ # Determine the appropriate device and dtype
10
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ torch_dtype = torch.float16 if torch_device == "cuda" else torch.float32
12
+
13
+ # Load the pipeline
14
+ model_path = "CompVis/stable-diffusion-v1-4"
15
+ sd_pipeline = DiffusionPipeline.from_pretrained(
16
+ model_path,
17
+ torch_dtype=torch_dtype,
18
+ low_cpu_mem_usage=True if torch_device == "cpu" else False
19
+ ).to(torch_device)
20
+
21
+ # Load textual inversions
22
+ sd_pipeline.load_textual_inversion("sd-concepts-library/illustration-style")
23
+ sd_pipeline.load_textual_inversion("sd-concepts-library/line-art")
24
+ sd_pipeline.load_textual_inversion("sd-concepts-library/hitokomoru-style-nao")
25
+ sd_pipeline.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
26
+ sd_pipeline.load_textual_inversion("sd-concepts-library/midjourney-style")
27
+ sd_pipeline.load_textual_inversion("sd-concepts-library/hanfu-anime-style")
28
+ sd_pipeline.load_textual_inversion("sd-concepts-library/birb-style")
29
+
30
+ # Update style token dictionary
31
+ style_token_dict = {
32
+ "Illustration Style": '<illustration-style>',
33
+ "Line Art": '<line-art>',
34
+ "Hitokomoru Style": '<hitokomoru-style-nao>',
35
+ "Marc Allante": '<Marc_Allante>',
36
+ "Midjourney": '<midjourney-style>',
37
+ "Hanfu Anime": '<hanfu-anime-style>',
38
+ "Birb Style": '<birb-style>'
39
+ }
40
+
41
+ def apply_guidance(image, guidance_method, loss_scale):
42
+ # Convert PIL Image to tensor
43
+ img_tensor = tfms.ToTensor()(image).unsqueeze(0).to(torch_device)
44
+
45
+ if guidance_method == 'Grayscale':
46
+ gray = tfms.Grayscale(3)(img_tensor)
47
+ guided = img_tensor + (gray - img_tensor) * (loss_scale / 10000)
48
+ elif guidance_method == 'Bright':
49
+ bright = F.relu(img_tensor) # Simple brightness increase
50
+ guided = img_tensor + (bright - img_tensor) * (loss_scale / 10000)
51
+ elif guidance_method == 'Contrast':
52
+ mean = img_tensor.mean()
53
+ contrast = (img_tensor - mean) * 2 + mean
54
+ guided = img_tensor + (contrast - img_tensor) * (loss_scale / 10000)
55
+ elif guidance_method == 'Symmetry':
56
+ flipped = torch.flip(img_tensor, [3]) # Flip horizontally
57
+ guided = img_tensor + (flipped - img_tensor) * (loss_scale / 10000)
58
+ elif guidance_method == 'Saturation':
59
+ saturated = tfms.functional.adjust_saturation(img_tensor, 2)
60
+ guided = img_tensor + (saturated - img_tensor) * (loss_scale / 10000)
61
+ else:
62
+ return image
63
+
64
+ # Convert back to PIL Image
65
+ guided = guided.squeeze(0).clamp(0, 1)
66
+ guided = (guided * 255).byte().cpu().permute(1, 2, 0).numpy()
67
+ return Image.fromarray(guided)
68
+
69
+ def inference(text, style, inference_step, guidance_scale, seed, guidance_method, loss_scale, image_size):
70
+ prompt = text + " " + style_token_dict[style]
71
+
72
+ # Convert image_size from string to tuple of integers
73
+ size = tuple(map(int, image_size.split('x')))
74
+
75
+ # Generate image with pipeline
76
+ image_pipeline = sd_pipeline(
77
+ prompt,
78
+ num_inference_steps=inference_step,
79
+ guidance_scale=guidance_scale,
80
+ generator=torch.Generator(device=torch_device).manual_seed(seed),
81
+ height=size[1],
82
+ width=size[0]
83
+ ).images[0]
84
+
85
+ # Apply guidance
86
+ image_guide = apply_guidance(image_pipeline, guidance_method, loss_scale)
87
+
88
+ return image_pipeline, image_guide
89
+
90
+ # HTML Template
91
  HTML_TEMPLATE = """
92
  <style>
93
  body {
 
225
  </div>
226
  """
227
 
228
+ # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  with gr.Blocks(css=HTML_TEMPLATE) as demo:
230
  gr.HTML(HTML_TEMPLATE)
231
  with gr.Row():
 
262
  cache_examples=True,
263
  )
264
 
265
+ if __name__ == "__main__":
266
+ demo.launch()