import torch
import os
from concurrent.futures import ThreadPoolExecutor
from pydub import AudioSegment
import cv2
from pathlib import Path
import subprocess
from pathlib import Path
import av
import imageio
import numpy as np
from rich.progress import track
from tqdm import tqdm

import stf_alternative

import spaces


def exec_cmd(cmd):
    subprocess.run(
        cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
    )


def images2video(images, wfp, **kwargs):
    fps = kwargs.get("fps", 24)
    video_format = kwargs.get("format", "mp4")  # default is mp4 format
    codec = kwargs.get("codec", "libx264")  # default is libx264 encoding
    quality = kwargs.get("quality")  # video quality
    pixelformat = kwargs.get("pixelformat", "yuv420p")  # video pixel format
    image_mode = kwargs.get("image_mode", "rgb")
    macro_block_size = kwargs.get("macro_block_size", 2)
    ffmpeg_params = ["-crf", str(kwargs.get("crf", 18))]

    writer = imageio.get_writer(
        wfp,
        fps=fps,
        format=video_format,
        codec=codec,
        quality=quality,
        ffmpeg_params=ffmpeg_params,
        pixelformat=pixelformat,
        macro_block_size=macro_block_size,
    )

    n = len(images)
    for i in track(range(n), description="writing", transient=True):
        if image_mode.lower() == "bgr":
            writer.append_data(images[i][..., ::-1])
        else:
            writer.append_data(images[i])

    writer.close()

    # print(f':smiley: Dump to {wfp}\n', style="bold green")
    print(f"Dump to {wfp}\n")


def merge_audio_video(video_fp, audio_fp, wfp):
    if osp.exists(video_fp) and osp.exists(audio_fp):
        cmd = f"ffmpeg -i {video_fp} -i {audio_fp} -c:v copy -c:a aac {wfp} -y"
        exec_cmd(cmd)
        print(f"merge {video_fp} and {audio_fp} to {wfp}")
    else:
        print(f"video_fp: {video_fp} or audio_fp: {audio_fp} not exists!")




class STFPipeline:
    def __init__(
        self,
        stf_path: str = "/home/user/app/stf/",
        template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
        config_path: str = "front_config.json",
        checkpoint_path: str = "089.pth",
        root_path: str = "works",
        wavlm_path: str = "microsoft/wavlm-large",
        device: str = "cuda"
    ):
        self.device = device
        self.stf_path = stf_path
        self.config_path = os.path.join(stf_path, config_path)
        self.checkpoint_path = os.path.join(stf_path, checkpoint_path)
        self.work_root_path = os.path.join(stf_path, root_path)
        self.wavlm_path = wavlm_path
        self.template_video_path = template_video_path

        # 비동기적으로 모델 로딩
        self.model = self.load_model()
        self.template = self.create_template()

    @spaces.GPU(duration=120)
    def load_model(self):
        """모델을 생성하고 GPU에 할당."""
        model = stf_alternative.create_model(
            config_path=self.config_path,
            checkpoint_path=self.checkpoint_path,
            work_root_path=self.work_root_path,
            device=self.device,
            wavlm_path=self.wavlm_path
        )
        return model

    @spaces.GPU(duration=120)
    def create_template(self):
        """템플릿 생성."""
        template = stf_alternative.Template(
            model=self.model,
            config_path=self.config_path,
            template_video_path=self.template_video_path
        )
        return template

    def execute(self, audio: str) -> str:
        """오디오를 입력 받아 비디오를 생성."""
        # 폴더 생성
        Path("dubbing").mkdir(exist_ok=True)
        save_path = os.path.join("dubbing", Path(audio).stem + "--lip.mp4")

        reader = iter(self.template._get_reader(num_skip_frames=0))
        audio_segment = AudioSegment.from_file(audio)
        results = []

        # 비동기 프레임 생성
        with ThreadPoolExecutor(max_workers=4) as executor:
            try:
                gen_infer = self.template.gen_infer_concurrent(
                    executor, audio_segment, 0
                )
                for idx, (it, _) in enumerate(gen_infer):
                    frame = next(reader)
                    composed = self.template.compose(idx, frame, it)
                    results.append(it["pred"])
            except StopIteration:
                pass

        self.images_to_video(results, save_path)
        return save_path

    @staticmethod
    def images_to_video(images, output_path, fps=24):
        """이미지 배열을 비디오로 변환."""
        writer = imageio.get_writer(output_path, fps=fps, format="mp4", codec="libx264")
        for i in track(range(len(images)), description="비디오 생성 중"):
            writer.append_data(images[i])
        writer.close()
        print(f"비디오 저장 완료: {output_path}")
        
# class STFPipeline:
#     def __init__(self,
#                  stf_path: str = "/home/user/app/stf/",
#                  device: str = "cuda:0",
#                  template_video_path: str = "templates/front_one_piece_dress_nodded_cut.webm",
#                  config_path: str = "front_config.json",
#                  checkpoint_path: str = "089.pth",
#                  root_path: str = "works"
                 
#     ):
        
#         config_path = os.path.join(stf_path, config_path)
#         checkpoint_path = os.path.join(stf_path, checkpoint_path)
#         work_root_path = os.path.join(stf_path, root_path)
        
#         model = stf_alternative.create_model(
#         config_path=config_path,
#         checkpoint_path=checkpoint_path,
#         work_root_path=work_root_path,
#         device=device,
#         wavlm_path="microsoft/wavlm-large",
#         )
#         self.template = stf_alternative.Template(
#         model=model,
#         config_path=config_path,
#         template_video_path=template_video_path,
#         )
    

#     def execute(self, audio: str):
#         Path("dubbing").mkdir(exist_ok=True)
#         save_path = os.path.join("dubbing", Path(audio).stem+"--lip.mp4")
#         reader = iter(self.template._get_reader(num_skip_frames=0))
#         audio_segment = AudioSegment.from_file(audio)
#         pivot = 0
#         results = []
#         with ThreadPoolExecutor(4) as p:
#             try:

#                 gen_infer = self.template.gen_infer_concurrent(
#                     p,
#                     audio_segment,
#                     pivot,
#                 )
#                 for idx, (it, chunk) in enumerate(gen_infer, pivot):
#                     frame = next(reader)
#                     composed = self.template.compose(idx, frame, it)
#                     frame_name = f"{idx}".zfill(5)+".jpg"
#                     results.append(it['pred'])
#                 pivot = idx + 1
#             except StopIteration as e:
#                 pass
            
#         images2video(results, save_path)
                                
#         return save_path