outsu commited on
Commit
317198d
·
verified ·
1 Parent(s): 5b43647

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gradio as gr
3
  import torch
4
  from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
@@ -9,8 +10,22 @@ from huggingface_hub import hf_hub_download, list_repo_files
9
  MODELS_DIR = "./TeLVE/models"
10
  TOKENIZER_PATH = "./TeLVE/tokenizer"
11
  HF_REPO_ID = "outsu/TeLVE"
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def list_available_models():
15
  """List all .pth models in the models directory"""
16
  if not os.path.exists(MODELS_DIR):
@@ -27,14 +42,19 @@ def get_hf_model_list():
27
  return []
28
 
29
  def download_missing_models():
30
- """Download missing models from HuggingFace"""
31
  if not os.path.exists(MODELS_DIR):
32
  os.makedirs(MODELS_DIR)
33
 
 
 
34
  local_models = set(list_available_models())
35
  hf_models = set(get_hf_model_list())
36
 
37
- for model in hf_models - local_models:
 
 
 
38
  try:
39
  print(f"Downloading missing model: {model}")
40
  hf_hub_download(
@@ -43,8 +63,21 @@ def download_missing_models():
43
  local_dir=os.path.dirname(MODELS_DIR),
44
  local_dir_use_symlinks=False
45
  )
 
46
  except Exception as e:
47
  print(f"Error downloading {model}: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def generate_description(image, model_name):
50
  """Generate image caption using selected model"""
@@ -96,6 +129,8 @@ def create_interface():
96
  return interface
97
 
98
  if __name__ == "__main__":
 
 
99
  print("Checking for missing models...")
100
  download_missing_models()
101
  demo = create_interface()
 
1
  import os
2
+ import json
3
  import gradio as gr
4
  import torch
5
  from TeLVE.imagine import ImageCaptioningModel, load_model, generate_caption
 
10
  MODELS_DIR = "./TeLVE/models"
11
  TOKENIZER_PATH = "./TeLVE/tokenizer"
12
  HF_REPO_ID = "outsu/TeLVE"
13
+ MODEL_STATE_FILE = "./TeLVE/model_state.json"
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ def load_model_state():
17
+ """Load the model state from JSON file"""
18
+ if os.path.exists(MODEL_STATE_FILE):
19
+ with open(MODEL_STATE_FILE, 'r') as f:
20
+ return json.load(f)
21
+ return {"downloaded_models": []}
22
+
23
+ def save_model_state(state):
24
+ """Save the model state to JSON file"""
25
+ os.makedirs(os.path.dirname(MODEL_STATE_FILE), exist_ok=True)
26
+ with open(MODEL_STATE_FILE, 'w') as f:
27
+ json.dump(state, f)
28
+
29
  def list_available_models():
30
  """List all .pth models in the models directory"""
31
  if not os.path.exists(MODELS_DIR):
 
42
  return []
43
 
44
  def download_missing_models():
45
+ """Download missing models from HuggingFace with state management"""
46
  if not os.path.exists(MODELS_DIR):
47
  os.makedirs(MODELS_DIR)
48
 
49
+ state = load_model_state()
50
+ downloaded_models = set(state["downloaded_models"])
51
  local_models = set(list_available_models())
52
  hf_models = set(get_hf_model_list())
53
 
54
+ # Check for models that need downloading
55
+ models_to_download = (hf_models - local_models) - downloaded_models
56
+
57
+ for model in models_to_download:
58
  try:
59
  print(f"Downloading missing model: {model}")
60
  hf_hub_download(
 
63
  local_dir=os.path.dirname(MODELS_DIR),
64
  local_dir_use_symlinks=False
65
  )
66
+ downloaded_models.add(model)
67
  except Exception as e:
68
  print(f"Error downloading {model}: {str(e)}")
69
+ continue
70
+
71
+ # Update state with newly downloaded models
72
+ state["downloaded_models"] = list(downloaded_models)
73
+ save_model_state(state)
74
+
75
+ def verify_model_integrity():
76
+ """Verify that all models in state actually exist"""
77
+ state = load_model_state()
78
+ local_models = set(list_available_models())
79
+ state["downloaded_models"] = list(set(state["downloaded_models"]) & local_models)
80
+ save_model_state(state)
81
 
82
  def generate_description(image, model_name):
83
  """Generate image caption using selected model"""
 
129
  return interface
130
 
131
  if __name__ == "__main__":
132
+ print("Verifying model integrity...")
133
+ verify_model_integrity()
134
  print("Checking for missing models...")
135
  download_missing_models()
136
  demo = create_interface()