import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCasualLM
from diffusers import DiffusionPipeline
import requests
from PIL import Image
from io import BytesIO
import onnxruntime as ort
from huggingface_hub import hf_hub_download

# Initialize models
anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
anime_model = ort.InferenceSession(anime_model_path)
photo_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)

# Load labels for the anime model
labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
with open(labels_path, 'r') as f:
    anime_labels = [line.strip().split(',')[0] for line in f.readlines()[1:]]  # Skip header

def preprocess_image(image):
    image = image.convert('RGB')
    image = image.resize((448, 448), Image.LANCZOS)
    image = np.array(image).astype(np.float32)
    image = image[:, :, ::-1]  # RGB -> BGR
    image = np.transpose(image, (2, 0, 1))  # HWC -> CHW
    image = image / 255.0
    return image[np.newaxis, ...]

def transcribe_image(image, image_type, transcriber, booru_tags=None):
    if image_type == "Anime":
        input_image = preprocess_image(image)
        input_name = anime_model.get_inputs()[0].name
        output_name = anime_model.get_outputs()[0].name
        probs = anime_model.run([output_name], {input_name: input_image})[0]
        
        # Get top 50 tags
        top_indices = probs[0].argsort()[-50:][::-1]
        tags = [anime_labels[i] for i in top_indices]
    else:
        prompt = "<MORE_DETAILED_CAPTION>"
        inputs = processor(images=image, text=prompt, return_tensors="pt")
        generated_ids = photo_model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            do_sample=False,
            num_beams=3,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        tags = generated_text  # Use generated text as the description
    
    return ", ".join(tags)


def get_booru_image(booru, image_id):
    if booru == "Gelbooru":
        url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
    elif booru == "Danbooru":
        url = f"https://danbooru.donmai.us/posts/{image_id}.json"
    elif booru == "rule34.xxx":
        url = f"https://api.rule34.xxx/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
    else:
        raise ValueError("Unsupported booru")

    response = requests.get(url)
    data = response.json()

    # The exact structure of the response will vary depending on the booru
    # You'll need to adjust this part based on each booru's API
    image_url = data[0]['file_url'] if isinstance(data, list) else data['file_url']
    tags = data[0]['tags'].split() if isinstance(data, list) else data['tags'].split()

    img_response = requests.get(image_url)
    img = Image.open(BytesIO(img_response.content))

    return img, tags

def update_image(image_type, booru, image_id, uploaded_image):
    if image_type == "Anime" and booru != "Upload":
        image, booru_tags = get_booru_image(booru, image_id)
        return image, gr.update(visible=True), booru_tags
    elif uploaded_image is not None:
        return uploaded_image, gr.update(visible=True), None
    else:
        return None, gr.update(visible=False), None

def on_image_type_change(image_type):
    if image_type == "Anime":
        return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"])
    else:
        return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"])

with gr.Blocks() as app:
    gr.Markdown("# Image Transcription App")
    
    with gr.Tab("Step 1: Image"):
        image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type")
        
        with gr.Column(visible=False) as anime_options:
            booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus")
            image_id = gr.Textbox(label="Image ID")
            get_image_btn = gr.Button("Get image")
        
        upload_btn = gr.UploadButton("Upload Image", visible=False)
        
        image_display = gr.Image(label="Image to transcribe", visible=False)
        booru_tags = gr.State(None)
        
        transcribe_btn = gr.Button("Transcribe", visible=False)
        transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False)
    
    with gr.Tab("Step 2: Transcribe"):
        transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber")
        transcribe_image_display = gr.Image(label="Image to transcribe")
        transcribe_btn_final = gr.Button("Transcribe")
        tags_output = gr.Textbox(label="Transcribed tags")
    
    image_type.change(on_image_type_change, inputs=[image_type], 
                      outputs=[anime_options, upload_btn, transcriber])
    
    get_image_btn.click(update_image, 
                        inputs=[image_type, booru, image_id, upload_btn], 
                        outputs=[image_display, transcribe_btn, booru_tags])
    
    upload_btn.upload(update_image, 
                      inputs=[image_type, booru, image_id, upload_btn], 
                      outputs=[image_display, transcribe_btn, booru_tags])
    
    def transcribe_and_update(image, image_type, transcriber, booru_tags):
        tags = transcribe_image(image, image_type, transcriber, booru_tags)
        return image, tags
    
    transcribe_btn.click(transcribe_and_update, 
                         inputs=[image_display, image_type, transcriber, booru_tags], 
                         outputs=[transcribe_image_display, tags_output])
    
    transcribe_with_tags_btn.click(transcribe_and_update, 
                                   inputs=[image_display, image_type, transcriber, booru_tags], 
                                   outputs=[transcribe_image_display, tags_output])
    
    transcribe_btn_final.click(transcribe_image, 
                               inputs=[transcribe_image_display, image_type, transcriber], 
                               outputs=[tags_output])

app.launch()