shlomihod commited on
Commit
87ef96c
·
1 Parent(s): c7feace

add local models

Browse files
Files changed (2) hide show
  1. app.py +52 -30
  2. requirements.txt +1 -0
app.py CHANGED
@@ -30,7 +30,7 @@ from sklearn.metrics import (
30
  from sklearn.model_selection import StratifiedShuffleSplit
31
  from spacy.lang.en import English
32
  from tenacity import retry, stop_after_attempt, wait_random_exponential
33
-
34
 
35
  LOGGER = logging.getLogger(__name__)
36
 
@@ -123,6 +123,38 @@ def get_processing_tokenizer():
123
  PROCESSING_TOKENIZER = get_processing_tokenizer()
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def escape_markdown(text):
127
  escape_dict = {
128
  "*": r"\*",
@@ -274,6 +306,16 @@ def build_api_call_function(model):
274
 
275
  return output, length
276
 
 
 
 
 
 
 
 
 
 
 
277
  else:
278
 
279
  @retry(
@@ -284,31 +326,7 @@ def build_api_call_function(model):
284
  async def api_call_function(prompt, generation_config):
285
  hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model)
286
 
287
- # Reference for decoding stratagies:
288
- # https://huggingface.co/docs/transformers/generation_strategies
289
-
290
- # `text_generation_interface`
291
- # Currenly supports only `greedy` amd `sampling` decoding strategies
292
- # Following , we add `do_sample` if any of the other
293
- # samling related parameters are set
294
- # https://github.com/huggingface/text-generation-inference/blob/e943a294bca239e26828732dd6ab5b6f95dadd0a/server/text_generation_server/utils/tokens.py#L46
295
-
296
- # `transformers`
297
- # According to experimentations, it seems that `transformers` behave similarly
298
-
299
- # I'm not sure what is the right behavior here, but it is better to be explicit
300
- for name, params in GENERATION_CONFIG_PARAMS.items():
301
- # Checking for START to examine the a slider parameters only
302
- if (
303
- "START" in params
304
- and params["SAMPLING"]
305
- and name in generation_config
306
- and generation_config[name] is not None
307
- ):
308
- if generation_config[name] == params["DEFAULT"]:
309
- generation_config[name] = None
310
- else:
311
- assert generation_config["do_sample"]
312
 
313
  response = await hf_client.text_generation(
314
  prompt, stream=False, details=True, **generation_config
@@ -768,15 +786,19 @@ def main():
768
 
769
  with st.expander("Info"):
770
  try:
771
- st.caption("Dataset")
772
- st.write(dataset_info(st.session_state.dataset_name).cardData)
773
  except (HFValidationError, RepositoryNotFoundError):
774
  pass
 
 
 
775
  try:
776
- st.caption("Model")
777
- st.write(model_info(model).cardData)
778
  except (HFValidationError, RepositoryNotFoundError):
779
  pass
 
 
 
780
 
781
  # st.write(f"Model max length: {AutoTokenizer.from_pretrained(model).model_max_length}")
782
 
 
30
  from sklearn.model_selection import StratifiedShuffleSplit
31
  from spacy.lang.en import English
32
  from tenacity import retry, stop_after_attempt, wait_random_exponential
33
+ from transformers import pipeline
34
 
35
  LOGGER = logging.getLogger(__name__)
36
 
 
123
  PROCESSING_TOKENIZER = get_processing_tokenizer()
124
 
125
 
126
+ def prepare_huggingface_generation_config(generation_config):
127
+ generation_config = generation_config.copy()
128
+
129
+ # Reference for decoding stratagies:
130
+ # https://huggingface.co/docs/transformers/generation_strategies
131
+
132
+ # `text_generation_interface`
133
+ # Currenly supports only `greedy` amd `sampling` decoding strategies
134
+ # Following , we add `do_sample` if any of the other
135
+ # samling related parameters are set
136
+ # https://github.com/huggingface/text-generation-inference/blob/e943a294bca239e26828732dd6ab5b6f95dadd0a/server/text_generation_server/utils/tokens.py#L46
137
+
138
+ # `transformers`
139
+ # According to experimentations, it seems that `transformers` behave similarly
140
+
141
+ # I'm not sure what is the right behavior here, but it is better to be explicit
142
+ for name, params in GENERATION_CONFIG_PARAMS.items():
143
+ # Checking for START to examine the a slider parameters only
144
+ if (
145
+ "START" in params
146
+ and params["SAMPLING"]
147
+ and name in generation_config
148
+ and generation_config[name] is not None
149
+ ):
150
+ if generation_config[name] == params["DEFAULT"]:
151
+ generation_config[name] = None
152
+ else:
153
+ assert generation_config["do_sample"]
154
+
155
+ return generation_config
156
+
157
+
158
  def escape_markdown(text):
159
  escape_dict = {
160
  "*": r"\*",
 
306
 
307
  return output, length
308
 
309
+ elif model.startswith("#"):
310
+ model = model[1:]
311
+ pipe = pipeline("text-generation", model=model, trust_remote_code=True)
312
+
313
+ async def api_call_function(prompt, generation_config):
314
+ generation_config = prepare_huggingface_generation_config(generation_config)
315
+ return pipe(prompt, return_text=True, **generation_config)[0][
316
+ "generated_text"
317
+ ]
318
+
319
  else:
320
 
321
  @retry(
 
326
  async def api_call_function(prompt, generation_config):
327
  hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model)
328
 
329
+ generation_config = prepare_huggingface_generation_config(generation_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  response = await hf_client.text_generation(
332
  prompt, stream=False, details=True, **generation_config
 
786
 
787
  with st.expander("Info"):
788
  try:
789
+ data_card = dataset_info(st.session_state.dataset_name).cardData
 
790
  except (HFValidationError, RepositoryNotFoundError):
791
  pass
792
+ else:
793
+ st.caption("Dataset")
794
+ st.write(data_card)
795
  try:
796
+ model_card = model_info(model).cardData
 
797
  except (HFValidationError, RepositoryNotFoundError):
798
  pass
799
+ else:
800
+ st.caption("Model")
801
+ st.write(model_card)
802
 
803
  # st.write(f"Model max length: {AutoTokenizer.from_pretrained(model).model_max_length}")
804
 
requirements.txt CHANGED
@@ -11,3 +11,4 @@ scikit-learn
11
  spacy
12
  streamlit
13
  tenacity
 
 
11
  spacy
12
  streamlit
13
  tenacity
14
+ transformers