yeq6x commited on
Commit
00e1057
·
1 Parent(s): 247ffab
Files changed (1) hide show
  1. app.py +20 -23
app.py CHANGED
@@ -15,16 +15,6 @@ import utils
15
 
16
  import spaces
17
 
18
- image_size = 112
19
- batch_size = 32
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
-
22
- models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
23
- {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
24
- {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
25
- {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
26
- models = []
27
-
28
  def load_model(model_path, feature_dim):
29
  model = AutoencoderModule(feature_dim=feature_dim)
30
  state_dict = torch.load(model_path)
@@ -42,10 +32,25 @@ def load_model(model_path, feature_dim):
42
  model.eval()
43
 
44
  model.to(device)
45
- print("Model loaded successfully.")
46
  return model
47
 
48
- @spaces.GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
50
  filenames = load_filenames(img_dir)
51
  train_X = filenames[:1000]
@@ -62,7 +67,6 @@ def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
62
  print("Data loaded successfully.")
63
  return x
64
 
65
- @spaces.GPU
66
  def load_keypoints(img_dir="resources/trainB/", image_size=112, batch_size=32):
67
  filenames = load_filenames(img_dir)
68
  train_X = filenames[:1000]
@@ -98,7 +102,7 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
98
  if uploaded_image is not None:
99
  uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
100
  else:
101
- uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
102
  target_feature_map, _ = model(uploaded_image)
103
  img = torch.cat((x, uploaded_image))
104
  feature_map = torch.cat((feature_map, target_feature_map))
@@ -135,7 +139,7 @@ def setup(model_info, input_image=None):
135
 
136
  index = models_info.index(model_info)
137
  model = models[index]
138
-
139
  x = load_data()
140
  test_imgs, points = load_keypoints()
141
 
@@ -195,14 +199,7 @@ with gr.Blocks() as demo:
195
  inputs=[input_image],
196
  )
197
 
198
-
199
- if __name__ == "__main__":
200
- for model_info in models_info:
201
- model_name = model_info["name"]
202
- feature_dim = model_info["feature_dim"]
203
- model_path = f"checkpoints/{model_name}"
204
- model = load_model(model_path, feature_dim)
205
- models.append(model)
206
 
207
  setup(models_info[0])
208
 
 
15
 
16
  import spaces
17
 
 
 
 
 
 
 
 
 
 
 
18
  def load_model(model_path, feature_dim):
19
  model = AutoencoderModule(feature_dim=feature_dim)
20
  state_dict = torch.load(model_path)
 
32
  model.eval()
33
 
34
  model.to(device)
35
+ print(f"{model_path} loaded successfully.")
36
  return model
37
 
38
+ image_size = 112
39
+ batch_size = 32
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ models_info = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
43
+ {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
44
+ {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
45
+ {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
46
+ models = []
47
+ for model_info in models_info:
48
+ model_name = model_info["name"]
49
+ feature_dim = model_info["feature_dim"]
50
+ model_path = f"checkpoints/{model_name}"
51
+ model = load_model(model_path, feature_dim)
52
+ models.append(model)
53
+
54
  def load_data(img_dir="resources/trainB/", image_size=112, batch_size=256):
55
  filenames = load_filenames(img_dir)
56
  train_X = filenames[:1000]
 
67
  print("Data loaded successfully.")
68
  return x
69
 
 
70
  def load_keypoints(img_dir="resources/trainB/", image_size=112, batch_size=32):
71
  filenames = load_filenames(img_dir)
72
  train_X = filenames[:1000]
 
102
  if uploaded_image is not None:
103
  uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
104
  else:
105
+ uploaded_image = torch.zeros(1, 3, image_size, image_size).to(device)
106
  target_feature_map, _ = model(uploaded_image)
107
  img = torch.cat((x, uploaded_image))
108
  feature_map = torch.cat((feature_map, target_feature_map))
 
139
 
140
  index = models_info.index(model_info)
141
  model = models[index]
142
+
143
  x = load_data()
144
  test_imgs, points = load_keypoints()
145
 
 
199
  inputs=[input_image],
200
  )
201
 
202
+ if __name__ == "__main__":
 
 
 
 
 
 
 
203
 
204
  setup(models_info[0])
205