Abir66 commited on
Commit
e18c8d1
·
verified ·
1 Parent(s): c52c03c

Update backend_model.py

Browse files
Files changed (1) hide show
  1. backend_model.py +3 -4
backend_model.py CHANGED
@@ -110,9 +110,8 @@ def load_model_and_tokenizer(model_architecture, model_path):
110
  model = CodeBERTBinaryClassifier(base_model)
111
  # model = model.to(device)
112
 
113
- # Load the model
114
- # model = CodeBERTBinaryClassifier(base_model)
115
- model.load_state_dict(torch.load(model_path))
116
- model = model.to(device)
117
 
118
  return model, tokenizer
 
110
  model = CodeBERTBinaryClassifier(base_model)
111
  # model = model.to(device)
112
 
113
+ map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
114
+ model.load_state_dict(torch.load(model_path, map_location=map_location))
115
+ model = model.to(map_location)
 
116
 
117
  return model, tokenizer