Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -130,32 +130,36 @@ async def generate(request: GenerateRequest):
|
|
130 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
131 |
|
132 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
|
133 |
-
encoded_input = tokenizer(input_text, return_tensors="pt"
|
134 |
|
135 |
def stop_criteria(input_ids, scores):
|
136 |
-
decoded_output = tokenizer.decode(
|
137 |
-
|
|
|
|
|
|
|
138 |
|
139 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
140 |
|
141 |
token_buffer = []
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
159 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
160 |
token_buffer.append(token)
|
161 |
if len(token_buffer) >= 10:
|
@@ -167,26 +171,13 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
|
|
167 |
yield "".join(token_buffer)
|
168 |
token_buffer = []
|
169 |
|
170 |
-
if
|
171 |
-
|
172 |
|
173 |
-
|
174 |
-
|
175 |
|
176 |
-
|
177 |
-
**encoded_input,
|
178 |
-
do_sample=generation_config.do_sample,
|
179 |
-
max_new_tokens=generation_config.max_new_tokens,
|
180 |
-
temperature=generation_config.temperature,
|
181 |
-
top_p=generation_config.top_p,
|
182 |
-
top_k=generation_config.top_k,
|
183 |
-
repetition_penalty=generation_config.repetition_penalty,
|
184 |
-
num_return_sequences=generation_config.num_return_sequences,
|
185 |
-
stopping_criteria=stopping_criteria,
|
186 |
-
output_scores=True,
|
187 |
-
return_dict_in_generate=True,
|
188 |
-
streamer=None
|
189 |
-
)
|
190 |
|
191 |
@app.post("/generate-image")
|
192 |
async def generate_image(request: GenerateRequest):
|
|
|
130 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
131 |
|
132 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
|
133 |
+
encoded_input = tokenizer(input_text, return_tensors="pt").to(device)
|
134 |
|
135 |
def stop_criteria(input_ids, scores):
|
136 |
+
decoded_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
137 |
+
for stop in stop_sequences:
|
138 |
+
if decoded_output.endswith(stop):
|
139 |
+
return True
|
140 |
+
return False
|
141 |
|
142 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
143 |
|
144 |
token_buffer = []
|
145 |
+
output_ids = encoded_input.input_ids
|
146 |
+
while True:
|
147 |
+
outputs = model.generate(
|
148 |
+
output_ids,
|
149 |
+
do_sample=generation_config.do_sample,
|
150 |
+
max_new_tokens=generation_config.max_new_tokens,
|
151 |
+
temperature=generation_config.temperature,
|
152 |
+
top_p=generation_config.top_p,
|
153 |
+
top_k=generation_config.top_k,
|
154 |
+
repetition_penalty=generation_config.repetition_penalty,
|
155 |
+
num_return_sequences=generation_config.num_return_sequences,
|
156 |
+
stopping_criteria=stopping_criteria,
|
157 |
+
output_scores=True,
|
158 |
+
return_dict_in_generate=True
|
159 |
+
)
|
160 |
+
new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
|
161 |
+
|
162 |
+
for token_id in new_token_ids:
|
163 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
164 |
token_buffer.append(token)
|
165 |
if len(token_buffer) >= 10:
|
|
|
171 |
yield "".join(token_buffer)
|
172 |
token_buffer = []
|
173 |
|
174 |
+
if stop_criteria(outputs.sequences, None):
|
175 |
+
break
|
176 |
|
177 |
+
if len(new_token_ids) < generation_config.max_new_tokens:
|
178 |
+
break
|
179 |
|
180 |
+
output_ids = outputs.sequences
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
@app.post("/generate-image")
|
183 |
async def generate_image(request: GenerateRequest):
|