import os
from fastapi import FastAPI, HTTPException, Depends, Body
from fastapi.responses import JSONResponse
from pydantic import BaseModel, field_validator, ValidationError
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline, StoppingCriteria
import boto3
import uvicorn
import soundfile as sf
import imageio
from typing import Dict, Optional, List
import torch  # Import torch
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
AWS_REGION = os.getenv("AWS_REGION")
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")

if not all([AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION, S3_BUCKET_NAME]):
    raise ValueError("Missing one or more AWS environment variables.")

s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)

app = FastAPI()

SPECIAL_TOKENS = {
    "bos_token": "<|startoftext|>",
    "eos_token": "<|endoftext|>",
    "pad_token": "[PAD]",
    "unk_token": "[UNK]",
}

class GenerateRequest(BaseModel):
    model_name: str
    input_text: str = ""
    task_type: str
    temperature: float = 1.0
    max_new_tokens: int = 10
    top_p: float = 1.0
    top_k: int = 50
    repetition_penalty: float = 1.1
    num_return_sequences: int = 1
    do_sample: bool = True
    stop_sequences: List[str] = []
    no_repeat_ngram_size: int = 2
    continuation_id: Optional[str] = None

    @field_validator("model_name")
    def model_name_cannot_be_empty(cls, v):
        if not v:
            raise ValueError("model_name cannot be empty.")
        return v

    @field_validator("task_type")
    def task_type_must_be_valid(cls, v):
        valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
        if v not in valid_types:
            raise ValueError(f"task_type must be one of: {valid_types}")
        return v

    @field_validator("max_new_tokens")
    def max_new_tokens_must_be_within_limit(cls, v):
        if v > 500:
            raise ValueError("max_new_tokens cannot be greater than 500.")
        return v

class S3ModelLoader:
    def __init__(self, bucket_name, s3_client):
        self.bucket_name = bucket_name
        self.s3_client = s3_client

    def _get_s3_uri(self, model_name):
        return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"

    async def load_model_and_tokenizer(self, model_name):
        s3_uri = self._get_s3_uri(model_name)
        try:
            config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
            model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False)
            tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=False)
            tokenizer.add_special_tokens(SPECIAL_TOKENS)
            model.resize_token_embeddings(len(tokenizer))
            if tokenizer.pad_token_id is None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
            return model, tokenizer
        except Exception as e:
            logging.error(f"Error loading model from S3: {e}")
            raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")

model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)

active_generations: Dict[str, Dict] = {}

async def get_model_and_tokenizer(model_name: str):
    try:
        return await model_loader.load_model_and_tokenizer(model_name)
    except Exception as e:
        logging.error(f"Error loading model: {e}")
        raise HTTPException(status_code=500, detail=f"Error loading model: {e}")

@app.post("/generate")
async def generate(request: GenerateRequest = Body(...), model_resources: tuple = Depends(get_model_and_tokenizer)):
    model, tokenizer = model_resources
    try:
        model_name = request.model_name
        input_text = request.input_text
        temperature = request.temperature
        max_new_tokens = request.max_new_tokens
        top_p = request.top_p
        top_k = request.top_k
        repetition_penalty = request.repetition_penalty
        num_return_sequences = request.num_return_sequences
        do_sample = request.do_sample
        stop_sequences = request.stop_sequences
        no_repeat_ngram_size = request.no_repeat_ngram_size
        continuation_id = request.continuation_id

        if continuation_id:
            if continuation_id not in active_generations:
                raise HTTPException(status_code=404, detail="Continuation ID not found.")
            previous_data = active_generations[continuation_id]
            if previous_data["model_name"] != model_name:
                raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
            input_text = previous_data["output"]

        generation_config = GenerationConfig.from_pretrained(model_name) # Load default config and override
        generation_config.temperature = temperature
        generation_config.max_new_tokens = max_new_tokens
        generation_config.top_p = top_p
        generation_config.top_k = top_k
        generation_config.repetition_penalty = repetition_penalty
        generation_config.do_sample = do_sample
        generation_config.num_return_sequences = num_return_sequences
        generation_config.no_repeat_ngram_size = no_repeat_ngram_size
        generation_config.pad_token_id = tokenizer.pad_token_id

        generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)

        new_continuation_id = continuation_id if continuation_id else os.urandom(16).hex()
        active_generations[new_continuation_id] = {"model_name": model_name, "output": generated_text}

        return JSONResponse({"text": generated_text, "continuation_id": new_continuation_id, "model_name": model_name})

    except HTTPException as http_err:
        raise http_err
    except Exception as e:
        logging.error(f"Internal server error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
    max_model_length = model.config.max_position_embeddings
    encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(model.device) # Ensure input is on the same device as the model

    stopping_criteria = StoppingCriteriaList()

    class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
        def __init__(self, stop_sequences, tokenizer):
            super().__init__() # call parent constructor
            self.stop_sequences = stop_sequences
            self.tokenizer = tokenizer

        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
            decoded_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
            for stop in self.stop_sequences:
                if decoded_output.endswith(stop):
                    return True
            return False

    if stop_sequences: # Only add if stop_sequences is not empty
        stopping_criteria.append(CustomStoppingCriteria(stop_sequences, tokenizer))

    outputs = model.generate(
        encoded_input.input_ids,
        generation_config=generation_config,
        stopping_criteria=stopping_criteria,
        pad_token_id=generation_config.pad_token_id
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

async def load_pipeline_from_s3(task, model_name):
    s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
    try:
        return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
    except Exception as e:
         logging.error(f"Error loading {task} model from S3: {e}")
         raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")

@app.post("/generate-image")
async def generate_image(request: GenerateRequest = Body(...)):
    try:
        if request.task_type != "text-to-image":
            raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")

        image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
        image = image_generator(request.input_text)[0]
        image_path = f"generated_image_{os.urandom(8).hex()}.png" # Save image locally
        image.save(image_path)
        new_continuation_id = os.urandom(16).hex()
        active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Image saved to {image_path}"} # Return path or upload URL
        return JSONResponse({"url": image_path, "continuation_id": new_continuation_id, "model_name": request.model_name})

    except HTTPException as http_err:
        raise http_err
    except Exception as e:
        logging.error(f"Internal server error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.post("/generate-text-to-speech")
async def generate_text_to_speech(request: GenerateRequest = Body(...)):
    try:
        if request.task_type != "text-to-speech":
            raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")

        tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
        audio_output = tts_pipeline(request.input_text)
        audio_path = f"generated_audio_{os.urandom(8).hex()}.wav"
        sf.write(audio_path, audio_output["sampling_rate"], audio_output["audio"])
        new_continuation_id = os.urandom(16).hex()
        active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Audio saved to {audio_path}"}
        return JSONResponse({"url": audio_path, "continuation_id": new_continuation_id, "model_name": request.model_name})

    except HTTPException as http_err:
        raise http_err
    except Exception as e:
        logging.error(f"Internal server error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.post("/generate-video")
async def generate_video(request: GenerateRequest = Body(...)):
    try:
        if request.task_type != "text-to-video":
            raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")

        video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
        video_frames = video_pipeline(request.input_text).frames
        video_path = f"generated_video_{os.urandom(8).hex()}.mp4"
        imageio.mimsave(video_path, video_frames, fps=30) # Adjust fps as needed
        new_continuation_id = os.urandom(16).hex()
        active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Video saved to {video_path}"}
        return JSONResponse({"url": video_path, "continuation_id": new_continuation_id, "model_name": request.model_name})

    except HTTPException as http_err:
        raise http_err
    except Exception as e:
        logging.error(f"Internal server error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

# Adding exception handling for Pydantic validation
@app.exception_handler(ValidationError)
async def validation_exception_handler(request, exc):
    logging.error(f"Validation Error: {exc}")
    return JSONResponse({"detail": exc.errors()}, status_code=422)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)