Spaces:
Sleeping
Sleeping
Hjgugugjhuhjggg
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ from fastapi.responses import StreamingResponse
|
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
-
pipeline,
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
@@ -130,15 +129,8 @@ async def generate(request: GenerateRequest):
|
|
130 |
except Exception as e:
|
131 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
132 |
|
133 |
-
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay
|
134 |
-
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True
|
135 |
-
input_length = encoded_input["input_ids"].shape[1]
|
136 |
-
remaining_tokens = max_length - input_length
|
137 |
-
|
138 |
-
if remaining_tokens <= 0:
|
139 |
-
yield ""
|
140 |
-
|
141 |
-
generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
|
142 |
|
143 |
def stop_criteria(input_ids, scores):
|
144 |
decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
|
@@ -146,7 +138,7 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
|
|
146 |
|
147 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
148 |
|
149 |
-
|
150 |
outputs = model.generate(
|
151 |
**encoded_input,
|
152 |
do_sample=generation_config.do_sample,
|
@@ -158,19 +150,29 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
|
|
158 |
num_return_sequences=generation_config.num_return_sequences,
|
159 |
stopping_criteria=stopping_criteria,
|
160 |
output_scores=True,
|
161 |
-
return_dict_in_generate=True
|
|
|
162 |
)
|
163 |
|
164 |
for output in outputs.sequences:
|
165 |
for token_id in output:
|
166 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
if stop_sequences and any(stop in
|
171 |
-
yield output_text
|
172 |
return
|
173 |
|
|
|
|
|
|
|
174 |
outputs = model.generate(
|
175 |
**encoded_input,
|
176 |
do_sample=generation_config.do_sample,
|
@@ -182,7 +184,8 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
|
|
182 |
num_return_sequences=generation_config.num_return_sequences,
|
183 |
stopping_criteria=stopping_criteria,
|
184 |
output_scores=True,
|
185 |
-
return_dict_in_generate=True
|
|
|
186 |
)
|
187 |
|
188 |
@app.post("/generate-image")
|
@@ -190,7 +193,7 @@ async def generate_image(request: GenerateRequest):
|
|
190 |
try:
|
191 |
validated_body = request
|
192 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
193 |
-
|
194 |
image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
|
195 |
image = image_generator(validated_body.input_text)[0]
|
196 |
|
@@ -208,7 +211,7 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
208 |
try:
|
209 |
validated_body = request
|
210 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
211 |
-
|
212 |
audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
|
213 |
audio = audio_generator(validated_body.input_text)[0]
|
214 |
|
|
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
|
|
8 |
AutoModelForCausalLM,
|
9 |
AutoTokenizer,
|
10 |
GenerationConfig,
|
|
|
129 |
except Exception as e:
|
130 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
131 |
|
132 |
+
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
|
133 |
+
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
def stop_criteria(input_ids, scores):
|
136 |
decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
|
|
|
138 |
|
139 |
stopping_criteria = StoppingCriteriaList([stop_criteria])
|
140 |
|
141 |
+
token_buffer = []
|
142 |
outputs = model.generate(
|
143 |
**encoded_input,
|
144 |
do_sample=generation_config.do_sample,
|
|
|
150 |
num_return_sequences=generation_config.num_return_sequences,
|
151 |
stopping_criteria=stopping_criteria,
|
152 |
output_scores=True,
|
153 |
+
return_dict_in_generate=True,
|
154 |
+
streamer=None # Ensure streamer is None for manual token processing
|
155 |
)
|
156 |
|
157 |
for output in outputs.sequences:
|
158 |
for token_id in output:
|
159 |
token = tokenizer.decode(token_id, skip_special_tokens=True)
|
160 |
+
token_buffer.append(token)
|
161 |
+
if len(token_buffer) >= 10:
|
162 |
+
yield "".join(token_buffer)
|
163 |
+
token_buffer = []
|
164 |
+
await asyncio.sleep(chunk_delay)
|
165 |
+
|
166 |
+
if token_buffer:
|
167 |
+
yield "".join(token_buffer)
|
168 |
+
token_buffer = []
|
169 |
|
170 |
+
if stop_sequences and any(stop in tokenizer.decode(output, skip_special_tokens=True) for stop in stop_sequences):
|
|
|
171 |
return
|
172 |
|
173 |
+
encoded_input = tokenizer.build_inputs_with_special_tokens(output)
|
174 |
+
encoded_input = {'input_ids': torch.tensor([encoded_input]).to(device)}
|
175 |
+
|
176 |
outputs = model.generate(
|
177 |
**encoded_input,
|
178 |
do_sample=generation_config.do_sample,
|
|
|
184 |
num_return_sequences=generation_config.num_return_sequences,
|
185 |
stopping_criteria=stopping_criteria,
|
186 |
output_scores=True,
|
187 |
+
return_dict_in_generate=True,
|
188 |
+
streamer=None
|
189 |
)
|
190 |
|
191 |
@app.post("/generate-image")
|
|
|
193 |
try:
|
194 |
validated_body = request
|
195 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
196 |
+
|
197 |
image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
|
198 |
image = image_generator(validated_body.input_text)[0]
|
199 |
|
|
|
211 |
try:
|
212 |
validated_body = request
|
213 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
214 |
+
|
215 |
audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
|
216 |
audio = audio_generator(validated_body.input_text)[0]
|
217 |
|