yeq6x commited on
Commit
8d302a0
·
1 Parent(s): fddf24a
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -93,8 +93,6 @@ model_index = 0
93
  # ヒートマップの生成関数
94
  @spaces.GPU
95
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
96
- global model, mean_vector_list
97
-
98
  if type(uploaded_image) == str:
99
  uploaded_image = Image.open(uploaded_image)
100
  if type(source_num) == str:
@@ -104,14 +102,14 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
104
  if type(y_coords) == str:
105
  y_coords = int(y_coords)
106
 
107
- dec5, _ = model(x)
108
  feature_map = dec5
109
  # アップロード画像の前処理
110
  if uploaded_image is not None:
111
  uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
112
  else:
113
  uploaded_image = torch.zeros(1, 3, image_size, image_size).to(device)
114
- target_feature_map, _ = model(uploaded_image)
115
  img = torch.cat((x, uploaded_image))
116
  feature_map = torch.cat((feature_map, target_feature_map))
117
 
 
93
  # ヒートマップの生成関数
94
  @spaces.GPU
95
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
 
 
96
  if type(uploaded_image) == str:
97
  uploaded_image = Image.open(uploaded_image)
98
  if type(source_num) == str:
 
102
  if type(y_coords) == str:
103
  y_coords = int(y_coords)
104
 
105
+ dec5, _ = models[model_index](x)
106
  feature_map = dec5
107
  # アップロード画像の前処理
108
  if uploaded_image is not None:
109
  uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
110
  else:
111
  uploaded_image = torch.zeros(1, 3, image_size, image_size).to(device)
112
+ target_feature_map, _ = models[model_index](uploaded_image)
113
  img = torch.cat((x, uploaded_image))
114
  feature_map = torch.cat((feature_map, target_feature_map))
115