Update app.py
Browse files
app.py
CHANGED
@@ -163,54 +163,10 @@ class GradioWindow():
|
|
163 |
outputs=[augmented_img, generated_prompt],
|
164 |
)
|
165 |
|
166 |
-
def download_weights(self):
|
167 |
-
models = [
|
168 |
-
"https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/",
|
169 |
-
"https://huggingface.co/llava-hf/llava-1.5-7b-hf",
|
170 |
-
"https://huggingface.co/danulkin/llama",
|
171 |
-
]
|
172 |
-
|
173 |
-
destinations = [
|
174 |
-
"Garage/models/checkpoints/ppt-v2-1",
|
175 |
-
"Garage/models/checkpoints/llava-1.5-7b-hf",
|
176 |
-
"Garage/models/checkpoints/llama-3-8b-Instruct",
|
177 |
-
]
|
178 |
-
|
179 |
-
if not os.path.exists("Garage/models/checkpoints"):
|
180 |
-
os.makedirs("Garage/models/checkpoints")
|
181 |
-
|
182 |
-
for model, destination in zip(models, destinations):
|
183 |
-
# Git LFS clone command
|
184 |
-
command = ["git", "lfs", "clone", model, destination]
|
185 |
-
try:
|
186 |
-
result = subprocess.run(command, check=True, text=True, capture_output=True)
|
187 |
-
print("Command Output:", result.stdout)
|
188 |
-
|
189 |
-
except subprocess.CalledProcessError as e:
|
190 |
-
print(f"Error: {e}")
|
191 |
-
print("Command Output:", e.output)
|
192 |
-
|
193 |
-
models = [
|
194 |
-
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
195 |
-
"https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
|
196 |
-
]
|
197 |
-
|
198 |
-
destinations = [
|
199 |
-
"Garage/models/checkpoints/GroundedSegmentAnything/sam_vit_h_4b8939.pth",
|
200 |
-
"Garage/models/checkpoints/GroundedSegmentAnything/groundingdino_swint_ogc.pth",
|
201 |
-
]
|
202 |
-
if not os.path.exists("Garage/models/checkpoints/GroundedSegmentAnything"):
|
203 |
-
os.makedirs("Garage/models/checkpoints/GroundedSegmentAnything")
|
204 |
-
|
205 |
-
for model, destination in zip(models, destinations):
|
206 |
-
if not os.path.exists(destination):
|
207 |
-
urllib.request.urlretrieve(model, destination)
|
208 |
-
print(f"Downloaded {model} to {destination}")
|
209 |
-
else:
|
210 |
-
print(f"Model {model} already exists")
|
211 |
|
212 |
def setup_model(self) -> SamPredictor:
|
213 |
-
self.sam =
|
|
|
214 |
self.sam.to(device=self.device)
|
215 |
self.sam_predictor = SamPredictor(self.sam)
|
216 |
|
|
|
163 |
outputs=[augmented_img, generated_prompt],
|
164 |
)
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
def setup_model(self) -> SamPredictor:
|
168 |
+
self.sam = sam_model_registry["vit_h"]()
|
169 |
+
self.sam.load_state_dict(torch.utils.model_zoo.load_url(MODEL_DICT["vit_h"]))
|
170 |
self.sam.to(device=self.device)
|
171 |
self.sam_predictor = SamPredictor(self.sam)
|
172 |
|