anvilarth commited on
Commit
7a5d391
·
verified ·
1 Parent(s): c99b3ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -46
app.py CHANGED
@@ -163,54 +163,10 @@ class GradioWindow():
163
  outputs=[augmented_img, generated_prompt],
164
  )
165
 
166
- def download_weights(self):
167
- models = [
168
- "https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/",
169
- "https://huggingface.co/llava-hf/llava-1.5-7b-hf",
170
- "https://huggingface.co/danulkin/llama",
171
- ]
172
-
173
- destinations = [
174
- "Garage/models/checkpoints/ppt-v2-1",
175
- "Garage/models/checkpoints/llava-1.5-7b-hf",
176
- "Garage/models/checkpoints/llama-3-8b-Instruct",
177
- ]
178
-
179
- if not os.path.exists("Garage/models/checkpoints"):
180
- os.makedirs("Garage/models/checkpoints")
181
-
182
- for model, destination in zip(models, destinations):
183
- # Git LFS clone command
184
- command = ["git", "lfs", "clone", model, destination]
185
- try:
186
- result = subprocess.run(command, check=True, text=True, capture_output=True)
187
- print("Command Output:", result.stdout)
188
-
189
- except subprocess.CalledProcessError as e:
190
- print(f"Error: {e}")
191
- print("Command Output:", e.output)
192
-
193
- models = [
194
- "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
195
- "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
196
- ]
197
-
198
- destinations = [
199
- "Garage/models/checkpoints/GroundedSegmentAnything/sam_vit_h_4b8939.pth",
200
- "Garage/models/checkpoints/GroundedSegmentAnything/groundingdino_swint_ogc.pth",
201
- ]
202
- if not os.path.exists("Garage/models/checkpoints/GroundedSegmentAnything"):
203
- os.makedirs("Garage/models/checkpoints/GroundedSegmentAnything")
204
-
205
- for model, destination in zip(models, destinations):
206
- if not os.path.exists(destination):
207
- urllib.request.urlretrieve(model, destination)
208
- print(f"Downloaded {model} to {destination}")
209
- else:
210
- print(f"Model {model} already exists")
211
 
212
  def setup_model(self) -> SamPredictor:
213
- self.sam = build_sam(checkpoint='sam_vit_h_4b8939.pth')
 
214
  self.sam.to(device=self.device)
215
  self.sam_predictor = SamPredictor(self.sam)
216
 
 
163
  outputs=[augmented_img, generated_prompt],
164
  )
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  def setup_model(self) -> SamPredictor:
168
+ self.sam = sam_model_registry["vit_h"]()
169
+ self.sam.load_state_dict(torch.utils.model_zoo.load_url(MODEL_DICT["vit_h"]))
170
  self.sam.to(device=self.device)
171
  self.sam_predictor = SamPredictor(self.sam)
172