marksverdhei commited on
Commit
c4e154c
·
1 Parent(s): 6b30d5d

Use correct device

Browse files
Files changed (1) hide show
  1. views.py +3 -3
views.py CHANGED
@@ -22,7 +22,7 @@ def diffs(embeddings: np.ndarray, corrector, encoder: PreTrainedModel, tokenizer
22
  st.markdown("This application lets you freely explore to which extent that property applies to embedding inversion models given the other factors of inaccuracy")
23
 
24
  generated_sentence = ""
25
-
26
 
27
  with st.form(key="foo") as form:
28
  submit_button = st.form_submit_button("Synthesize")
@@ -35,10 +35,10 @@ def diffs(embeddings: np.ndarray, corrector, encoder: PreTrainedModel, tokenizer
35
  st.latex("=")
36
 
37
  if submit_button:
38
- v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer, device=encoder.device).to("cpu")
39
  v4 = v1 - v2 + v3
40
  generated_sentence, = vec2text.invert_embeddings(
41
- embeddings=v4.unsqueeze(0).cuda(),
42
  corrector=corrector,
43
  num_steps=20,
44
  )
 
22
  st.markdown("This application lets you freely explore to which extent that property applies to embedding inversion models given the other factors of inaccuracy")
23
 
24
  generated_sentence = ""
25
+ device = encoder.device
26
 
27
  with st.form(key="foo") as form:
28
  submit_button = st.form_submit_button("Synthesize")
 
35
  st.latex("=")
36
 
37
  if submit_button:
38
+ v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer, device=encoder.device).to(device)
39
  v4 = v1 - v2 + v3
40
  generated_sentence, = vec2text.invert_embeddings(
41
+ embeddings=v4.unsqueeze(0).to(device),
42
  corrector=corrector,
43
  num_steps=20,
44
  )