import gradio as gr from io import BytesIO import json import os from PIL import Image import requests import time import getpass from PIL import ImageDraw, ImageFont # Get the API key STABILITY_API_KEY = os.environ.get('STABILITY_API_KEY') # Try to get from environment variable def send_generation_request( host, params, ): headers = { "Accept": "image/*", "Authorization": f"Bearer {STABILITY_API_KEY}" } # Encode parameters files = {} image = params.pop("image", None) mask = params.pop("mask", None) if image is not None and image != '': files["image"] = open(image, 'rb') if mask is not None and mask != '': files["mask"] = open(mask, 'rb') if len(files)==0: files["none"] = '' # Send request print(f"Sending REST request to {host}...") response = requests.post( host, headers=headers, files=files, data=params ) if not response.ok: raise Exception(f"HTTP {response.status_code}: {response.text}") return response def generate_image(brand_name, type, coupon_info, address, prompt, negative_prompt, seed, aspect_ratio, output_format, model): """Generates an image using the SD3 API.""" host = os.environ["STABILITY_HOST"] params = { "prompt": prompt, "negative_prompt": negative_prompt, "aspect_ratio": aspect_ratio, "seed": seed, "output_format": output_format, "model": model, # Use the selected model "mode": "text-to-image" } response = send_generation_request(host, params) # Decode response output_image = response.content finish_reason = response.headers.get("finish-reason") seed = response.headers.get("seed") # Check for NSFW classification if finish_reason == 'CONTENT_FILTERED': raise Warning("Generation failed NSFW classifier") # Convert image to PIL format image = Image.open(BytesIO(output_image)) # Add brand info to the upper left corner draw = ImageDraw.Draw(image) font = ImageFont.truetype("./arial-font/arial.ttf", 66) # You can change the font and size draw.text((10, 10), brand_name, fill=(255, 255, 255), font=font) # Adjust position and color as needed # TODO: 新增 类别 # TODO: 新增 折扣信息 # TODO: 新增 实体店地址 # TODO: 新增 编号控制 # TODO: 新增 智能客服联系人QR Code return image """ Cost Rule - SD3 Large: 2B, 6.5 credits each - SD3 Large Turbo: 8B, 4 credits each - SD3 Medium: 8B, 3.5 credits each """ # Create the Gradio interface interface = gr.Interface( fn=generate_image, inputs=[ gr.Textbox(label="品牌名称", value="AI原优舍", lines=1, placeholder="样例: AI原优舍"), gr.Dropdown(label="类别", choices=["服饰", "饮食", "健身", "按摩", "公益"], value="服饰"), gr.Textbox(label="折扣信息", value="凭券全场任意两件正价八折", lines=1, placeholder="样例: 凭券全场任意两件正价八折"), gr.Textbox(label="实体店地址", value="广州市麦栏街20号野隐大楼3层", lines=1, placeholder="样例: 广州市麦栏街20号野隐大楼3层"), # TODO: 新增 编号控制按钮 # TODO: 新增 智能客服联系人QR Code gr.Textbox(label="广告整体正向描述", value="A lovely girl is trying hard to make her dream come true.", lines=4, placeholder="A lovely girl is trying hard to make her dream come true."), gr.Textbox(label="广告整体负向描述", lines=1, placeholder="Optional"), gr.Number(label="算法随机种子", value=0), gr.Dropdown(label="折扣券长宽比", choices=["21:9", "16:9", "3:2", "5:4", "1:1", "4:5", "2:3", "9:16", "9:21"], value="16:9"), gr.Dropdown(label="折扣券输出格式", choices=["jpeg", "png"], value="png"), gr.Dropdown(label="生图模型选择", choices=["sd3-large", "sd3-large-turbo", "sd3-medium"], value="sd3-medium") # Add model dropdown ], outputs=[ gr.Image(label="生成的折扣券"), #gr.Gallery(label="历史生成记录", show_label=True, elem_id="gallery") # Add a gallery component ], title="妈妈折扣券生成器", description="基于认知计算广告理论,为广告主一键生成兼具品牌影响力及传播力的品牌专属折扣券。" ) # Function to update the gallery def update_gallery(image, gallery): gallery.append(image) return gallery # Connect the generate_image function to the gallery update #interface.outputs[0].click(update_gallery, inputs=[interface.outputs[0], interface.outputs[1]]) #interface.load(lambda: None).then(lambda: interface.getElementById("generated_image").click(update_gallery, inputs=[interface.getElementById("generated_image"), interface.getElementById("gallery")])) # Launch the interface interface.launch(share=False)