banao-tech commited on
Commit
3f593e3
·
verified ·
1 Parent(s): 51cf94d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -4
main.py CHANGED
@@ -24,12 +24,20 @@ from transformers import AutoProcessor, AutoModelForCausalLM
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
- # main.py (updated YOLO loading)
28
- from utils import get_yolo_model # Ensure this import exists
29
 
30
- # Load YOLO model correctly
 
 
 
 
31
  yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt")
32
- yolo_model.to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
33
 
34
  # Load caption model and processor
35
  try:
 
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
 
 
27
 
28
+ # main.py (YOLO loading fix)
29
+ from utils import get_yolo_model
30
+ import torch
31
+
32
+ # Load YOLO model using official method
33
  yolo_model = get_yolo_model(model_path="weights/icon_detect/best.pt")
34
+
35
+ # Handle device placement
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ if str(device) == "cuda":
38
+ yolo_model = yolo_model.cuda()
39
+ else:
40
+ yolo_model = yolo_model.cpu()
41
 
42
  # Load caption model and processor
43
  try: