Hjgugugjhuhjggg commited on
Commit
282a362
·
verified ·
1 Parent(s): 587b403

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -5,7 +5,6 @@ from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
8
- pipeline,
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
@@ -130,15 +129,8 @@ async def generate(request: GenerateRequest):
130
  except Exception as e:
131
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
132
 
133
- async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
134
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
- input_length = encoded_input["input_ids"].shape[1]
136
- remaining_tokens = max_length - input_length
137
-
138
- if remaining_tokens <= 0:
139
- yield ""
140
-
141
- generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
142
 
143
  def stop_criteria(input_ids, scores):
144
  decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
@@ -146,7 +138,7 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
146
 
147
  stopping_criteria = StoppingCriteriaList([stop_criteria])
148
 
149
- output_text = ""
150
  outputs = model.generate(
151
  **encoded_input,
152
  do_sample=generation_config.do_sample,
@@ -158,19 +150,29 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
158
  num_return_sequences=generation_config.num_return_sequences,
159
  stopping_criteria=stopping_criteria,
160
  output_scores=True,
161
- return_dict_in_generate=True
 
162
  )
163
 
164
  for output in outputs.sequences:
165
  for token_id in output:
166
  token = tokenizer.decode(token_id, skip_special_tokens=True)
167
- yield token
168
- await asyncio.sleep(chunk_delay) # Simula el delay entre tokens
 
 
 
 
 
 
 
169
 
170
- if stop_sequences and any(stop in output_text for stop in stop_sequences):
171
- yield output_text
172
  return
173
 
 
 
 
174
  outputs = model.generate(
175
  **encoded_input,
176
  do_sample=generation_config.do_sample,
@@ -182,7 +184,8 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
182
  num_return_sequences=generation_config.num_return_sequences,
183
  stopping_criteria=stopping_criteria,
184
  output_scores=True,
185
- return_dict_in_generate=True
 
186
  )
187
 
188
  @app.post("/generate-image")
@@ -190,7 +193,7 @@ async def generate_image(request: GenerateRequest):
190
  try:
191
  validated_body = request
192
  device = "cuda" if torch.cuda.is_available() else "cpu"
193
-
194
  image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
195
  image = image_generator(validated_body.input_text)[0]
196
 
@@ -208,7 +211,7 @@ async def generate_text_to_speech(request: GenerateRequest):
208
  try:
209
  validated_body = request
210
  device = "cuda" if torch.cuda.is_available() else "cpu"
211
-
212
  audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
213
  audio = audio_generator(validated_body.input_text)[0]
214
 
 
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
 
8
  AutoModelForCausalLM,
9
  AutoTokenizer,
10
  GenerationConfig,
 
129
  except Exception as e:
130
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
131
 
132
+ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
133
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
 
 
 
 
 
 
 
134
 
135
  def stop_criteria(input_ids, scores):
136
  decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
 
138
 
139
  stopping_criteria = StoppingCriteriaList([stop_criteria])
140
 
141
+ token_buffer = []
142
  outputs = model.generate(
143
  **encoded_input,
144
  do_sample=generation_config.do_sample,
 
150
  num_return_sequences=generation_config.num_return_sequences,
151
  stopping_criteria=stopping_criteria,
152
  output_scores=True,
153
+ return_dict_in_generate=True,
154
+ streamer=None # Ensure streamer is None for manual token processing
155
  )
156
 
157
  for output in outputs.sequences:
158
  for token_id in output:
159
  token = tokenizer.decode(token_id, skip_special_tokens=True)
160
+ token_buffer.append(token)
161
+ if len(token_buffer) >= 10:
162
+ yield "".join(token_buffer)
163
+ token_buffer = []
164
+ await asyncio.sleep(chunk_delay)
165
+
166
+ if token_buffer:
167
+ yield "".join(token_buffer)
168
+ token_buffer = []
169
 
170
+ if stop_sequences and any(stop in tokenizer.decode(output, skip_special_tokens=True) for stop in stop_sequences):
 
171
  return
172
 
173
+ encoded_input = tokenizer.build_inputs_with_special_tokens(output)
174
+ encoded_input = {'input_ids': torch.tensor([encoded_input]).to(device)}
175
+
176
  outputs = model.generate(
177
  **encoded_input,
178
  do_sample=generation_config.do_sample,
 
184
  num_return_sequences=generation_config.num_return_sequences,
185
  stopping_criteria=stopping_criteria,
186
  output_scores=True,
187
+ return_dict_in_generate=True,
188
+ streamer=None
189
  )
190
 
191
  @app.post("/generate-image")
 
193
  try:
194
  validated_body = request
195
  device = "cuda" if torch.cuda.is_available() else "cpu"
196
+
197
  image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
198
  image = image_generator(validated_body.input_text)[0]
199
 
 
211
  try:
212
  validated_body = request
213
  device = "cuda" if torch.cuda.is_available() else "cpu"
214
+
215
  audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
216
  audio = audio_generator(validated_body.input_text)[0]
217