idlebg commited on
Commit
2b28767
1 Parent(s): d46d634

SCHEDULERS dropdown added

Browse files

AVAILABLE_SCHEDULERS = {
"DDIM": DDIMScheduler,
"DDPM": DDPMScheduler,
"PNDM": PNDMScheduler,
"LMS Discrete": LMSDiscreteScheduler,
"Euler Discrete": EulerDiscreteScheduler,
"Euler Ancestral Discrete": EulerAncestralDiscreteScheduler,
"DPM Solver Multistep": DPMSolverMultistepScheduler,
"DPM Solver Singlestep": DPMSolverSinglestepScheduler,
}

Files changed (2) hide show
  1. app.py +62 -11
  2. utils.py +227 -0
app.py CHANGED
@@ -9,6 +9,9 @@ import io
9
  import tempfile
10
  import zipfile
11
  import PIL
 
 
 
12
  from dataclasses import dataclass
13
  from io import BytesIO
14
  def sanitize_filename(filename):
@@ -16,14 +19,27 @@ def sanitize_filename(filename):
16
  return re.sub(r'[\\/*?:"<>|]', "_", filename)
17
 
18
  from typing import Optional, Literal, Union
19
- from diffusers import DiffusionPipeline, DDIMScheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  HF_TOKEN = os.environ.get("HF_TOKEN")
21
  import streamlit as st
22
  st.set_page_config(layout="wide")
23
  import torch
24
  from diffusers import (
25
  StableDiffusionPipeline,
26
- EulerDiscreteScheduler,
27
  StableDiffusionInpaintPipeline,
28
  StableDiffusionImg2ImgPipeline,
29
  )
@@ -44,6 +60,7 @@ from st_clickable_images import clickable_images
44
 
45
  import streamlit.components.v1 as components
46
 
 
47
  prefix = 'image_generation'
48
 
49
  def dict_to_style(d):
@@ -78,7 +95,7 @@ def display_and_download_images(output_images, metadata):
78
 
79
  PIPELINE_NAMES = Literal["txt2img", "inpaint", "img2img"]
80
 
81
- DEFAULT_PROMPT = "a sprinkled donut sitting on top of a purple cherry apple with ice cubes, colorful hyperrealism, digital explosion of vibrant colors and abstract digital elements"
82
  DEFAULT_WIDTH, DEFAULT_HEIGHT = 512, 512
83
  OUTPUT_IMAGE_KEY = "output_img"
84
  LOADED_IMAGE_KEY = "loaded_image"
@@ -97,7 +114,8 @@ def set_image(key: str, img: Image.Image):
97
 
98
  @st.cache_resource(max_entries=1)
99
  def get_pipeline(
100
- name: PIPELINE_NAMES,
 
101
  ) -> Union[
102
  StableDiffusionPipeline,
103
  StableDiffusionImg2ImgPipeline,
@@ -107,10 +125,18 @@ def get_pipeline(
107
  model_id = "FFusion/FFusion-BaSE"
108
 
109
  pipeline = DiffusionPipeline.from_pretrained(model_id)
110
- # switch the scheduler in the pipeline to use the DDIMScheduler
111
- pipeline.scheduler = DDIMScheduler.from_config(
112
- pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
113
- )
 
 
 
 
 
 
 
 
114
  pipeline = pipeline.to("cuda")
115
  return pipeline
116
 
@@ -142,7 +168,7 @@ def generate(
142
 
143
  if enable_xformers:
144
  pipe.enable_xformers_memory_efficient_attention()
145
-
146
  kwargs = dict(
147
  prompt=prompt,
148
  negative_prompt=negative_prompt,
@@ -151,7 +177,6 @@ def generate(
151
  guidance_scale=guidance_scale,
152
  guidance_rescale=0.7
153
  )
154
- print("kwargs", kwargs)
155
 
156
  if pipeline_name == "txt2img":
157
  kwargs.update(width=width, height=height)
@@ -164,6 +189,13 @@ def generate(
164
  f"Cannot generate image for pipeline {pipeline_name} and {prompt}"
165
  )
166
 
 
 
 
 
 
 
 
167
  output_images = [] # list to hold output image objects
168
  for _ in range(num_images): # loop over number of images
169
  result = pipe(**kwargs) # generate one image at a time
@@ -177,6 +209,10 @@ def generate(
177
  image.save(f"{filename}.png")
178
  output_images.append(image) # add the image object to the list
179
 
 
 
 
 
180
  for image in output_images:
181
  with open(f"{filename}.txt", "w") as f:
182
  f.write(prompt)
@@ -185,6 +221,9 @@ def generate(
185
 
186
 
187
 
 
 
 
188
  def prompt_and_generate_button(prefix, pipeline_name: PIPELINE_NAMES, **kwargs):
189
  prompt = st.text_area(
190
  "Prompt",
@@ -203,9 +242,21 @@ def prompt_and_generate_button(prefix, pipeline_name: PIPELINE_NAMES, **kwargs):
203
  guidance_scale = st.slider(
204
  "Guidance scale", min_value=0.0, max_value=20.0, value=7.5, step=0.5, key=f"{prefix}-guidance-scale"
205
  )
 
 
 
 
 
 
 
 
 
 
 
 
206
  # enable_attention_slicing = st.checkbox('Enable attention slicing (enables higher resolutions but is slower)', key=f"{prefix}-attention-slicing", value=True)
207
  # enable_xformers = st.checkbox('Enable xformers library (better memory usage)', key=f"{prefix}-xformers", value=True)
208
- num_images = st.slider("Number of images to generate", min_value=1, max_value=4, value=1, key=f"{prefix}-num-images")
209
 
210
  images = []
211
 
 
9
  import tempfile
10
  import zipfile
11
  import PIL
12
+ import subprocess
13
+ from huggingface_hub import Repository
14
+ from utils import save_to_hub, save_to_local
15
  from dataclasses import dataclass
16
  from io import BytesIO
17
  def sanitize_filename(filename):
 
19
  return re.sub(r'[\\/*?:"<>|]', "_", filename)
20
 
21
  from typing import Optional, Literal, Union
22
+ from diffusers import (DiffusionPipeline, DDIMScheduler, DDPMScheduler, PNDMScheduler,
23
+ LMSDiscreteScheduler, EulerDiscreteScheduler,
24
+ EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler,
25
+ DPMSolverSinglestepScheduler)
26
+
27
+ AVAILABLE_SCHEDULERS = {
28
+ "DDIM": DDIMScheduler,
29
+ "DDPM": DDPMScheduler,
30
+ "PNDM": PNDMScheduler,
31
+ "LMS Discrete": LMSDiscreteScheduler,
32
+ "Euler Discrete": EulerDiscreteScheduler,
33
+ "Euler Ancestral Discrete": EulerAncestralDiscreteScheduler,
34
+ "DPM Solver Multistep": DPMSolverMultistepScheduler,
35
+ "DPM Solver Singlestep": DPMSolverSinglestepScheduler,
36
+ }
37
  HF_TOKEN = os.environ.get("HF_TOKEN")
38
  import streamlit as st
39
  st.set_page_config(layout="wide")
40
  import torch
41
  from diffusers import (
42
  StableDiffusionPipeline,
 
43
  StableDiffusionInpaintPipeline,
44
  StableDiffusionImg2ImgPipeline,
45
  )
 
60
 
61
  import streamlit.components.v1 as components
62
 
63
+
64
  prefix = 'image_generation'
65
 
66
  def dict_to_style(d):
 
95
 
96
  PIPELINE_NAMES = Literal["txt2img", "inpaint", "img2img"]
97
 
98
+ DEFAULT_PROMPT = "sprinkled donut sitting on top of a purple cherry apple, colorful hyperrealism"
99
  DEFAULT_WIDTH, DEFAULT_HEIGHT = 512, 512
100
  OUTPUT_IMAGE_KEY = "output_img"
101
  LOADED_IMAGE_KEY = "loaded_image"
 
114
 
115
  @st.cache_resource(max_entries=1)
116
  def get_pipeline(
117
+ name: str,
118
+ scheduler_name: str = None,
119
  ) -> Union[
120
  StableDiffusionPipeline,
121
  StableDiffusionImg2ImgPipeline,
 
125
  model_id = "FFusion/FFusion-BaSE"
126
 
127
  pipeline = DiffusionPipeline.from_pretrained(model_id)
128
+
129
+ # Use specified scheduler if provided, else use DDIMScheduler
130
+ if scheduler_name:
131
+ SchedulerClass = AVAILABLE_SCHEDULERS[scheduler_name]
132
+ pipeline.scheduler = SchedulerClass.from_config(
133
+ pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
134
+ )
135
+ else:
136
+ pipeline.scheduler = DDIMScheduler.from_config(
137
+ pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
138
+ )
139
+
140
  pipeline = pipeline.to("cuda")
141
  return pipeline
142
 
 
168
 
169
  if enable_xformers:
170
  pipe.enable_xformers_memory_efficient_attention()
171
+
172
  kwargs = dict(
173
  prompt=prompt,
174
  negative_prompt=negative_prompt,
 
177
  guidance_scale=guidance_scale,
178
  guidance_rescale=0.7
179
  )
 
180
 
181
  if pipeline_name == "txt2img":
182
  kwargs.update(width=width, height=height)
 
189
  f"Cannot generate image for pipeline {pipeline_name} and {prompt}"
190
  )
191
 
192
+ # Save images to Hugging Face Hub or locally
193
+ current_datetime = datetime.now()
194
+ metadata = {
195
+ "prompt": prompt,
196
+ "timestamp": str(current_datetime),
197
+ }
198
+
199
  output_images = [] # list to hold output image objects
200
  for _ in range(num_images): # loop over number of images
201
  result = pipe(**kwargs) # generate one image at a time
 
209
  image.save(f"{filename}.png")
210
  output_images.append(image) # add the image object to the list
211
 
212
+ # Save image to Hugging Face Hub
213
+ output_path = f"images/{i}.png"
214
+ save_to_hub(image, current_datetime, metadata, output_path)
215
+
216
  for image in output_images:
217
  with open(f"{filename}.txt", "w") as f:
218
  f.write(prompt)
 
221
 
222
 
223
 
224
+
225
+
226
+
227
  def prompt_and_generate_button(prefix, pipeline_name: PIPELINE_NAMES, **kwargs):
228
  prompt = st.text_area(
229
  "Prompt",
 
242
  guidance_scale = st.slider(
243
  "Guidance scale", min_value=0.0, max_value=20.0, value=7.5, step=0.5, key=f"{prefix}-guidance-scale"
244
  )
245
+ # Add a select box for the schedulers
246
+ scheduler_name = st.selectbox(
247
+ "Choose a Scheduler",
248
+ options=list(AVAILABLE_SCHEDULERS.keys()),
249
+ index=0, # Default index
250
+ key=f"{prefix}-scheduler",
251
+ )
252
+ scheduler_class = AVAILABLE_SCHEDULERS[scheduler_name] # Get the selected scheduler class
253
+
254
+
255
+ pipe = get_pipeline(pipeline_name, scheduler_name=scheduler_name)
256
+
257
  # enable_attention_slicing = st.checkbox('Enable attention slicing (enables higher resolutions but is slower)', key=f"{prefix}-attention-slicing", value=True)
258
  # enable_xformers = st.checkbox('Enable xformers library (better memory usage)', key=f"{prefix}-xformers", value=True)
259
+ num_images = st.slider("Number of images to generate", min_value=1, max_value=4, value=1, key=f"{prefix}-num-images")
260
 
261
  images = []
262
 
utils.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gc
3
+ import io
4
+ import os
5
+ import tempfile
6
+ import zipfile
7
+ from datetime import datetime
8
+ from threading import Thread
9
+ from huggingface_hub import Repository
10
+ import subprocess
11
+
12
+ import requests
13
+ import streamlit as st
14
+ import torch
15
+ from huggingface_hub import HfApi
16
+ from huggingface_hub.utils._errors import RepositoryNotFoundError
17
+ from huggingface_hub.utils._validators import HFValidationError
18
+ from loguru import logger
19
+ from PIL.PngImagePlugin import PngInfo
20
+ from st_clickable_images import clickable_images
21
+
22
+
23
+ no_safety_checker = None
24
+
25
+
26
+ CODE_OF_CONDUCT = """
27
+ ## Code of conduct
28
+ The app should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
29
+
30
+ Using the app to generate content that is cruel to individuals is a misuse of this app. One shall not use this app to generate content that is intended to be cruel to individuals, or to generate content that is intended to be cruel to individuals in a way that is not obvious to the viewer.
31
+ This includes, but is not limited to:
32
+ - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
33
+ - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
34
+ - Impersonating individuals without their consent.
35
+ - Sexual content without consent of the people who might see it.
36
+ - Mis- and disinformation
37
+ - Representations of egregious violence and gore
38
+ - Sharing of copyrighted or licensed material in violation of its terms of use.
39
+ - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
40
+
41
+ By using this app, you agree to the above code of conduct.
42
+
43
+ """
44
+
45
+
46
+ def use_auth_token():
47
+ token_path = os.path.join(os.path.expanduser("~"), ".huggingface", "token")
48
+ if os.path.exists(token_path):
49
+ return True
50
+ if "HF_TOKEN" in os.environ:
51
+ return os.environ["HF_TOKEN"]
52
+ return False
53
+
54
+
55
+
56
+ def download_file(file_url):
57
+ r = requests.get(file_url, stream=True)
58
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
59
+ for chunk in r.iter_content(chunk_size=1024):
60
+ if chunk: # filter out keep-alive new chunks
61
+ tmp.write(chunk)
62
+ return tmp.name
63
+
64
+
65
+ def cache_folder():
66
+ _cache_folder = os.path.join(os.path.expanduser("~"), ".ffusion")
67
+ os.makedirs(_cache_folder, exist_ok=True)
68
+ return _cache_folder
69
+
70
+
71
+ def clear_memory(preserve):
72
+ torch.cuda.empty_cache()
73
+ gc.collect()
74
+ to_clear = ["inpainting", "text2img", "img2text"]
75
+ for key in to_clear:
76
+ if key not in preserve and key in st.session_state:
77
+ del st.session_state[key]
78
+
79
+
80
+
81
+ import subprocess
82
+
83
+ from huggingface_hub import Repository
84
+
85
+ def save_to_hub(image, current_datetime, metadata, output_path):
86
+ """Saves an image to Hugging Face Hub"""
87
+ try:
88
+ # Convert image to byte array
89
+ byte_arr = io.BytesIO()
90
+
91
+ # Check if the image has metadata
92
+ if image.info:
93
+ # Save as PNG
94
+ image.save(byte_arr, format='PNG')
95
+ else:
96
+ # Save as JPG
97
+ image.save(byte_arr, format='JPEG')
98
+
99
+ byte_arr = byte_arr.getvalue()
100
+
101
+ # Create a repository object
102
+ token = os.getenv("HF_TOKEN")
103
+ api = HfApi()
104
+
105
+ username = "FFusion"
106
+ repo_name = "FF"
107
+ try:
108
+ repo = Repository(f"{username}/{repo_name}", clone_from=f"{username}/{repo_name}", use_auth_token=token, repo_type="dataset")
109
+ except RepositoryNotFoundError:
110
+ repo = Repository(f"{username}/{repo_name}", clone_from=f"{username}/{repo_name}", use_auth_token=token, repo_type="dataset")
111
+
112
+ # Create the directory if it does not exist
113
+ os.makedirs(os.path.dirname(f"{repo.local_dir}/{output_path}"), exist_ok=True)
114
+
115
+ # Write image to repository
116
+ with open(f"{repo.local_dir}/{output_path}", "wb") as f:
117
+ f.write(byte_arr)
118
+
119
+ # Set Git username and email
120
+ subprocess.run(["git", "config", "user.name", "idle stoev"], check=True, cwd=repo.local_dir)
121
+ subprocess.run(["git", "config", "user.email", "[email protected]"], check=True, cwd=repo.local_dir)
122
+
123
+ # Commit and push changes
124
+ repo.git_add(pattern=".")
125
+ repo.git_commit(f"Add image at {current_datetime}")
126
+ print(f"Pushing changes to {username}/{repo_name}...")
127
+ repo.git_push()
128
+ print(f"Image saved to {username}/{repo_name}/{output_path}")
129
+
130
+ except Exception as e:
131
+ print(f"Failed to save image to Hugging Face Hub: {e}")
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+ def save_to_local(images, module, current_datetime, metadata, output_path):
150
+ _metadata = PngInfo()
151
+ _metadata.add_text("text2img", metadata)
152
+ os.makedirs(output_path, exist_ok=True)
153
+ os.makedirs(f"{output_path}/{module}", exist_ok=True)
154
+ os.makedirs(f"{output_path}/{module}/{current_datetime}", exist_ok=True)
155
+
156
+ for i, img in enumerate(images):
157
+ img.save(
158
+ f"{output_path}/{module}/{current_datetime}/{i}.png",
159
+ pnginfo=_metadata,
160
+ )
161
+
162
+ # save metadata as text file
163
+ with open(f"{output_path}/{module}/{current_datetime}/metadata.txt", "w") as f:
164
+ f.write(metadata)
165
+ logger.info(f"Saved images to {output_path}/{module}/{current_datetime}")
166
+
167
+
168
+ def save_images(images, module, metadata, output_path):
169
+ if output_path is None:
170
+ logger.warning("No output path specified, skipping saving images")
171
+ return
172
+
173
+ api = HfApi()
174
+ dset_info = None
175
+ try:
176
+ dset_info = api.dataset_info(output_path)
177
+ except (HFValidationError, RepositoryNotFoundError):
178
+ logger.warning("No valid hugging face repo. Saving locally...")
179
+
180
+ current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
181
+
182
+ if not dset_info:
183
+ save_to_local(images, module, current_datetime, metadata, output_path)
184
+ else:
185
+ Thread(target=save_to_hub, args=(api, images, module, current_datetime, metadata, output_path)).start()
186
+
187
+
188
+ def display_and_download_images(output_images, metadata, download_col=None):
189
+ # st.image(output_images, width=128, output_format="PNG")
190
+
191
+ with st.spinner("Preparing images for download..."):
192
+ # save images to a temporary directory
193
+ with tempfile.TemporaryDirectory() as tmpdir:
194
+ gallery_images = []
195
+ for i, image in enumerate(output_images):
196
+ image.save(os.path.join(tmpdir, f"{i + 1}.png"), pnginfo=metadata)
197
+ with open(os.path.join(tmpdir, f"{i + 1}.png"), "rb") as img:
198
+ encoded = base64.b64encode(img.read()).decode()
199
+ gallery_images.append(f"data:image/jpeg;base64,{encoded}")
200
+
201
+ # zip the images
202
+ zip_path = os.path.join(tmpdir, "images.zip")
203
+ with zipfile.ZipFile(zip_path, "w") as zip:
204
+ for filename in os.listdir(tmpdir):
205
+ if filename.endswith(".png"):
206
+ zip.write(os.path.join(tmpdir, filename), filename)
207
+
208
+ # convert zip to base64
209
+ with open(zip_path, "rb") as f:
210
+ encoded = base64.b64encode(f.read()).decode()
211
+
212
+ _ = clickable_images(
213
+ gallery_images,
214
+ titles=[f"Image #{str(i)}" for i in range(len(gallery_images))],
215
+ div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
216
+ img_style={"margin": "5px", "height": "200px"},
217
+ )
218
+
219
+ # add download link
220
+ st.markdown(
221
+ f"""
222
+ <a href="data:application/zip;base64,{encoded}" download="images.zip">
223
+ Download Images
224
+ </a>
225
+ """,
226
+ unsafe_allow_html=True,
227
+ )