import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import numpy as np

dog_breeds = ["Afghan_Hound", "African_Hunting_Dog", "Airedale", "American_Staffordshire_Terrier",
              "Appenzeller", "Australian_Terrier", "Bedlington_Terrier", "Bernese_Mountain_Dog", "Bichon_Frise",
              "Blenheim_Spaniel", "Border_Collie", "Border_Terrier", "Boston_Bull", "Bouvier_Des_Flandres",
              "Brabancon_Griffon", "Brittany_Spaniel", "Cardigan", "Chesapeake_Bay_Retriever",
              "Chihuahua", "Dachshund", "Dandie_Dinmont", "Doberman", "English_Foxhound", "English_Setter",
              "English_Springer", "EntleBucher", "Eskimo_Dog", "French_Bulldog", "German_Shepherd",
              "German_Short-Haired_Pointer", "Gordon_Setter", "Great_Dane", "Great_Pyrenees",
              "Greater_Swiss_Mountain_Dog","Havanese", "Ibizan_Hound", "Irish_Setter", "Irish_Terrier",
              "Irish_Water_Spaniel", "Irish_Wolfhound", "Italian_Greyhound", "Japanese_Spaniel",
              "Kerry_Blue_Terrier", "Labrador_Retriever", "Lakeland_Terrier", "Leonberg", "Lhasa",
              "Maltese_Dog", "Mexican_Hairless", "Newfoundland", "Norfolk_Terrier", "Norwegian_Elkhound",
              "Norwich_Terrier", "Old_English_Sheepdog", "Pekinese", "Pembroke", "Pomeranian",
              "Rhodesian_Ridgeback", "Rottweiler", "Saint_Bernard", "Saluki", "Samoyed",
              "Scotch_Terrier", "Scottish_Deerhound", "Sealyham_Terrier", "Shetland_Sheepdog", "Shiba_Inu",
              "Shih-Tzu", "Siberian_Husky", "Staffordshire_Bullterrier", "Sussex_Spaniel",
              "Tibetan_Mastiff", "Tibetan_Terrier", "Walker_Hound", "Weimaraner",
              "Welsh_Springer_Spaniel", "West_Highland_White_Terrier", "Yorkshire_Terrier",
              "Affenpinscher", "Basenji", "Basset", "Beagle", "Black-and-Tan_Coonhound", "Bloodhound",
              "Bluetick", "Borzoi", "Boxer", "Briard", "Bull_Mastiff", "Cairn", "Chow", "Clumber",
              "Cocker_Spaniel", "Collie", "Curly-Coated_Retriever", "Dhole", "Dingo",
              "Flat-Coated_Retriever", "Giant_Schnauzer", "Golden_Retriever", "Groenendael", "Keeshond",
              "Kelpie", "Komondor", "Kuvasz", "Malamute", "Malinois", "Miniature_Pinscher",
              "Miniature_Poodle", "Miniature_Schnauzer", "Otterhound", "Papillon", "Pug", "Redbone",
              "Schipperke", "Silky_Terrier", "Soft-Coated_Wheaten_Terrier", "Standard_Poodle",
              "Standard_Schnauzer", "Toy_Poodle", "Toy_Terrier", "Vizsla", "Whippet",
              "Wire-Haired_Fox_Terrier"]


class MorphologicalFeatureExtractor(nn.Module):

    def __init__(self, in_features):
        super().__init__()

        # 基礎特徵維度設置
        self.reduced_dim = in_features // 4
        self.spatial_size = max(7, int(np.sqrt(self.reduced_dim // 64)))

        # 1. 特徵空間轉換器:將一維特徵轉換為二維空間表示
        self.dimension_transformer = nn.Sequential(
            nn.Linear(in_features, self.spatial_size * self.spatial_size * 64),
            nn.LayerNorm(self.spatial_size * self.spatial_size * 64),
            nn.ReLU()
        )

        # 2. 形態特徵分析器:分析具體的形態特徵
        self.morphological_analyzers = nn.ModuleDict({
            # 體型分析器:分析整體比例和大小
            'body_proportion': nn.Sequential(
                # 使用大卷積核捕捉整體體型特徵
                nn.Conv2d(64, 128, kernel_size=7, padding=3),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                # 使用較小的卷積核精煉特徵
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU()
            ),

            # 頭部特徵分析器:關注耳朵、臉部等
            'head_features': nn.Sequential(
                # 中等大小的卷積核,適合分析頭部結構
                nn.Conv2d(64, 128, kernel_size=5, padding=2),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                # 小卷積核捕捉細節
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU()
            ),

            # 尾部特徵分析器
            'tail_features': nn.Sequential(
                nn.Conv2d(64, 128, kernel_size=5, padding=2),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU()
            ),

            # 毛髮特徵分析器:分析毛髮長度、質地等
            'fur_features': nn.Sequential(
                # 使用多個小卷積核捕捉毛髮紋理
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU()
            ),

            # 顏色特徵分析器:分析顏色分佈
            'color_pattern': nn.Sequential(
                # 第一層:捕捉基本顏色分布
                nn.Conv2d(64, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),

                # 第二層:分析顏色模式和花紋
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),

                # 第三層:整合顏色信息
                nn.Conv2d(128, 128, kernel_size=1),
                nn.BatchNorm2d(128),
                nn.ReLU()
            )
        })

        # 3. 特徵注意力機制:動態關注不同特徵
        self.feature_attention = nn.MultiheadAttention(
            embed_dim=128,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )

        # 4. 特徵關係分析器:分析不同特徵之間的關係
        self.relation_analyzer = nn.Sequential(
            nn.Linear(128 * 5, 256),  # 4個特徵分析器的輸出
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU()
        )

        # 5. 特徵整合器:將所有特徵智能地組合在一起
        self.feature_integrator = nn.Sequential(
            nn.Linear(128 * 6, in_features),  # 5個原始特徵 + 1個關係特徵
            nn.LayerNorm(in_features),
            nn.ReLU()
        )

    def forward(self, x):
        batch_size = x.size(0)

        # 1. 將特徵轉換為空間形式
        spatial_features = self.dimension_transformer(x).view(
            batch_size, 64, self.spatial_size, self.spatial_size
        )

        # 2. 分析各種形態特徵
        morphological_features = {}
        for name, analyzer in self.morphological_analyzers.items():
            # 提取特定形態特徵
            features = analyzer(spatial_features)
            # 使用自適應池化統一特徵大小
            pooled_features = F.adaptive_avg_pool2d(features, (1, 1))
            # 重塑特徵為向量形式
            morphological_features[name] = pooled_features.view(batch_size, -1)

        # 3. 特徵注意力處理
        # 將所有特徵堆疊成序列
        stacked_features = torch.stack(list(morphological_features.values()), dim=1)
        # 應用注意力機制
        attended_features, _ = self.feature_attention(
            stacked_features, stacked_features, stacked_features
        )

        # 4. 分析特徵之間的關係
        # 將所有特徵連接起來
        combined_features = torch.cat(list(morphological_features.values()), dim=1)
        # 提取特徵間的關係
        relation_features = self.relation_analyzer(combined_features)

        # 5. 特徵整合
        # 將原始特徵和關係特徵結合
        final_features = torch.cat([
            *morphological_features.values(),
            relation_features
        ], dim=1)

        # 6. 最終整合
        integrated_features = self.feature_integrator(final_features)

        # 添加殘差連接
        return integrated_features + x


class MultiHeadAttention(nn.Module):

    def __init__(self, in_dim, num_heads=8):
        """
        Initializes the MultiHeadAttention module.
        Args:
            in_dim (int): Dimension of the input features.
            num_heads (int): Number of attention heads. Defaults to 8.
        """
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = max(1, in_dim // num_heads)  
        self.scaled_dim = self.head_dim * num_heads  
        self.fc_in = nn.Linear(in_dim, self.scaled_dim)  
        self.query = nn.Linear(self.scaled_dim, self.scaled_dim)  # Query projection
        self.key = nn.Linear(self.scaled_dim, self.scaled_dim)  # Key projection
        self.value = nn.Linear(self.scaled_dim, self.scaled_dim)  # Value projection
        self.fc_out = nn.Linear(self.scaled_dim, in_dim)  # Linear layer to project output back to in_dim

    def forward(self, x):
        """
        Forward pass for multi-head attention mechanism.
        Args:
            x (Tensor): Input tensor of shape (batch_size, input_dim).
            x 是 (N,D), N:批次大小, D:輸入特徵維度
        Returns:
            Tensor: Output tensor after applying attention mechanism.
        """
        N = x.shape[0]  # Batch size
        x = self.fc_in(x)  # Project input to scaled_dim
        q = self.query(x).view(N, self.num_heads, self.head_dim)  # Compute queries
        k = self.key(x).view(N, self.num_heads, self.head_dim)  # Compute keys
        v = self.value(x).view(N, self.num_heads, self.head_dim)  # Compute values

        # Calculate attention scores
        energy = torch.einsum("nqd,nkd->nqk", [q, k])  
        attention = F.softmax(energy / (self.head_dim ** 0.5), dim=2)  # Apply softmax with scaling

        # Compute weighted sum of values based on attention scores
        out = torch.einsum("nqk,nvd->nqd", [attention, v]) 
        out = out.reshape(N, self.scaled_dim)  # Concatenate all heads
        out = self.fc_out(out)  # Project back to original input dimension
        return out


class BaseModel(nn.Module):

    def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
        super().__init__()
        self.device = device

        # 1. Initialize backbone
        self.backbone = timm.create_model(
                'convnextv2_base',
                pretrained=True,
                num_classes=0
        )

        # 2. 使用測試數據來確定實際的特徵維度
        with torch.no_grad():  
            dummy_input = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy_input)

            if len(features.shape) > 2:
                features = features.mean([-2, -1])

            self.feature_dim = features.shape[1]

        print(f"Feature Dimension from V2 backbone: {self.feature_dim}")

        # 3. Setup multi-head attention layer
        self.num_heads = max(1, min(8, self.feature_dim // 64))
        self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)

        # 4. Setup classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Dropout(0.3),
            nn.Linear(self.feature_dim, num_classes)
        )

        self.morphological_extractor = MorphologicalFeatureExtractor(
            in_features=self.feature_dim
        )

        self.feature_fusion = nn.Sequential(
            nn.Linear(self.feature_dim * 3, self.feature_dim),  
            nn.LayerNorm(self.feature_dim),
            nn.ReLU(),
            nn.Linear(self.feature_dim, self.feature_dim),
            nn.LayerNorm(self.feature_dim),
            nn.ReLU()
        )

    def forward(self, x):
        """
        Forward propagation process, combining V2's FCCA and multi-head attention mechanism
        Args:
            x (Tensor): Input image tensor of shape [batch_size, channels, height, width]
        Returns:
            Tuple[Tensor, Tensor]: Classification logits and attention features
        """
        x = x.to(self.device)

        # 1. Extract base features
        features = self.backbone(x)
        if len(features.shape) > 2:
            features = features.mean([-2, -1])

        # 2. Extract morphological features (including all detail features)
        morphological_features = self.morphological_extractor(features)

        # 3. Feature fusion (note dimension alignment with new fusion layer)
        combined_features = torch.cat([
            features,  # Original features
            morphological_features,  # Morphological features
            features * morphological_features  # Feature interaction information
        ], dim=1)
        fused_features = self.feature_fusion(combined_features)

        # 4. Apply attention mechanism
        attended_features = self.attention(fused_features)

        # 5. Final classifier
        logits = self.classifier(attended_features)

        return logits, attended_features