Spaces:
Runtime error
Runtime error
SCHEDULERS dropdown added
Browse filesAVAILABLE_SCHEDULERS = {
"DDIM": DDIMScheduler,
"DDPM": DDPMScheduler,
"PNDM": PNDMScheduler,
"LMS Discrete": LMSDiscreteScheduler,
"Euler Discrete": EulerDiscreteScheduler,
"Euler Ancestral Discrete": EulerAncestralDiscreteScheduler,
"DPM Solver Multistep": DPMSolverMultistepScheduler,
"DPM Solver Singlestep": DPMSolverSinglestepScheduler,
}
app.py
CHANGED
@@ -9,6 +9,9 @@ import io
|
|
9 |
import tempfile
|
10 |
import zipfile
|
11 |
import PIL
|
|
|
|
|
|
|
12 |
from dataclasses import dataclass
|
13 |
from io import BytesIO
|
14 |
def sanitize_filename(filename):
|
@@ -16,14 +19,27 @@ def sanitize_filename(filename):
|
|
16 |
return re.sub(r'[\\/*?:"<>|]', "_", filename)
|
17 |
|
18 |
from typing import Optional, Literal, Union
|
19 |
-
from diffusers import DiffusionPipeline, DDIMScheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
21 |
import streamlit as st
|
22 |
st.set_page_config(layout="wide")
|
23 |
import torch
|
24 |
from diffusers import (
|
25 |
StableDiffusionPipeline,
|
26 |
-
EulerDiscreteScheduler,
|
27 |
StableDiffusionInpaintPipeline,
|
28 |
StableDiffusionImg2ImgPipeline,
|
29 |
)
|
@@ -44,6 +60,7 @@ from st_clickable_images import clickable_images
|
|
44 |
|
45 |
import streamlit.components.v1 as components
|
46 |
|
|
|
47 |
prefix = 'image_generation'
|
48 |
|
49 |
def dict_to_style(d):
|
@@ -78,7 +95,7 @@ def display_and_download_images(output_images, metadata):
|
|
78 |
|
79 |
PIPELINE_NAMES = Literal["txt2img", "inpaint", "img2img"]
|
80 |
|
81 |
-
DEFAULT_PROMPT = "
|
82 |
DEFAULT_WIDTH, DEFAULT_HEIGHT = 512, 512
|
83 |
OUTPUT_IMAGE_KEY = "output_img"
|
84 |
LOADED_IMAGE_KEY = "loaded_image"
|
@@ -97,7 +114,8 @@ def set_image(key: str, img: Image.Image):
|
|
97 |
|
98 |
@st.cache_resource(max_entries=1)
|
99 |
def get_pipeline(
|
100 |
-
name:
|
|
|
101 |
) -> Union[
|
102 |
StableDiffusionPipeline,
|
103 |
StableDiffusionImg2ImgPipeline,
|
@@ -107,10 +125,18 @@ def get_pipeline(
|
|
107 |
model_id = "FFusion/FFusion-BaSE"
|
108 |
|
109 |
pipeline = DiffusionPipeline.from_pretrained(model_id)
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
pipeline = pipeline.to("cuda")
|
115 |
return pipeline
|
116 |
|
@@ -142,7 +168,7 @@ def generate(
|
|
142 |
|
143 |
if enable_xformers:
|
144 |
pipe.enable_xformers_memory_efficient_attention()
|
145 |
-
|
146 |
kwargs = dict(
|
147 |
prompt=prompt,
|
148 |
negative_prompt=negative_prompt,
|
@@ -151,7 +177,6 @@ def generate(
|
|
151 |
guidance_scale=guidance_scale,
|
152 |
guidance_rescale=0.7
|
153 |
)
|
154 |
-
print("kwargs", kwargs)
|
155 |
|
156 |
if pipeline_name == "txt2img":
|
157 |
kwargs.update(width=width, height=height)
|
@@ -164,6 +189,13 @@ def generate(
|
|
164 |
f"Cannot generate image for pipeline {pipeline_name} and {prompt}"
|
165 |
)
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
output_images = [] # list to hold output image objects
|
168 |
for _ in range(num_images): # loop over number of images
|
169 |
result = pipe(**kwargs) # generate one image at a time
|
@@ -177,6 +209,10 @@ def generate(
|
|
177 |
image.save(f"{filename}.png")
|
178 |
output_images.append(image) # add the image object to the list
|
179 |
|
|
|
|
|
|
|
|
|
180 |
for image in output_images:
|
181 |
with open(f"{filename}.txt", "w") as f:
|
182 |
f.write(prompt)
|
@@ -185,6 +221,9 @@ def generate(
|
|
185 |
|
186 |
|
187 |
|
|
|
|
|
|
|
188 |
def prompt_and_generate_button(prefix, pipeline_name: PIPELINE_NAMES, **kwargs):
|
189 |
prompt = st.text_area(
|
190 |
"Prompt",
|
@@ -203,9 +242,21 @@ def prompt_and_generate_button(prefix, pipeline_name: PIPELINE_NAMES, **kwargs):
|
|
203 |
guidance_scale = st.slider(
|
204 |
"Guidance scale", min_value=0.0, max_value=20.0, value=7.5, step=0.5, key=f"{prefix}-guidance-scale"
|
205 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
# enable_attention_slicing = st.checkbox('Enable attention slicing (enables higher resolutions but is slower)', key=f"{prefix}-attention-slicing", value=True)
|
207 |
# enable_xformers = st.checkbox('Enable xformers library (better memory usage)', key=f"{prefix}-xformers", value=True)
|
208 |
-
|
209 |
|
210 |
images = []
|
211 |
|
|
|
9 |
import tempfile
|
10 |
import zipfile
|
11 |
import PIL
|
12 |
+
import subprocess
|
13 |
+
from huggingface_hub import Repository
|
14 |
+
from utils import save_to_hub, save_to_local
|
15 |
from dataclasses import dataclass
|
16 |
from io import BytesIO
|
17 |
def sanitize_filename(filename):
|
|
|
19 |
return re.sub(r'[\\/*?:"<>|]', "_", filename)
|
20 |
|
21 |
from typing import Optional, Literal, Union
|
22 |
+
from diffusers import (DiffusionPipeline, DDIMScheduler, DDPMScheduler, PNDMScheduler,
|
23 |
+
LMSDiscreteScheduler, EulerDiscreteScheduler,
|
24 |
+
EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler,
|
25 |
+
DPMSolverSinglestepScheduler)
|
26 |
+
|
27 |
+
AVAILABLE_SCHEDULERS = {
|
28 |
+
"DDIM": DDIMScheduler,
|
29 |
+
"DDPM": DDPMScheduler,
|
30 |
+
"PNDM": PNDMScheduler,
|
31 |
+
"LMS Discrete": LMSDiscreteScheduler,
|
32 |
+
"Euler Discrete": EulerDiscreteScheduler,
|
33 |
+
"Euler Ancestral Discrete": EulerAncestralDiscreteScheduler,
|
34 |
+
"DPM Solver Multistep": DPMSolverMultistepScheduler,
|
35 |
+
"DPM Solver Singlestep": DPMSolverSinglestepScheduler,
|
36 |
+
}
|
37 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
38 |
import streamlit as st
|
39 |
st.set_page_config(layout="wide")
|
40 |
import torch
|
41 |
from diffusers import (
|
42 |
StableDiffusionPipeline,
|
|
|
43 |
StableDiffusionInpaintPipeline,
|
44 |
StableDiffusionImg2ImgPipeline,
|
45 |
)
|
|
|
60 |
|
61 |
import streamlit.components.v1 as components
|
62 |
|
63 |
+
|
64 |
prefix = 'image_generation'
|
65 |
|
66 |
def dict_to_style(d):
|
|
|
95 |
|
96 |
PIPELINE_NAMES = Literal["txt2img", "inpaint", "img2img"]
|
97 |
|
98 |
+
DEFAULT_PROMPT = "sprinkled donut sitting on top of a purple cherry apple, colorful hyperrealism"
|
99 |
DEFAULT_WIDTH, DEFAULT_HEIGHT = 512, 512
|
100 |
OUTPUT_IMAGE_KEY = "output_img"
|
101 |
LOADED_IMAGE_KEY = "loaded_image"
|
|
|
114 |
|
115 |
@st.cache_resource(max_entries=1)
|
116 |
def get_pipeline(
|
117 |
+
name: str,
|
118 |
+
scheduler_name: str = None,
|
119 |
) -> Union[
|
120 |
StableDiffusionPipeline,
|
121 |
StableDiffusionImg2ImgPipeline,
|
|
|
125 |
model_id = "FFusion/FFusion-BaSE"
|
126 |
|
127 |
pipeline = DiffusionPipeline.from_pretrained(model_id)
|
128 |
+
|
129 |
+
# Use specified scheduler if provided, else use DDIMScheduler
|
130 |
+
if scheduler_name:
|
131 |
+
SchedulerClass = AVAILABLE_SCHEDULERS[scheduler_name]
|
132 |
+
pipeline.scheduler = SchedulerClass.from_config(
|
133 |
+
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
pipeline.scheduler = DDIMScheduler.from_config(
|
137 |
+
pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
|
138 |
+
)
|
139 |
+
|
140 |
pipeline = pipeline.to("cuda")
|
141 |
return pipeline
|
142 |
|
|
|
168 |
|
169 |
if enable_xformers:
|
170 |
pipe.enable_xformers_memory_efficient_attention()
|
171 |
+
|
172 |
kwargs = dict(
|
173 |
prompt=prompt,
|
174 |
negative_prompt=negative_prompt,
|
|
|
177 |
guidance_scale=guidance_scale,
|
178 |
guidance_rescale=0.7
|
179 |
)
|
|
|
180 |
|
181 |
if pipeline_name == "txt2img":
|
182 |
kwargs.update(width=width, height=height)
|
|
|
189 |
f"Cannot generate image for pipeline {pipeline_name} and {prompt}"
|
190 |
)
|
191 |
|
192 |
+
# Save images to Hugging Face Hub or locally
|
193 |
+
current_datetime = datetime.now()
|
194 |
+
metadata = {
|
195 |
+
"prompt": prompt,
|
196 |
+
"timestamp": str(current_datetime),
|
197 |
+
}
|
198 |
+
|
199 |
output_images = [] # list to hold output image objects
|
200 |
for _ in range(num_images): # loop over number of images
|
201 |
result = pipe(**kwargs) # generate one image at a time
|
|
|
209 |
image.save(f"{filename}.png")
|
210 |
output_images.append(image) # add the image object to the list
|
211 |
|
212 |
+
# Save image to Hugging Face Hub
|
213 |
+
output_path = f"images/{i}.png"
|
214 |
+
save_to_hub(image, current_datetime, metadata, output_path)
|
215 |
+
|
216 |
for image in output_images:
|
217 |
with open(f"{filename}.txt", "w") as f:
|
218 |
f.write(prompt)
|
|
|
221 |
|
222 |
|
223 |
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
def prompt_and_generate_button(prefix, pipeline_name: PIPELINE_NAMES, **kwargs):
|
228 |
prompt = st.text_area(
|
229 |
"Prompt",
|
|
|
242 |
guidance_scale = st.slider(
|
243 |
"Guidance scale", min_value=0.0, max_value=20.0, value=7.5, step=0.5, key=f"{prefix}-guidance-scale"
|
244 |
)
|
245 |
+
# Add a select box for the schedulers
|
246 |
+
scheduler_name = st.selectbox(
|
247 |
+
"Choose a Scheduler",
|
248 |
+
options=list(AVAILABLE_SCHEDULERS.keys()),
|
249 |
+
index=0, # Default index
|
250 |
+
key=f"{prefix}-scheduler",
|
251 |
+
)
|
252 |
+
scheduler_class = AVAILABLE_SCHEDULERS[scheduler_name] # Get the selected scheduler class
|
253 |
+
|
254 |
+
|
255 |
+
pipe = get_pipeline(pipeline_name, scheduler_name=scheduler_name)
|
256 |
+
|
257 |
# enable_attention_slicing = st.checkbox('Enable attention slicing (enables higher resolutions but is slower)', key=f"{prefix}-attention-slicing", value=True)
|
258 |
# enable_xformers = st.checkbox('Enable xformers library (better memory usage)', key=f"{prefix}-xformers", value=True)
|
259 |
+
num_images = st.slider("Number of images to generate", min_value=1, max_value=4, value=1, key=f"{prefix}-num-images")
|
260 |
|
261 |
images = []
|
262 |
|
utils.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import gc
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import tempfile
|
6 |
+
import zipfile
|
7 |
+
from datetime import datetime
|
8 |
+
from threading import Thread
|
9 |
+
from huggingface_hub import Repository
|
10 |
+
import subprocess
|
11 |
+
|
12 |
+
import requests
|
13 |
+
import streamlit as st
|
14 |
+
import torch
|
15 |
+
from huggingface_hub import HfApi
|
16 |
+
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
17 |
+
from huggingface_hub.utils._validators import HFValidationError
|
18 |
+
from loguru import logger
|
19 |
+
from PIL.PngImagePlugin import PngInfo
|
20 |
+
from st_clickable_images import clickable_images
|
21 |
+
|
22 |
+
|
23 |
+
no_safety_checker = None
|
24 |
+
|
25 |
+
|
26 |
+
CODE_OF_CONDUCT = """
|
27 |
+
## Code of conduct
|
28 |
+
The app should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
29 |
+
|
30 |
+
Using the app to generate content that is cruel to individuals is a misuse of this app. One shall not use this app to generate content that is intended to be cruel to individuals, or to generate content that is intended to be cruel to individuals in a way that is not obvious to the viewer.
|
31 |
+
This includes, but is not limited to:
|
32 |
+
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
33 |
+
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
34 |
+
- Impersonating individuals without their consent.
|
35 |
+
- Sexual content without consent of the people who might see it.
|
36 |
+
- Mis- and disinformation
|
37 |
+
- Representations of egregious violence and gore
|
38 |
+
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
39 |
+
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
40 |
+
|
41 |
+
By using this app, you agree to the above code of conduct.
|
42 |
+
|
43 |
+
"""
|
44 |
+
|
45 |
+
|
46 |
+
def use_auth_token():
|
47 |
+
token_path = os.path.join(os.path.expanduser("~"), ".huggingface", "token")
|
48 |
+
if os.path.exists(token_path):
|
49 |
+
return True
|
50 |
+
if "HF_TOKEN" in os.environ:
|
51 |
+
return os.environ["HF_TOKEN"]
|
52 |
+
return False
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def download_file(file_url):
|
57 |
+
r = requests.get(file_url, stream=True)
|
58 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
59 |
+
for chunk in r.iter_content(chunk_size=1024):
|
60 |
+
if chunk: # filter out keep-alive new chunks
|
61 |
+
tmp.write(chunk)
|
62 |
+
return tmp.name
|
63 |
+
|
64 |
+
|
65 |
+
def cache_folder():
|
66 |
+
_cache_folder = os.path.join(os.path.expanduser("~"), ".ffusion")
|
67 |
+
os.makedirs(_cache_folder, exist_ok=True)
|
68 |
+
return _cache_folder
|
69 |
+
|
70 |
+
|
71 |
+
def clear_memory(preserve):
|
72 |
+
torch.cuda.empty_cache()
|
73 |
+
gc.collect()
|
74 |
+
to_clear = ["inpainting", "text2img", "img2text"]
|
75 |
+
for key in to_clear:
|
76 |
+
if key not in preserve and key in st.session_state:
|
77 |
+
del st.session_state[key]
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
import subprocess
|
82 |
+
|
83 |
+
from huggingface_hub import Repository
|
84 |
+
|
85 |
+
def save_to_hub(image, current_datetime, metadata, output_path):
|
86 |
+
"""Saves an image to Hugging Face Hub"""
|
87 |
+
try:
|
88 |
+
# Convert image to byte array
|
89 |
+
byte_arr = io.BytesIO()
|
90 |
+
|
91 |
+
# Check if the image has metadata
|
92 |
+
if image.info:
|
93 |
+
# Save as PNG
|
94 |
+
image.save(byte_arr, format='PNG')
|
95 |
+
else:
|
96 |
+
# Save as JPG
|
97 |
+
image.save(byte_arr, format='JPEG')
|
98 |
+
|
99 |
+
byte_arr = byte_arr.getvalue()
|
100 |
+
|
101 |
+
# Create a repository object
|
102 |
+
token = os.getenv("HF_TOKEN")
|
103 |
+
api = HfApi()
|
104 |
+
|
105 |
+
username = "FFusion"
|
106 |
+
repo_name = "FF"
|
107 |
+
try:
|
108 |
+
repo = Repository(f"{username}/{repo_name}", clone_from=f"{username}/{repo_name}", use_auth_token=token, repo_type="dataset")
|
109 |
+
except RepositoryNotFoundError:
|
110 |
+
repo = Repository(f"{username}/{repo_name}", clone_from=f"{username}/{repo_name}", use_auth_token=token, repo_type="dataset")
|
111 |
+
|
112 |
+
# Create the directory if it does not exist
|
113 |
+
os.makedirs(os.path.dirname(f"{repo.local_dir}/{output_path}"), exist_ok=True)
|
114 |
+
|
115 |
+
# Write image to repository
|
116 |
+
with open(f"{repo.local_dir}/{output_path}", "wb") as f:
|
117 |
+
f.write(byte_arr)
|
118 |
+
|
119 |
+
# Set Git username and email
|
120 |
+
subprocess.run(["git", "config", "user.name", "idle stoev"], check=True, cwd=repo.local_dir)
|
121 |
+
subprocess.run(["git", "config", "user.email", "[email protected]"], check=True, cwd=repo.local_dir)
|
122 |
+
|
123 |
+
# Commit and push changes
|
124 |
+
repo.git_add(pattern=".")
|
125 |
+
repo.git_commit(f"Add image at {current_datetime}")
|
126 |
+
print(f"Pushing changes to {username}/{repo_name}...")
|
127 |
+
repo.git_push()
|
128 |
+
print(f"Image saved to {username}/{repo_name}/{output_path}")
|
129 |
+
|
130 |
+
except Exception as e:
|
131 |
+
print(f"Failed to save image to Hugging Face Hub: {e}")
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
def save_to_local(images, module, current_datetime, metadata, output_path):
|
150 |
+
_metadata = PngInfo()
|
151 |
+
_metadata.add_text("text2img", metadata)
|
152 |
+
os.makedirs(output_path, exist_ok=True)
|
153 |
+
os.makedirs(f"{output_path}/{module}", exist_ok=True)
|
154 |
+
os.makedirs(f"{output_path}/{module}/{current_datetime}", exist_ok=True)
|
155 |
+
|
156 |
+
for i, img in enumerate(images):
|
157 |
+
img.save(
|
158 |
+
f"{output_path}/{module}/{current_datetime}/{i}.png",
|
159 |
+
pnginfo=_metadata,
|
160 |
+
)
|
161 |
+
|
162 |
+
# save metadata as text file
|
163 |
+
with open(f"{output_path}/{module}/{current_datetime}/metadata.txt", "w") as f:
|
164 |
+
f.write(metadata)
|
165 |
+
logger.info(f"Saved images to {output_path}/{module}/{current_datetime}")
|
166 |
+
|
167 |
+
|
168 |
+
def save_images(images, module, metadata, output_path):
|
169 |
+
if output_path is None:
|
170 |
+
logger.warning("No output path specified, skipping saving images")
|
171 |
+
return
|
172 |
+
|
173 |
+
api = HfApi()
|
174 |
+
dset_info = None
|
175 |
+
try:
|
176 |
+
dset_info = api.dataset_info(output_path)
|
177 |
+
except (HFValidationError, RepositoryNotFoundError):
|
178 |
+
logger.warning("No valid hugging face repo. Saving locally...")
|
179 |
+
|
180 |
+
current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
181 |
+
|
182 |
+
if not dset_info:
|
183 |
+
save_to_local(images, module, current_datetime, metadata, output_path)
|
184 |
+
else:
|
185 |
+
Thread(target=save_to_hub, args=(api, images, module, current_datetime, metadata, output_path)).start()
|
186 |
+
|
187 |
+
|
188 |
+
def display_and_download_images(output_images, metadata, download_col=None):
|
189 |
+
# st.image(output_images, width=128, output_format="PNG")
|
190 |
+
|
191 |
+
with st.spinner("Preparing images for download..."):
|
192 |
+
# save images to a temporary directory
|
193 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
194 |
+
gallery_images = []
|
195 |
+
for i, image in enumerate(output_images):
|
196 |
+
image.save(os.path.join(tmpdir, f"{i + 1}.png"), pnginfo=metadata)
|
197 |
+
with open(os.path.join(tmpdir, f"{i + 1}.png"), "rb") as img:
|
198 |
+
encoded = base64.b64encode(img.read()).decode()
|
199 |
+
gallery_images.append(f"data:image/jpeg;base64,{encoded}")
|
200 |
+
|
201 |
+
# zip the images
|
202 |
+
zip_path = os.path.join(tmpdir, "images.zip")
|
203 |
+
with zipfile.ZipFile(zip_path, "w") as zip:
|
204 |
+
for filename in os.listdir(tmpdir):
|
205 |
+
if filename.endswith(".png"):
|
206 |
+
zip.write(os.path.join(tmpdir, filename), filename)
|
207 |
+
|
208 |
+
# convert zip to base64
|
209 |
+
with open(zip_path, "rb") as f:
|
210 |
+
encoded = base64.b64encode(f.read()).decode()
|
211 |
+
|
212 |
+
_ = clickable_images(
|
213 |
+
gallery_images,
|
214 |
+
titles=[f"Image #{str(i)}" for i in range(len(gallery_images))],
|
215 |
+
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
|
216 |
+
img_style={"margin": "5px", "height": "200px"},
|
217 |
+
)
|
218 |
+
|
219 |
+
# add download link
|
220 |
+
st.markdown(
|
221 |
+
f"""
|
222 |
+
<a href="data:application/zip;base64,{encoded}" download="images.zip">
|
223 |
+
Download Images
|
224 |
+
</a>
|
225 |
+
""",
|
226 |
+
unsafe_allow_html=True,
|
227 |
+
)
|