marksverdhei commited on
Commit
9c88742
·
1 Parent(s): d1f3499

Add cpu support

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -14,6 +14,10 @@ from streamlit_plotly_events import plotly_events
14
  import plotly.graph_objects as go
15
  import logging
16
  import utils
 
 
 
 
17
  # Activate tqdm with pandas
18
  tqdm.pandas()
19
 
@@ -54,7 +58,7 @@ df = load_data()
54
  # Caching the model and tokenizer to avoid reloading
55
  @st.cache_resource
56
  def load_model_and_tokenizer():
57
- encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to("cuda")
58
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
59
  return encoder, tokenizer
60
 
 
14
  import plotly.graph_objects as go
15
  import logging
16
  import utils
17
+
18
+ use_cpu = torch.cuda.is_available()
19
+ device = "cpu" if use_cpu else "cuda"
20
+
21
  # Activate tqdm with pandas
22
  tqdm.pandas()
23
 
 
58
  # Caching the model and tokenizer to avoid reloading
59
  @st.cache_resource
60
  def load_model_and_tokenizer():
61
+ encoder = AutoModel.from_pretrained("sentence-transformers/gtr-t5-base").encoder.to(device)
62
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/gtr-t5-base")
63
  return encoder, tokenizer
64