|
import gradio as gr |
|
from io import BytesIO |
|
import json |
|
import os |
|
from PIL import Image |
|
import requests |
|
import time |
|
import getpass |
|
from PIL import ImageDraw, ImageFont |
|
|
|
|
|
STABILITY_API_KEY = os.environ.get('STABILITY_API_KEY') |
|
|
|
def send_generation_request( |
|
host, |
|
params, |
|
): |
|
headers = { |
|
"Accept": "image/*", |
|
"Authorization": f"Bearer {STABILITY_API_KEY}" |
|
} |
|
|
|
|
|
files = {} |
|
image = params.pop("image", None) |
|
mask = params.pop("mask", None) |
|
if image is not None and image != '': |
|
files["image"] = open(image, 'rb') |
|
if mask is not None and mask != '': |
|
files["mask"] = open(mask, 'rb') |
|
if len(files)==0: |
|
files["none"] = '' |
|
|
|
|
|
print(f"Sending REST request to {host}...") |
|
response = requests.post( |
|
host, |
|
headers=headers, |
|
files=files, |
|
data=params |
|
) |
|
if not response.ok: |
|
raise Exception(f"HTTP {response.status_code}: {response.text}") |
|
|
|
return response |
|
|
|
def generate_image(brand_name, type, coupon_info, address, prompt, negative_prompt, seed, aspect_ratio, output_format, model): |
|
"""Generates an image using the SD3 API.""" |
|
|
|
host = os.environ["STABILITY_HOST"] |
|
|
|
params = { |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"aspect_ratio": aspect_ratio, |
|
"seed": seed, |
|
"output_format": output_format, |
|
"model": model, |
|
"mode": "text-to-image" |
|
} |
|
|
|
response = send_generation_request(host, params) |
|
|
|
|
|
output_image = response.content |
|
finish_reason = response.headers.get("finish-reason") |
|
seed = response.headers.get("seed") |
|
|
|
|
|
if finish_reason == 'CONTENT_FILTERED': |
|
raise Warning("Generation failed NSFW classifier") |
|
|
|
|
|
image = Image.open(BytesIO(output_image)) |
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
font = ImageFont.truetype("./arial-font/arial.ttf", 20) |
|
draw.text((10, 10), brand_name, fill=(255, 255, 255), font=font) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return image |
|
|
|
""" |
|
Cost Rule |
|
- SD3 Large: 2B, 6.5 credits each |
|
- SD3 Large Turbo: 8B, 4 credits each |
|
- SD3 Medium: 8B, 3.5 credits each |
|
""" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_image, |
|
inputs=[ |
|
gr.Textbox(label="品牌名称", placeholder="样例: AI原优舍"), |
|
gr.Dropdown(label="类别", choices=["服饰", "饮食", "健身", "按摩", "公益"], value="服饰"), |
|
gr.Textbox(label="折扣信息", placeholder="样例: 凭券全场任意两件正价八折"), |
|
gr.Textbox(label="实体店地址", placeholder="样例: 广州市麦栏街20号野隐大楼3层"), |
|
|
|
|
|
gr.Textbox(label="广告整体正向描述", placeholder="This dreamlike digital art captures a vibrant, kaleidoscopic bird in a lush rainforest"), |
|
gr.Textbox(label="广告整体负向描述", placeholder="Optional"), |
|
gr.Number(label="算法随机种子", value=0), |
|
gr.Dropdown(label="折扣券长宽比", choices=["21:9", "16:9", "3:2", "5:4", "1:1", "4:5", "2:3", "9:16", "9:21"], value="16:9"), |
|
gr.Dropdown(label="折扣券输出格式", choices=["jpeg", "png"], value="png"), |
|
gr.Dropdown(label="生图模型选择", choices=["sd3-large", "sd3-large-turbo", "sd3-medium"], value="sd3-medium") |
|
], |
|
outputs=[ |
|
gr.Image(label="生成的折扣券"), |
|
|
|
], |
|
title="妈妈折扣券生成器", |
|
description="基于认知计算广告理论,为广告主一键生成兼具品牌影响力及传播力的品牌专属折扣券。" |
|
) |
|
|
|
|
|
def update_gallery(image, gallery): |
|
gallery.append(image) |
|
return gallery |
|
|
|
|
|
|
|
|
|
|
|
|
|
interface.launch(share=False) |