Vijish commited on
Commit
236fbce
1 Parent(s): 4ac5261

Update src/core.py

Browse files
Files changed (1) hide show
  1. src/core.py +2 -3
src/core.py CHANGED
@@ -12,7 +12,6 @@ import cv2
12
 
13
  import numpy as np
14
  import pandas as pd
15
- #import streamlit as st
16
  from PIL import Image
17
  #from streamlit_drawable_canvas import st_canvas
18
 
@@ -64,9 +63,9 @@ ENERGY_MASK_CONST = 100000.0 # large energy value for protective ma
64
  MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
65
  USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
66
 
67
- device = torch.device("cpu")
68
  model_path = "./assets/big-lama.pt"
69
- model = torch.jit.load(model_path, map_location="cpu")
70
  model = model.to(device)
71
  model.eval()
72
 
 
12
 
13
  import numpy as np
14
  import pandas as pd
 
15
  from PIL import Image
16
  #from streamlit_drawable_canvas import st_canvas
17
 
 
63
  MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
64
  USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
65
 
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
  model_path = "./assets/big-lama.pt"
68
+ model = torch.jit.load(model_path, map_location=device)
69
  model = model.to(device)
70
  model.eval()
71