BeardedMonster commited on
Commit
f9affc4
·
verified ·
1 Parent(s): a7135a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -40
app.py CHANGED
@@ -8,7 +8,6 @@ import json
8
  import torch
9
  import re
10
  import nest_asyncio
11
- from hashlib import md5
12
 
13
  nest_asyncio.apply()
14
 
@@ -106,21 +105,6 @@ st.write("**It might take a while (~25s) to return an output on the first 'gener
106
  st.write("**For convenience, you can use chatgpt to copy text and evaluate model output.**")
107
  st.write("-" * 50)
108
 
109
- # async def generate_from_api(user_input, generation_config):
110
- # url = "https://pauljeffrey--sabiyarn-fastapi-app.modal.run/predict"
111
-
112
- # payload = {
113
- # "prompt": user_input,
114
- # "config": generation_config
115
- # }
116
-
117
- # headers = {
118
- # 'Content-Type': 'application/json'
119
- # }
120
-
121
- # async with aiohttp.ClientSession() as session:
122
- # async with session.post(url, headers=headers, json=payload) as response:
123
- # return await response.text()
124
 
125
  async def generate_from_api(user_input, generation_config):
126
  urls = [
@@ -151,15 +135,6 @@ async def generate_from_api(user_input, generation_config):
151
 
152
  return "FAILED"
153
 
154
- def generate_cache_key(user_input, generation_config):
155
- key_data = f"{user_input}_{json.dumps(generation_config, sort_keys=True)}"
156
- return md5(key_data.encode()).hexdigest()
157
-
158
- @st.cache_data(show_spinner=False)
159
- def get_cached_response(user_input, generation_config):
160
- return asyncio.run(generate_from_api(user_input, generation_config))
161
-
162
-
163
  # Sample texts
164
  sample_texts = {
165
  "select":"",
@@ -253,22 +228,9 @@ if st.button("Generate"):
253
  print("wrapped_input: ", wrapped_input)
254
  generation_config["max_new_tokens"]= min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
255
  start_time = time.time()
256
- # try:
257
- # Attempt the asynchronous API call
258
- generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
259
- # generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
260
- cache_key = generate_cache_key(wrapped_input, generation_config)
261
- generated_text = get_cached_response(wrapped_input, generation_config)
262
 
263
- # loop = asyncio.new_event_loop()
264
- # asyncio.set_event_loop(loop)
265
- # generated_text = loop.run_until_complete(generate_from_api(wrapped_input, generation_config))
266
- # except Exception as e:
267
- # print(f"API call failed: {e}. Using local model for text generation.")
268
- # Use the locally loaded model for text generation
269
- # input_ids = tokenizer(wrapped_input, return_tensors="pt")["input_ids"].to(device)
270
- # output = model.generate(input_ids, **generation_config)
271
- # generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
272
 
273
  if generated_text == "FAILED":
274
  input_ids = tokenizer(wrapped_input, return_tensors="pt")["input_ids"].to(device)
 
8
  import torch
9
  import re
10
  import nest_asyncio
 
11
 
12
  nest_asyncio.apply()
13
 
 
105
  st.write("**For convenience, you can use chatgpt to copy text and evaluate model output.**")
106
  st.write("-" * 50)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  async def generate_from_api(user_input, generation_config):
110
  urls = [
 
135
 
136
  return "FAILED"
137
 
 
 
 
 
 
 
 
 
 
138
  # Sample texts
139
  sample_texts = {
140
  "select":"",
 
228
  print("wrapped_input: ", wrapped_input)
229
  generation_config["max_new_tokens"]= min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
230
  start_time = time.time()
 
 
 
 
 
 
231
 
232
+ generation_config["max_new_tokens"] = min(max_new_tokens, 1024 - len(tokenizer.tokenize(wrapped_input)))
233
+ generated_text = asyncio.run(generate_from_api(wrapped_input, generation_config))
 
 
 
 
 
 
 
234
 
235
  if generated_text == "FAILED":
236
  input_ids = tokenizer(wrapped_input, return_tensors="pt")["input_ids"].to(device)