Autorestart space: Cuda Error

#15
by doevent - opened
Files changed (1) hide show
  1. app.py +88 -80
app.py CHANGED
@@ -119,90 +119,98 @@ def build_html_error_message(error):
119
  @GPU_DECORATOR
120
  @torch.inference_mode()
121
  def inference(req: ServeTTSRequest):
122
- # Parse reference audio aka prompt
123
- refs = req.references
124
-
125
- prompt_tokens = [
126
- encode_reference(
127
- decoder_model=decoder_model,
128
- reference_audio=ref.audio,
129
- enable_reference_audio=True,
130
- )
131
- for ref in refs
132
- ]
133
- prompt_texts = [ref.text for ref in refs]
134
-
135
- if req.seed is not None:
136
- set_seed(req.seed)
137
- logger.warning(f"set seed: {req.seed}")
138
-
139
- # LLAMA Inference
140
- request = dict(
141
- device=decoder_model.device,
142
- max_new_tokens=req.max_new_tokens,
143
- text=(
144
- req.text
145
- if not req.normalize
146
- else ChnNormedText(raw_text=req.text).normalize()
147
- ),
148
- top_p=req.top_p,
149
- repetition_penalty=req.repetition_penalty,
150
- temperature=req.temperature,
151
- compile=args.compile,
152
- iterative_prompt=req.chunk_length > 0,
153
- chunk_length=req.chunk_length,
154
- max_length=4096,
155
- prompt_tokens=prompt_tokens,
156
- prompt_text=prompt_texts,
157
- )
158
-
159
- response_queue = queue.Queue()
160
- llama_queue.put(
161
- GenerateRequest(
162
- request=request,
163
- response_queue=response_queue,
164
- )
165
- )
166
-
167
- segments = []
168
-
169
- while True:
170
- result: WrappedGenerateResponse = response_queue.get()
171
- if result.status == "error":
172
- yield None, None, build_html_error_message(result.response)
173
- break
174
-
175
- result: GenerateResponse = result.response
176
- if result.action == "next":
177
- break
178
-
179
- with autocast_exclude_mps(
180
- device_type=decoder_model.device.type, dtype=args.precision
181
- ):
182
- fake_audios = decode_vq_tokens(
183
  decoder_model=decoder_model,
184
- codes=result.codes,
 
185
  )
186
-
187
- fake_audios = fake_audios.float().cpu().numpy()
188
- segments.append(fake_audios)
189
-
190
- if len(segments) == 0:
191
- return (
192
- None,
193
- None,
194
- build_html_error_message(
195
- i18n("No audio generated, please check the input text.")
 
 
 
 
 
 
196
  ),
 
 
 
 
 
 
 
 
 
197
  )
198
-
199
- # No matter streaming or not, we need to return the final audio
200
- audio = np.concatenate(segments, axis=0)
201
- yield None, (decoder_model.spec_transform.sample_rate, audio), None
202
-
203
- if torch.cuda.is_available():
204
- torch.cuda.empty_cache()
205
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  n_audios = 4
208
 
 
119
  @GPU_DECORATOR
120
  @torch.inference_mode()
121
  def inference(req: ServeTTSRequest):
122
+ try:
123
+ # Parse reference audio aka prompt
124
+ refs = req.references
125
+
126
+ prompt_tokens = [
127
+ encode_reference(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  decoder_model=decoder_model,
129
+ reference_audio=ref.audio,
130
+ enable_reference_audio=True,
131
  )
132
+ for ref in refs
133
+ ]
134
+ prompt_texts = [ref.text for ref in refs]
135
+
136
+ if req.seed is not None:
137
+ set_seed(req.seed)
138
+ logger.warning(f"set seed: {req.seed}")
139
+
140
+ # LLAMA Inference
141
+ request = dict(
142
+ device=decoder_model.device,
143
+ max_new_tokens=req.max_new_tokens,
144
+ text=(
145
+ req.text
146
+ if not req.normalize
147
+ else ChnNormedText(raw_text=req.text).normalize()
148
  ),
149
+ top_p=req.top_p,
150
+ repetition_penalty=req.repetition_penalty,
151
+ temperature=req.temperature,
152
+ compile=args.compile,
153
+ iterative_prompt=req.chunk_length > 0,
154
+ chunk_length=req.chunk_length,
155
+ max_length=4096,
156
+ prompt_tokens=prompt_tokens,
157
+ prompt_text=prompt_texts,
158
  )
159
+
160
+ response_queue = queue.Queue()
161
+ llama_queue.put(
162
+ GenerateRequest(
163
+ request=request,
164
+ response_queue=response_queue,
165
+ )
166
+ )
167
+
168
+ segments = []
169
+
170
+ while True:
171
+ result: WrappedGenerateResponse = response_queue.get()
172
+ if result.status == "error":
173
+ yield None, None, build_html_error_message(result.response)
174
+ break
175
+
176
+ result: GenerateResponse = result.response
177
+ if result.action == "next":
178
+ break
179
+
180
+ with autocast_exclude_mps(
181
+ device_type=decoder_model.device.type, dtype=args.precision
182
+ ):
183
+ fake_audios = decode_vq_tokens(
184
+ decoder_model=decoder_model,
185
+ codes=result.codes,
186
+ )
187
+
188
+ fake_audios = fake_audios.float().cpu().numpy()
189
+ segments.append(fake_audios)
190
+
191
+ if len(segments) == 0:
192
+ return (
193
+ None,
194
+ None,
195
+ build_html_error_message(
196
+ i18n("No audio generated, please check the input text.")
197
+ ),
198
+ )
199
+
200
+ # No matter streaming or not, we need to return the final audio
201
+ audio = np.concatenate(segments, axis=0)
202
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
203
+
204
+ if torch.cuda.is_available():
205
+ torch.cuda.empty_cache()
206
+ gc.collect()
207
+
208
+ except Exception as e:
209
+ er = "CUDA error: device-side assert triggered"
210
+ if er in e:
211
+ app.close()
212
+ else:
213
+ raise Exception(e)
214
 
215
  n_audios = 4
216