DawnC commited on
Commit
83bc690
·
1 Parent(s): 282a8a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -7
app.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
  import time
7
  import traceback
8
  import spaces
9
- from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights
10
  from torchvision.ops import nms, box_iou
11
  import torch.nn.functional as F
12
  from torchvision import transforms
@@ -98,29 +98,61 @@ class MultiHeadAttention(nn.Module):
98
  return out
99
 
100
  class BaseModel(nn.Module):
 
101
  def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
102
  super().__init__()
103
  self.device = device
104
- self.backbone = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)
105
- self.feature_dim = self.backbone.classifier[1].in_features
106
- self.backbone.classifier = nn.Identity()
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  self.num_heads = max(1, min(8, self.feature_dim // 64))
109
  self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
110
 
 
111
  self.classifier = nn.Sequential(
112
  nn.LayerNorm(self.feature_dim),
113
  nn.Dropout(0.3),
114
  nn.Linear(self.feature_dim, num_classes)
115
  )
116
 
117
- self.to(device)
118
-
119
  def forward(self, x):
 
 
 
 
 
 
 
120
  x = x.to(self.device)
 
 
121
  features = self.backbone(x)
 
 
 
 
 
 
 
 
122
  attended_features = self.attention(features)
 
 
123
  logits = self.classifier(attended_features)
 
124
  return logits, attended_features
125
 
126
 
@@ -179,7 +211,7 @@ class ModelManager:
179
  ).to(self.device)
180
 
181
  checkpoint = torch.load(
182
- '124_best_model_dog.pth',
183
  map_location=self.device # 確保checkpoint加載到正確的設備
184
  )
185
  self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)
 
6
  import time
7
  import traceback
8
  import spaces
9
+ from torchvision.models import convnext_base, ConvNeXt_Base_Weights
10
  from torchvision.ops import nms, box_iou
11
  import torch.nn.functional as F
12
  from torchvision import transforms
 
98
  return out
99
 
100
  class BaseModel(nn.Module):
101
+
102
  def __init__(self, num_classes, device='cuda' if torch.cuda.is_available() else 'cpu'):
103
  super().__init__()
104
  self.device = device
 
 
 
105
 
106
+ # 1. 初始化 backbone
107
+ self.backbone = convnext_base(weights=ConvNeXt_Base_Weights.IMAGENET1K_V1)
108
+ self.backbone.classifier = nn.Identity() # 移除原始分類器
109
+
110
+ # 2. 使用測試數據確定實際的特徵維度
111
+ with torch.no_grad(): # 不需要計算梯度
112
+ dummy_input = torch.randn(1, 3, 224, 224) # 創建示例輸入
113
+ features = self.backbone(dummy_input)
114
+ if len(features.shape) > 2: # 如果特徵是多維的
115
+ features = features.mean([-2, -1]) # 進行全局平均池化
116
+ self.feature_dim = features.shape[1] # 獲取正確的特徵維度
117
+
118
+ print(f"Feature Dim: {self.feature_dim}") # 幫助調試
119
+
120
+ # 3. 設置多頭注意力層
121
  self.num_heads = max(1, min(8, self.feature_dim // 64))
122
  self.attention = MultiHeadAttention(self.feature_dim, num_heads=self.num_heads)
123
 
124
+ # 4. 設置分類器
125
  self.classifier = nn.Sequential(
126
  nn.LayerNorm(self.feature_dim),
127
  nn.Dropout(0.3),
128
  nn.Linear(self.feature_dim, num_classes)
129
  )
130
 
 
 
131
  def forward(self, x):
132
+ """
133
+ 模型的前向傳播過程
134
+ Args:
135
+ x (Tensor): 輸入圖像張量,形狀為 [batch_size, channels, height, width]
136
+ Returns:
137
+ Tuple[Tensor, Tensor]: 分類邏輯值和注意力特徵
138
+ """
139
  x = x.to(self.device)
140
+
141
+ # 1. 提取基礎特徵
142
  features = self.backbone(x)
143
+
144
+ # 2. 處理特徵維度
145
+ if len(features.shape) > 2:
146
+ # 如果特徵維度是 [batch_size, channels, height, width]
147
+ # 轉換為 [batch_size, channels]
148
+ features = features.mean([-2, -1]) # 使用全局平均池化
149
+
150
+ # 3. 應用注意力機制
151
  attended_features = self.attention(features)
152
+
153
+ # 4. 最終分類
154
  logits = self.classifier(attended_features)
155
+
156
  return logits, attended_features
157
 
158
 
 
211
  ).to(self.device)
212
 
213
  checkpoint = torch.load(
214
+ 'ConvNextBase_best_model_dog.pth',
215
  map_location=self.device # 確保checkpoint加載到正確的設備
216
  )
217
  self._breed_model.load_state_dict(checkpoint['base_model'], strict=False)