Spaces:
Runtime error
Runtime error
import base64 | |
import io | |
import random | |
import time | |
from typing import List | |
from PIL import Image | |
import aiohttp | |
import asyncio | |
import requests | |
import streamlit as st | |
import requests | |
import zipfile | |
import io | |
import pandas as pd | |
from utils import icon | |
from streamlit_image_select import image_select | |
from PIL import Image | |
import random | |
import time | |
import base64 | |
from typing import List | |
import aiohttp | |
import asyncio | |
import plotly.express as px | |
from common import set_page_container_style | |
def pil_image_to_base64(image: Image.Image) -> str: | |
image_stream = io.BytesIO() | |
image.save(image_stream, format="PNG") | |
base64_image = base64.b64encode(image_stream.getvalue()).decode("utf-8") | |
return base64_image | |
def get_or_create_eventloop(): | |
try: | |
return asyncio.get_event_loop() | |
except RuntimeError as ex: | |
if "There is no current event loop in thread" in str(ex): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return asyncio.get_event_loop() | |
model_config = { | |
"RealisticVision": { | |
"ratio": { | |
"square": (512, 512), | |
"tall": (512, 768), | |
"wide": (768, 512), | |
}, | |
"num_inference_steps": 30, | |
"guidance_scale": 7.0, | |
"clip_skip": 2, | |
}, | |
"AnimeV3": { | |
"num_inference_steps": 25, | |
"guidance_scale": 7, | |
"clip_skip": 2, | |
"ratio": { | |
"square": (1024, 1024), | |
"tall": (672, 1024), | |
"wide": (1024, 672), | |
}, | |
}, | |
"DreamShaper": { | |
"num_inference_steps": 35, | |
"guidance_scale": 7, | |
"clip_skip": 2, | |
"ratio": { | |
"square": (512, 512), | |
"tall": (512, 768), | |
"wide": (768, 512), | |
}, | |
}, | |
"RealitiesEdgeXL": { | |
"num_inference_steps": 7, | |
"guidance_scale": 2.5, | |
"clip_skip": 2, | |
"ratio": { | |
"square": (1024, 1024), | |
"tall": (672, 1024), | |
"wide": (1024, 672), | |
}, | |
}, | |
} | |
def base64_to_image(base64_string): | |
return Image.open(io.BytesIO(base64.b64decode(base64_string))) | |
async def call_niche_api(url, data) -> List[Image.Image]: | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.post(url, json=data) as response: | |
response = await response.json() | |
return base64_to_image(response) | |
except Exception as e: | |
print(e) | |
return None | |
async def get_output(url, datas): | |
tasks = [asyncio.create_task(call_niche_api(url, data)) for data in datas] | |
return await asyncio.gather(*tasks) | |
def main_page( | |
submitted: bool, | |
model_name: str, | |
prompt: str, | |
negative_prompt: str, | |
aspect_ratio: str, | |
num_images: int, | |
uid: str, | |
secret_key: str, | |
seed: str, | |
conditional_image: str, | |
controlnet_conditioning_scale: list, | |
pipeline_type: str, | |
api_token: str, | |
generated_images_placeholder, | |
) -> None: | |
"""Main page layout and logic for generating images. | |
Args: | |
submitted (bool): Flag indicating whether the form has been submitted. | |
width (int): Width of the output image. | |
height (int): Height of the output image. | |
num_inference_steps (int): Number of denoising steps. | |
guidance_scale (float): Scale for classifier-free guidance. | |
prompt_strength (float): Prompt strength when using img2img/inpaint. | |
prompt (str): Text prompt for the image generation. | |
negative_prompt (str): Text prompt for elements to avoid in the image. | |
""" | |
if submitted: | |
if secret_key != api_token and uid != "-1": | |
st.error("Invalid secret key") | |
return | |
try: | |
uid = int(uid) | |
except ValueError: | |
uid = -1 | |
width, height = model_config[model_name]["ratio"][aspect_ratio.lower()] | |
width = int(width) | |
height = int(height) | |
num_inference_steps = model_config[model_name]["num_inference_steps"] | |
guidance_scale = model_config[model_name]["guidance_scale"] | |
with st.status( | |
"π©πΎβπ³ Whipping up your words into art...", expanded=True | |
) as status: | |
try: | |
# Only call the API if the "Submit" button was pressed | |
if submitted: | |
start_time = time.time() | |
# Calling the replicate API to get the image | |
with generated_images_placeholder.container(): | |
try: | |
seed = int(seed) | |
except ValueError: | |
seed = -1 | |
if seed >= 0: | |
seeds = [int(seed) + i for i in range(num_images)] | |
else: | |
seeds = [random.randint(0, 1e9) for _ in range(num_images)] | |
all_images = [] # List to store all generated images | |
data = { | |
"key": api_token, | |
"prompt": prompt, # prompt | |
"model_name": model_name, # See avaialble models in https://github.com/NicheTensor/NicheImage/blob/main/configs/model_config.yaml | |
"seed": seed, # -1 means random seed | |
"miner_uid": int( | |
uid | |
), # specify miner uid, -1 means random miner selected by validator | |
"pipeline_type": pipeline_type, | |
"conditional_image": conditional_image, | |
"pipeline_params": { # params feed to diffusers pipeline, see all params here https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__ | |
"width": width, | |
"height": height, | |
"num_inference_steps": num_inference_steps, | |
"guidance_scale": guidance_scale, | |
"negative_prompt": negative_prompt, | |
"controlnet_conditioning_scale": controlnet_conditioning_scale, | |
"clip_skip": model_config[model_name]["clip_skip"], | |
}, | |
} | |
duplicate_data = [data.copy() for _ in range(num_images)] | |
for i, d in enumerate(duplicate_data): | |
d["seed"] = seeds[i] | |
# Call the NicheImage API | |
loop = get_or_create_eventloop() | |
asyncio.set_event_loop(loop) | |
output = loop.run_until_complete( | |
get_output( | |
"http://proxy_client_nicheimage.nichetensor.com:10003/generate", | |
duplicate_data, | |
) | |
) | |
while len(output) < 4: | |
output.append(None) | |
for i, image in enumerate(output): | |
if not image: | |
output[i] = Image.new("RGB", (width, height), (0, 0, 0)) | |
print(output) | |
if output: | |
st.toast("Your image has been generated!", icon="π") | |
end_time = time.time() | |
status.update( | |
label=f"β Images generated in {round(end_time-start_time, 3)} seconds", | |
state="complete", | |
expanded=False, | |
) | |
# Save generated image to session state | |
st.session_state.generated_image = output | |
captions = [f"Image {i+1} π" for i in range(4)] | |
all_images = [] | |
# Displaying the image | |
_, main_col, _ = st.columns([0.15, 0.7, 0.15]) | |
with main_col: | |
cols_1 = st.columns(2) | |
cols_2 = st.columns(2) | |
with st.container(border=True): | |
for i, image in enumerate( | |
st.session_state.generated_image[:2] | |
): | |
cols_1[i].image( | |
image, | |
caption=captions[i], | |
use_column_width=True, | |
output_format="PNG", | |
) | |
# Add image to the list | |
all_images.append(image) | |
for i, image in enumerate( | |
st.session_state.generated_image[2:] | |
): | |
cols_2[i].image( | |
image, | |
caption=captions[i + 2], | |
use_column_width=True, | |
output_format="PNG", | |
) | |
# Save all generated images to session state | |
st.session_state.all_images = all_images | |
zip_io = io.BytesIO() | |
# Download option for each image | |
with zipfile.ZipFile(zip_io, "w") as zipf: | |
for i, image in enumerate(st.session_state.all_images): | |
image_data = io.BytesIO() | |
image.save(image_data, format="PNG") | |
image_data.seek(0) | |
# Write each image to the zip file with a name | |
zipf.writestr( | |
f"output_file_{i+1}.png", image_data.read() | |
) | |
# Create a download button for the zip file | |
st.download_button( | |
":red[**Download All Images**]", | |
data=zip_io.getvalue(), | |
file_name="output_files.zip", | |
mime="application/zip", | |
use_container_width=True, | |
) | |
status.update( | |
label="β Images generated!", state="complete", expanded=False | |
) | |
except Exception as e: | |
print(e) | |
st.error(f"Encountered an error: {e}", icon="π¨") | |
# If not submitted, chill here πΉ | |
else: | |
pass | |