weijiang2024's picture
Upload folder using huggingface_hub
c2f9a5b verified
raw
history blame
4.74 kB
import gradio as gr
from io import BytesIO
import json
import os
from PIL import Image
import requests
import time
import getpass
from PIL import ImageDraw, ImageFont
# Get the API key
STABILITY_API_KEY = os.environ.get('STABILITY_API_KEY') # Try to get from environment variable
def send_generation_request(
host,
params,
):
headers = {
"Accept": "image/*",
"Authorization": f"Bearer {STABILITY_API_KEY}"
}
# Encode parameters
files = {}
image = params.pop("image", None)
mask = params.pop("mask", None)
if image is not None and image != '':
files["image"] = open(image, 'rb')
if mask is not None and mask != '':
files["mask"] = open(mask, 'rb')
if len(files)==0:
files["none"] = ''
# Send request
print(f"Sending REST request to {host}...")
response = requests.post(
host,
headers=headers,
files=files,
data=params
)
if not response.ok:
raise Exception(f"HTTP {response.status_code}: {response.text}")
return response
def generate_image(brand_name, type, coupon_info, address, prompt, negative_prompt, seed, aspect_ratio, output_format, model):
"""Generates an image using the SD3 API."""
host = os.environ["STABILITY_HOST"]
params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"aspect_ratio": aspect_ratio,
"seed": seed,
"output_format": output_format,
"model": model, # Use the selected model
"mode": "text-to-image"
}
response = send_generation_request(host, params)
# Decode response
output_image = response.content
finish_reason = response.headers.get("finish-reason")
seed = response.headers.get("seed")
# Check for NSFW classification
if finish_reason == 'CONTENT_FILTERED':
raise Warning("Generation failed NSFW classifier")
# Convert image to PIL format
image = Image.open(BytesIO(output_image))
# Add brand info to the upper left corner
draw = ImageDraw.Draw(image)
font = ImageFont.truetype("./arial-font/arial.ttf", 20) # You can change the font and size
draw.text((10, 10), brand_name, fill=(255, 255, 255), font=font) # Adjust position and color as needed
# TODO: 新增 类别
# TODO: 新增 折扣信息
# TODO: 新增 实体店地址
# TODO: 新增 编号控制
# TODO: 新增 智能客服联系人QR Code
return image
"""
Cost Rule
- SD3 Large: 2B, 6.5 credits each
- SD3 Large Turbo: 8B, 4 credits each
- SD3 Medium: 8B, 3.5 credits each
"""
# Create the Gradio interface
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="品牌名称", placeholder="样例: AI原优舍"),
gr.Dropdown(label="类别", choices=["服饰", "饮食", "健身", "按摩", "公益"], value="服饰"),
gr.Textbox(label="折扣信息", placeholder="样例: 凭券全场任意两件正价八折"),
gr.Textbox(label="实体店地址", placeholder="样例: 广州市麦栏街20号野隐大楼3层"),
# TODO: 新增 编号控制按钮
# TODO: 新增 智能客服联系人QR Code
gr.Textbox(label="广告整体正向描述", placeholder="This dreamlike digital art captures a vibrant, kaleidoscopic bird in a lush rainforest"),
gr.Textbox(label="广告整体负向描述", placeholder="Optional"),
gr.Number(label="算法随机种子", value=0),
gr.Dropdown(label="折扣券长宽比", choices=["21:9", "16:9", "3:2", "5:4", "1:1", "4:5", "2:3", "9:16", "9:21"], value="16:9"),
gr.Dropdown(label="折扣券输出格式", choices=["jpeg", "png"], value="png"),
gr.Dropdown(label="生图模型选择", choices=["sd3-large", "sd3-large-turbo", "sd3-medium"], value="sd3-medium") # Add model dropdown
],
outputs=[
gr.Image(label="生成的折扣券"),
#gr.Gallery(label="历史生成记录", show_label=True, elem_id="gallery") # Add a gallery component
],
title="妈妈折扣券生成器",
description="基于认知计算广告理论,为广告主一键生成兼具品牌影响力及传播力的品牌专属折扣券。"
)
# Function to update the gallery
def update_gallery(image, gallery):
gallery.append(image)
return gallery
# Connect the generate_image function to the gallery update
#interface.outputs[0].click(update_gallery, inputs=[interface.outputs[0], interface.outputs[1]])
#interface.load(lambda: None).then(lambda: interface.getElementById("generated_image").click(update_gallery, inputs=[interface.getElementById("generated_image"), interface.getElementById("gallery")]))
# Launch the interface
interface.launch(share=False)