Spaces:
Runtime error
Runtime error
shlomihod
commited on
Commit
·
87ef96c
1
Parent(s):
c7feace
add local models
Browse files- app.py +52 -30
- requirements.txt +1 -0
app.py
CHANGED
@@ -30,7 +30,7 @@ from sklearn.metrics import (
|
|
30 |
from sklearn.model_selection import StratifiedShuffleSplit
|
31 |
from spacy.lang.en import English
|
32 |
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
33 |
-
|
34 |
|
35 |
LOGGER = logging.getLogger(__name__)
|
36 |
|
@@ -123,6 +123,38 @@ def get_processing_tokenizer():
|
|
123 |
PROCESSING_TOKENIZER = get_processing_tokenizer()
|
124 |
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def escape_markdown(text):
|
127 |
escape_dict = {
|
128 |
"*": r"\*",
|
@@ -274,6 +306,16 @@ def build_api_call_function(model):
|
|
274 |
|
275 |
return output, length
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
else:
|
278 |
|
279 |
@retry(
|
@@ -284,31 +326,7 @@ def build_api_call_function(model):
|
|
284 |
async def api_call_function(prompt, generation_config):
|
285 |
hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model)
|
286 |
|
287 |
-
|
288 |
-
# https://huggingface.co/docs/transformers/generation_strategies
|
289 |
-
|
290 |
-
# `text_generation_interface`
|
291 |
-
# Currenly supports only `greedy` amd `sampling` decoding strategies
|
292 |
-
# Following , we add `do_sample` if any of the other
|
293 |
-
# samling related parameters are set
|
294 |
-
# https://github.com/huggingface/text-generation-inference/blob/e943a294bca239e26828732dd6ab5b6f95dadd0a/server/text_generation_server/utils/tokens.py#L46
|
295 |
-
|
296 |
-
# `transformers`
|
297 |
-
# According to experimentations, it seems that `transformers` behave similarly
|
298 |
-
|
299 |
-
# I'm not sure what is the right behavior here, but it is better to be explicit
|
300 |
-
for name, params in GENERATION_CONFIG_PARAMS.items():
|
301 |
-
# Checking for START to examine the a slider parameters only
|
302 |
-
if (
|
303 |
-
"START" in params
|
304 |
-
and params["SAMPLING"]
|
305 |
-
and name in generation_config
|
306 |
-
and generation_config[name] is not None
|
307 |
-
):
|
308 |
-
if generation_config[name] == params["DEFAULT"]:
|
309 |
-
generation_config[name] = None
|
310 |
-
else:
|
311 |
-
assert generation_config["do_sample"]
|
312 |
|
313 |
response = await hf_client.text_generation(
|
314 |
prompt, stream=False, details=True, **generation_config
|
@@ -768,15 +786,19 @@ def main():
|
|
768 |
|
769 |
with st.expander("Info"):
|
770 |
try:
|
771 |
-
st.
|
772 |
-
st.write(dataset_info(st.session_state.dataset_name).cardData)
|
773 |
except (HFValidationError, RepositoryNotFoundError):
|
774 |
pass
|
|
|
|
|
|
|
775 |
try:
|
776 |
-
|
777 |
-
st.write(model_info(model).cardData)
|
778 |
except (HFValidationError, RepositoryNotFoundError):
|
779 |
pass
|
|
|
|
|
|
|
780 |
|
781 |
# st.write(f"Model max length: {AutoTokenizer.from_pretrained(model).model_max_length}")
|
782 |
|
|
|
30 |
from sklearn.model_selection import StratifiedShuffleSplit
|
31 |
from spacy.lang.en import English
|
32 |
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
33 |
+
from transformers import pipeline
|
34 |
|
35 |
LOGGER = logging.getLogger(__name__)
|
36 |
|
|
|
123 |
PROCESSING_TOKENIZER = get_processing_tokenizer()
|
124 |
|
125 |
|
126 |
+
def prepare_huggingface_generation_config(generation_config):
|
127 |
+
generation_config = generation_config.copy()
|
128 |
+
|
129 |
+
# Reference for decoding stratagies:
|
130 |
+
# https://huggingface.co/docs/transformers/generation_strategies
|
131 |
+
|
132 |
+
# `text_generation_interface`
|
133 |
+
# Currenly supports only `greedy` amd `sampling` decoding strategies
|
134 |
+
# Following , we add `do_sample` if any of the other
|
135 |
+
# samling related parameters are set
|
136 |
+
# https://github.com/huggingface/text-generation-inference/blob/e943a294bca239e26828732dd6ab5b6f95dadd0a/server/text_generation_server/utils/tokens.py#L46
|
137 |
+
|
138 |
+
# `transformers`
|
139 |
+
# According to experimentations, it seems that `transformers` behave similarly
|
140 |
+
|
141 |
+
# I'm not sure what is the right behavior here, but it is better to be explicit
|
142 |
+
for name, params in GENERATION_CONFIG_PARAMS.items():
|
143 |
+
# Checking for START to examine the a slider parameters only
|
144 |
+
if (
|
145 |
+
"START" in params
|
146 |
+
and params["SAMPLING"]
|
147 |
+
and name in generation_config
|
148 |
+
and generation_config[name] is not None
|
149 |
+
):
|
150 |
+
if generation_config[name] == params["DEFAULT"]:
|
151 |
+
generation_config[name] = None
|
152 |
+
else:
|
153 |
+
assert generation_config["do_sample"]
|
154 |
+
|
155 |
+
return generation_config
|
156 |
+
|
157 |
+
|
158 |
def escape_markdown(text):
|
159 |
escape_dict = {
|
160 |
"*": r"\*",
|
|
|
306 |
|
307 |
return output, length
|
308 |
|
309 |
+
elif model.startswith("#"):
|
310 |
+
model = model[1:]
|
311 |
+
pipe = pipeline("text-generation", model=model, trust_remote_code=True)
|
312 |
+
|
313 |
+
async def api_call_function(prompt, generation_config):
|
314 |
+
generation_config = prepare_huggingface_generation_config(generation_config)
|
315 |
+
return pipe(prompt, return_text=True, **generation_config)[0][
|
316 |
+
"generated_text"
|
317 |
+
]
|
318 |
+
|
319 |
else:
|
320 |
|
321 |
@retry(
|
|
|
326 |
async def api_call_function(prompt, generation_config):
|
327 |
hf_client = AsyncInferenceClient(token=HF_TOKEN, model=model)
|
328 |
|
329 |
+
generation_config = prepare_huggingface_generation_config(generation_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
response = await hf_client.text_generation(
|
332 |
prompt, stream=False, details=True, **generation_config
|
|
|
786 |
|
787 |
with st.expander("Info"):
|
788 |
try:
|
789 |
+
data_card = dataset_info(st.session_state.dataset_name).cardData
|
|
|
790 |
except (HFValidationError, RepositoryNotFoundError):
|
791 |
pass
|
792 |
+
else:
|
793 |
+
st.caption("Dataset")
|
794 |
+
st.write(data_card)
|
795 |
try:
|
796 |
+
model_card = model_info(model).cardData
|
|
|
797 |
except (HFValidationError, RepositoryNotFoundError):
|
798 |
pass
|
799 |
+
else:
|
800 |
+
st.caption("Model")
|
801 |
+
st.write(model_card)
|
802 |
|
803 |
# st.write(f"Model max length: {AutoTokenizer.from_pretrained(model).model_max_length}")
|
804 |
|
requirements.txt
CHANGED
@@ -11,3 +11,4 @@ scikit-learn
|
|
11 |
spacy
|
12 |
streamlit
|
13 |
tenacity
|
|
|
|
11 |
spacy
|
12 |
streamlit
|
13 |
tenacity
|
14 |
+
transformers
|