File size: 2,242 Bytes
2741722
 
 
 
 
 
 
 
 
c5e29f0
 
 
2741722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdf9332
2741722
 
c5e29f0
2741722
 
 
 
c5e29f0
 
 
 
 
 
 
 
 
 
 
 
2741722
 
 
c5e29f0
2741722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from all_models import models
from externalmod import gr_Interface_load
import asyncio
import os
from datetime import datetime

# Load the models
HF_TOKEN = os.getenv("HF_TOKEN", None)
from PIL import Image
import io


def load_models(models):
    loaded_models = {}
    for model in models:
        try:
            loaded_models[model] = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
        except Exception as e:
            print(f"Error loading {model}: {e}")
    return loaded_models

models_load = load_models(models)

# Generate image function
async def infer(model_str, prompt, seed=-1):
    task = asyncio.create_task(
        asyncio.to_thread(models_load[model_str].fn, prompt=prompt, seed=seed, token=HF_TOKEN)
    )
    await asyncio.sleep(0)
    result = await asyncio.wait_for(task, timeout=600)
    return result


def generate_image(model_name, prompt, seed):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        # Get the result from inference
        result = loop.run_until_complete(infer(model_name, prompt, seed))
        if isinstance(result, tuple):
            # Assuming the first element is the image data
            result = result[0]
        if isinstance(result, bytes):
            # Convert bytes to PIL Image if necessary
            return Image.open(io.BytesIO(result))
        elif isinstance(result, Image.Image):
            return result
        else:
            raise ValueError(f"Unexpected output type: {type(result)}")
    finally:
        loop.close()

        
# Interface
with gr.Blocks() as demo:
    with gr.Column():
        model_choice = gr.Dropdown(
            choices=models, label="Select Model", value=models[0]
        )
        prompt_input = gr.Textbox(label="Enter your prompt")
        seed_input = gr.Slider(
            label="Seed", minimum=-1, maximum=100000, step=1, value=-1
        )
        generate_button = gr.Button("Generate Image")
        output_image = gr.Image(label="Generated Image", show_download_button=True)

        generate_button.click(
            fn=generate_image,
            inputs=[model_choice, prompt_input, seed_input],
            outputs=[output_image],
        )

demo.launch()