|
import gradio as gr |
|
from io import BytesIO |
|
import json |
|
import os |
|
from PIL import Image |
|
import requests |
|
import time |
|
import getpass |
|
|
|
|
|
STABILITY_API_KEY = os.environ.get('STABILITY_API_KEY') |
|
|
|
def send_generation_request( |
|
host, |
|
params, |
|
): |
|
headers = { |
|
"Accept": "image/*", |
|
"Authorization": f"Bearer {STABILITY_API_KEY}" |
|
} |
|
|
|
|
|
files = {} |
|
image = params.pop("image", None) |
|
mask = params.pop("mask", None) |
|
if image is not None and image != '': |
|
files["image"] = open(image, 'rb') |
|
if mask is not None and mask != '': |
|
files["mask"] = open(mask, 'rb') |
|
if len(files)==0: |
|
files["none"] = '' |
|
|
|
|
|
print(f"Sending REST request to {host}...") |
|
response = requests.post( |
|
host, |
|
headers=headers, |
|
files=files, |
|
data=params |
|
) |
|
if not response.ok: |
|
raise Exception(f"HTTP {response.status_code}: {response.text}") |
|
|
|
return response |
|
|
|
def generate_image(prompt, negative_prompt, aspect_ratio, seed, output_format): |
|
"""Generates an image using the SD3 API.""" |
|
|
|
host = os.environ["STABILITY_HOST"] |
|
|
|
params = { |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"aspect_ratio": aspect_ratio, |
|
"seed": seed, |
|
"output_format": output_format, |
|
"model": "sd3-medium", |
|
"mode": "text-to-image" |
|
} |
|
|
|
response = send_generation_request(host, params) |
|
|
|
|
|
output_image = response.content |
|
finish_reason = response.headers.get("finish-reason") |
|
seed = response.headers.get("seed") |
|
|
|
|
|
if finish_reason == 'CONTENT_FILTERED': |
|
raise Warning("Generation failed NSFW classifier") |
|
|
|
|
|
image = Image.open(BytesIO(output_image)) |
|
|
|
return image |
|
|
|
""" |
|
Cost Rule |
|
- SD3 Large: 2B, 6.5 credits each |
|
- SD3 Large Turbo: 8B, 4 credits each |
|
- SD3 Medium: 8B, 3.5 credits each |
|
""" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_image, |
|
inputs=[ |
|
gr.Textbox(label="正向提示词", placeholder="This dreamlike digital art captures a vibrant, kaleidoscopic bird in a lush rainforest"), |
|
gr.Textbox(label="负向提示词", placeholder="Optional"), |
|
gr.Dropdown(label="长宽比", choices=["21:9", "16:9", "3:2", "5:4", "1:1", "4:5", "2:3", "9:16", "9:21"], value="16:9"), |
|
gr.Number(label="生成算法随机种子", value=0), |
|
gr.Dropdown(label="输出格式", choices=["jpeg", "png"], value="png") |
|
], |
|
outputs="image", |
|
title="Stable Diffusion 3 Image Generator", |
|
description="Generate images with Stable Diffusion 3. Type a prompt and see the magic!" |
|
) |
|
|
|
|
|
interface.launch(share=True) |