import gradio as gr
import torch
# from diffusers import DiffusionPipeline
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL
from diffusers import StableDiffusionPipeline




def generate(prompt, negative_prompts, samples, steps,scale, seed, width, height):

    pipeline = StableDiffusionPipeline.from_pretrained("jayparmr/icbinp", use_auth_token="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", torch_dtype=torch.float16)
    pipeline.to("cuda")
  
    generator = torch.Generator(device="cuda").manual_seed(int(seed))
    
    images_list = pipeline(
        [prompt] * samples,
        negative_prompt= [negative_prompts] * samples,
        num_inference_steps=steps,
        guidance_scale=scale,
        generator=generator,
        width=width,
        height=height
    )

    # vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
    # pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae).to("cuda")
    
    
    # images_list = pipe(
    #     [prompt] * samples,
    #     negative_prompt= [negative_prompts] * samples,
    #     num_inference_steps=steps,
    #     guidance_scale=scale
    # )
    print("stop gen")
    images = []
    print(images_list) 
    for i, image in enumerate(images_list["images"]):
        images.append(image)
    return images

block = gr.Blocks()

with block:
    with gr.Group():
        with gr.Box():
            with gr.Row().style(equal_height=True):
                text = gr.Textbox(
                    label="Enter your prompt",
                    show_label=False,
                    max_lines=1,
                    placeholder="Enter your prompt",
                )
                negative_text = gr.Textbox(
                    value="",
                    label="Enter your negative prompt",
                    show_label=False,
                    max_lines=1,
                    placeholder="Enter your negative prompt",
                )
                btn = gr.Button("Generate image")
        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery", width = 512
        ).style(columns=[2], rows=[2], object_fit="contain", height="auto")


        with gr.Row(elem_id="advanced-options"):
            samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
            steps = gr.Slider(label="Steps", minimum=1, maximum=500, value=100, step=1)
            width = gr.Slider(label="width", minimum=1, maximum=2048, value=512, step=1)
            height = gr.Slider(label="height", minimum=1, maximum=2048, value=512, step=1)
            scale = gr.Slider(
                label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
            )
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=2147483647,
                step=1
            )
        text.submit(generate, inputs=[text,negative_text, samples, steps, scale, seed, width, height], outputs=gallery)
        btn.click(generate, inputs=[text,negative_text, samples, steps, scale, seed, width, height], outputs=gallery)
        
        
block.launch()