Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ import json
|
|
8 |
import torch
|
9 |
import re
|
10 |
import nest_asyncio
|
11 |
-
from hashlib import md5
|
12 |
|
13 |
nest_asyncio.apply()
|
14 |
|
@@ -106,21 +105,6 @@ st.write("**It might take a while (~25s) to return an output on the first 'gener
|
|
106 |
st.write("**For convenience, you can use chatgpt to copy text and evaluate model output.**")
|
107 |
st.write("-" * 50)
|
108 |
|
109 |
-
# async def generate_from_api(user_input, generation_config):
|
110 |
-
# url = "https://pauljeffrey--sabiyarn-fastapi-app.modal.run/predict"
|
111 |
-
|
112 |
-
# payload = {
|
113 |
-
# "prompt": user_input,
|
114 |
-
# "config": generation_config
|
115 |
-
# }
|
116 |
-
|
117 |
-
# headers = {
|
118 |
-
# 'Content-Type': 'application/json'
|
119 |
-
# }
|
120 |
-
|
121 |
-
# async with aiohttp.ClientSession() as session:
|
122 |
-
# async with session.post(url, headers=headers, json=payload) as response:
|
123 |
-
# return await response.text()
|
124 |
|
125 |
async def generate_from_api(user_input, generation_config):
|
126 |
urls = [
|
@@ -151,15 +135,6 @@ async def generate_from_api(user_input, generation_config):
|
|
151 |
|
152 |
return "FAILED"
|
153 |
|
154 |
-
def generate_cache_key(user_input, generation_config):
|
155 |
-
key_data = f"{user_input}_{json.dumps(generation_config, sort_keys=True)}"
|
156 |
-
return md5(key_data.encode()).hexdigest()
|
157 |
-
|
158 |
-
@st.cache_data(show_spinner=False)
|
159 |
-
def get_cached_response(user_input, generation_config):
|
160 |
-
return asyncio.run(generate_from_api(user_input, generation_config))
|
161 |
-
|
162 |
-
|
163 |
# Sample texts
|
164 |
sample_texts = {
|
165 |
"select":"",
|
@@ -253,22 +228,9 @@ if st.button("Generate"):
|
|
253 |
print("wrapped_input: ", wrapped_input)
|
254 |
generation_config["max_new_tokens"]= min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
|
255 |
start_time = time.time()
|
256 |
-
# try:
|
257 |
-
# Attempt the asynchronous API call
|
258 |
-
generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
|
259 |
-
# generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
|
260 |
-
cache_key = generate_cache_key(wrapped_input, generation_config)
|
261 |
-
generated_text = get_cached_response(wrapped_input, generation_config)
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
# generated_text = loop.run_until_complete(generate_from_api(wrapped_input, generation_config))
|
266 |
-
# except Exception as e:
|
267 |
-
# print(f"API call failed: {e}. Using local model for text generation.")
|
268 |
-
# Use the locally loaded model for text generation
|
269 |
-
# input_ids = tokenizer(wrapped_input, return_tensors="pt")["input_ids"].to(device)
|
270 |
-
# output = model.generate(input_ids, **generation_config)
|
271 |
-
# generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
272 |
|
273 |
if generated_text == "FAILED":
|
274 |
input_ids = tokenizer(wrapped_input, return_tensors="pt")["input_ids"].to(device)
|
|
|
8 |
import torch
|
9 |
import re
|
10 |
import nest_asyncio
|
|
|
11 |
|
12 |
nest_asyncio.apply()
|
13 |
|
|
|
105 |
st.write("**For convenience, you can use chatgpt to copy text and evaluate model output.**")
|
106 |
st.write("-" * 50)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
async def generate_from_api(user_input, generation_config):
|
110 |
urls = [
|
|
|
135 |
|
136 |
return "FAILED"
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
# Sample texts
|
139 |
sample_texts = {
|
140 |
"select":"",
|
|
|
228 |
print("wrapped_input: ", wrapped_input)
|
229 |
generation_config["max_new_tokens"]= min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
|
230 |
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
+
generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
|
233 |
+
generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
if generated_text == "FAILED":
|
236 |
input_ids = tokenizer(wrapped_input, return_tensors="pt")["input_ids"].to(device)
|