shlomihod commited on
Commit
dd27b4a
·
1 Parent(s): 87ef96c

improve support with local huggingface models

Browse files
Files changed (2) hide show
  1. app.py +14 -5
  2. requirements.txt +2 -0
app.py CHANGED
@@ -306,15 +306,23 @@ def build_api_call_function(model):
306
 
307
  return output, length
308
 
309
- elif model.startswith("#"):
310
  model = model[1:]
311
- pipe = pipeline("text-generation", model=model, trust_remote_code=True)
 
 
312
 
313
  async def api_call_function(prompt, generation_config):
314
  generation_config = prepare_huggingface_generation_config(generation_config)
315
- return pipe(prompt, return_text=True, **generation_config)[0][
 
316
  "generated_text"
317
  ]
 
 
 
 
 
318
 
319
  else:
320
 
@@ -904,8 +912,9 @@ def main():
904
  st.error(e)
905
  st.stop()
906
  st.markdown(escape_markdown(output))
907
- with st.expander("Stats"):
908
- st.metric("#Tokens", length)
 
909
 
910
 
911
  if __name__ == "__main__":
 
306
 
307
  return output, length
308
 
309
+ elif model.startswith("@"):
310
  model = model[1:]
311
+ pipe = pipeline(
312
+ "text-generation", model=model, trust_remote_code=True, device_map="auto"
313
+ )
314
 
315
  async def api_call_function(prompt, generation_config):
316
  generation_config = prepare_huggingface_generation_config(generation_config)
317
+
318
+ output = pipe(prompt, return_text=True, **generation_config)[0][
319
  "generated_text"
320
  ]
321
+ output = output[len(prompt) :]
322
+
323
+ length = None
324
+
325
+ return output, length
326
 
327
  else:
328
 
 
912
  st.error(e)
913
  st.stop()
914
  st.markdown(escape_markdown(output))
915
+ if length is not None:
916
+ with st.expander("Stats"):
917
+ st.metric("#Tokens", length)
918
 
919
 
920
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  aiohttp
2
  cohere
3
  datasets
@@ -11,4 +12,5 @@ scikit-learn
11
  spacy
12
  streamlit
13
  tenacity
 
14
  transformers
 
1
+ accelerate
2
  aiohttp
3
  cohere
4
  datasets
 
12
  spacy
13
  streamlit
14
  tenacity
15
+ torch
16
  transformers