jhj0517 commited on
Commit
79e6f08
·
1 Parent(s): a1b32c1

Refactor translation

Browse files
modules/translation/deepl_api.py CHANGED
@@ -139,37 +139,28 @@ class DeepLAPI:
139
  )
140
 
141
  files_info = {}
142
- for fileobj in fileobjs:
143
- file_path = fileobj
144
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
145
-
146
- if file_ext == ".srt":
147
- parsed_dicts = parse_srt(file_path=file_path)
148
-
149
- elif file_ext == ".vtt":
150
- parsed_dicts = parse_vtt(file_path=file_path)
151
 
152
  batch_size = self.max_text_batch_size
153
- for batch_start in range(0, len(parsed_dicts), batch_size):
154
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
155
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
156
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
157
  target_lang, is_pro)
158
  for i, translated_text in enumerate(translated_texts):
159
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
160
- progress(batch_end / len(parsed_dicts), desc="Translating..")
161
-
162
- if file_ext == ".srt":
163
- subtitle = get_serialized_srt(parsed_dicts)
164
- elif file_ext == ".vtt":
165
- subtitle = get_serialized_vtt(parsed_dicts)
166
-
167
- if add_timestamp:
168
- timestamp = datetime.now().strftime("%m%d%H%M%S")
169
- file_name += f"-{timestamp}"
170
-
171
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
172
- write_file(subtitle, output_path)
173
 
174
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
175
 
 
139
  )
140
 
141
  files_info = {}
142
+ for file_path in fileobjs:
143
+ file_name, file_ext = os.path.splitext(os.path.basename(file_path))
144
+ writer = get_writer(file_ext, self.output_dir)
145
+ segments = writer.to_segments(file_path)
 
 
 
 
 
146
 
147
  batch_size = self.max_text_batch_size
148
+ for batch_start in range(0, len(segments), batch_size):
149
+ progress(batch_start / len(segments), desc="Translating..")
150
+ sentences_to_translate = segments[batch_start:batch_start+batch_size]
151
  translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
152
  target_lang, is_pro)
153
  for i, translated_text in enumerate(translated_texts):
154
+ segments[batch_start + i].text = translated_text["text"]
155
+ print("DeepL Segments: ", segments)
156
+
157
+ subtitle, output_path = generate_file(
158
+ output_dir=self.output_dir,
159
+ output_file_name=file_name,
160
+ output_format=file_ext,
161
+ result=segments,
162
+ add_timestamp=add_timestamp
163
+ )
 
 
 
 
164
 
165
  files_info[file_name] = {"subtitle": subtitle, "path": output_path}
166
 
modules/translation/translation_base.py CHANGED
@@ -95,32 +95,22 @@ class TranslationBase(ABC):
95
  files_info = {}
96
  for fileobj in fileobjs:
97
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
98
- if file_ext == ".srt":
99
- parsed_dicts = parse_srt(file_path=fileobj)
100
- total_progress = len(parsed_dicts)
101
- for index, dic in enumerate(parsed_dicts):
102
- progress(index / total_progress, desc="Translating..")
103
- translated_text = self.translate(dic["sentence"], max_length=max_length)
104
- dic["sentence"] = translated_text
105
- subtitle = get_serialized_srt(parsed_dicts)
106
-
107
- elif file_ext == ".vtt":
108
- parsed_dicts = parse_vtt(file_path=fileobj)
109
- total_progress = len(parsed_dicts)
110
- for index, dic in enumerate(parsed_dicts):
111
- progress(index / total_progress, desc="Translating..")
112
- translated_text = self.translate(dic["sentence"], max_length=max_length)
113
- dic["sentence"] = translated_text
114
- subtitle = get_serialized_vtt(parsed_dicts)
115
-
116
- if add_timestamp:
117
- timestamp = datetime.now().strftime("%m%d%H%M%S")
118
- file_name += f"-{timestamp}"
119
-
120
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
121
- write_file(subtitle, output_path)
122
-
123
- files_info[file_name] = {"subtitle": subtitle, "path": output_path}
124
 
125
  total_result = ''
126
  for file_name, info in files_info.items():
@@ -134,6 +124,8 @@ class TranslationBase(ABC):
134
 
135
  except Exception as e:
136
  print(f"Error: {str(e)}")
 
 
137
  finally:
138
  self.release_cuda_memory()
139
 
 
95
  files_info = {}
96
  for fileobj in fileobjs:
97
  file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
98
+ writer = get_writer(file_ext, self.output_dir)
99
+ segments = writer.to_segments(fileobj)
100
+ for i, segment in enumerate(segments):
101
+ progress(i / len(segments), desc="Translating..")
102
+ translated_text = self.translate(segment.text, max_length=max_length)
103
+ segment.text = translated_text
104
+
105
+ subtitle, file_path = generate_file(
106
+ output_dir=self.output_dir,
107
+ output_file_name=file_name,
108
+ output_format=file_ext,
109
+ result=segments,
110
+ add_timestamp=add_timestamp
111
+ )
112
+
113
+ files_info[file_name] = {"subtitle": subtitle, "path": file_path}
 
 
 
 
 
 
 
 
 
 
114
 
115
  total_result = ''
116
  for file_name, info in files_info.items():
 
124
 
125
  except Exception as e:
126
  print(f"Error: {str(e)}")
127
+ import traceback
128
+ traceback.print_exc()
129
  finally:
130
  self.release_cuda_memory()
131
 
modules/utils/subtitle_manager.py CHANGED
@@ -33,6 +33,18 @@ def format_timestamp(
33
  )
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def get_start(segments: List[dict]) -> Optional[float]:
37
  return next(
38
  (w["start"] for s in segments for w in s["words"]),
@@ -54,16 +66,12 @@ class ResultWriter:
54
  self.output_dir = output_dir
55
 
56
  def __call__(
57
- self, result: Union[dict, List[Segment]], output_file_name: str, add_timestamp: bool = True,
58
  options: Optional[dict] = None, **kwargs
59
  ):
60
  if isinstance(result, List) and result and isinstance(result[0], Segment):
61
  result = {"segments": [seg.dict() for seg in result]}
62
 
63
- if add_timestamp:
64
- timestamp = datetime.now().strftime("%m%d%H%M%S")
65
- output_file_name += f"-{timestamp}"
66
-
67
  output_path = os.path.join(
68
  self.output_dir, output_file_name + "." + self.extension
69
  )
@@ -216,6 +224,26 @@ class WriteVTT(SubtitlesWriter):
216
  for start, end, text in self.iterate_result(result, options, **kwargs):
217
  print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  class WriteSRT(SubtitlesWriter):
221
  extension: str = "srt"
@@ -230,6 +258,27 @@ class WriteSRT(SubtitlesWriter):
230
  ):
231
  print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  class WriteTSV(ResultWriter):
235
  """
@@ -265,7 +314,7 @@ class WriteJSON(ResultWriter):
265
  def get_writer(
266
  output_format: str, output_dir: str
267
  ) -> Callable[[dict, TextIO, dict], None]:
268
- output_format = output_format.strip().lower()
269
 
270
  writers = {
271
  "txt": WriteTXT,
@@ -292,75 +341,19 @@ def get_writer(
292
  def generate_file(
293
  output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str, add_timestamp: bool = True,
294
  ) -> Tuple[str, str]:
 
 
 
 
 
 
295
  file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
296
  file_writer = get_writer(output_format=output_format, output_dir=output_dir)
297
- file_writer(result=result, output_file_name=output_file_name, add_timestamp=add_timestamp)
298
  content = read_file(file_path)
299
  return content, file_path
300
 
301
 
302
- def parse_srt(file_path):
303
- """Reads SRT file and returns as dict"""
304
- with open(file_path, 'r', encoding='utf-8') as file:
305
- srt_data = file.read()
306
-
307
- data = []
308
- blocks = srt_data.split('\n\n')
309
-
310
- for block in blocks:
311
- if block.strip() != '':
312
- lines = block.strip().split('\n')
313
- index = lines[0]
314
- timestamp = lines[1]
315
- sentence = ' '.join(lines[2:])
316
-
317
- data.append({
318
- "index": index,
319
- "timestamp": timestamp,
320
- "sentence": sentence
321
- })
322
- return data
323
-
324
-
325
- def parse_vtt(file_path):
326
- """Reads WEBVTT file and returns as dict"""
327
- with open(file_path, 'r', encoding='utf-8') as file:
328
- webvtt_data = file.read()
329
-
330
- data = []
331
- blocks = webvtt_data.split('\n\n')
332
-
333
- for block in blocks:
334
- if block.strip() != '' and not block.strip().startswith("WEBVTT"):
335
- lines = block.strip().split('\n')
336
- timestamp = lines[0]
337
- sentence = ' '.join(lines[1:])
338
-
339
- data.append({
340
- "timestamp": timestamp,
341
- "sentence": sentence
342
- })
343
-
344
- return data
345
-
346
-
347
- def get_serialized_srt(dicts):
348
- output = ""
349
- for dic in dicts:
350
- output += f'{dic["index"]}\n'
351
- output += f'{dic["timestamp"]}\n'
352
- output += f'{dic["sentence"]}\n\n'
353
- return output
354
-
355
-
356
- def get_serialized_vtt(dicts):
357
- output = "WEBVTT\n\n"
358
- for dic in dicts:
359
- output += f'{dic["timestamp"]}\n'
360
- output += f'{dic["sentence"]}\n\n'
361
- return output
362
-
363
-
364
  def safe_filename(name):
365
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
366
  safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
 
33
  )
34
 
35
 
36
+ def time_str_to_seconds(time_str: str, decimal_marker: str = ",") -> float:
37
+ hours, minutes, rest = time_str.split(":")
38
+ seconds, fractional = rest.split(decimal_marker)
39
+
40
+ hours = int(hours)
41
+ minutes = int(minutes)
42
+ seconds = int(seconds)
43
+ fractional_seconds = float("0." + fractional)
44
+
45
+ return hours * 3600 + minutes * 60 + seconds + fractional_seconds
46
+
47
+
48
  def get_start(segments: List[dict]) -> Optional[float]:
49
  return next(
50
  (w["start"] for s in segments for w in s["words"]),
 
66
  self.output_dir = output_dir
67
 
68
  def __call__(
69
+ self, result: Union[dict, List[Segment]], output_file_name: str,
70
  options: Optional[dict] = None, **kwargs
71
  ):
72
  if isinstance(result, List) and result and isinstance(result[0], Segment):
73
  result = {"segments": [seg.dict() for seg in result]}
74
 
 
 
 
 
75
  output_path = os.path.join(
76
  self.output_dir, output_file_name + "." + self.extension
77
  )
 
224
  for start, end, text in self.iterate_result(result, options, **kwargs):
225
  print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
226
 
227
+ def to_segments(self, file_path: str) -> List[Segment]:
228
+ segments = []
229
+
230
+ blocks = read_file(file_path).split('\n\n')
231
+
232
+ for block in blocks:
233
+ if block.strip() != '' and not block.strip().startswith("WEBVTT"):
234
+ lines = block.strip().split('\n')
235
+ time_line = lines[0].split(" --> ")
236
+ start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
237
+ sentence = ' '.join(lines[1:])
238
+
239
+ segments.append(Segment(
240
+ start=start,
241
+ end=end,
242
+ text=sentence
243
+ ))
244
+
245
+ return segments
246
+
247
 
248
  class WriteSRT(SubtitlesWriter):
249
  extension: str = "srt"
 
258
  ):
259
  print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
260
 
261
+ def to_segments(self, file_path: str) -> List[Segment]:
262
+ segments = []
263
+
264
+ blocks = read_file(file_path).split('\n\n')
265
+
266
+ for block in blocks:
267
+ if block.strip() != '':
268
+ lines = block.strip().split('\n')
269
+ index = lines[0]
270
+ time_line = lines[1].split(" --> ")
271
+ start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
272
+ sentence = ' '.join(lines[2:])
273
+
274
+ segments.append(Segment(
275
+ start=start,
276
+ end=end,
277
+ text=sentence
278
+ ))
279
+
280
+ return segments
281
+
282
 
283
  class WriteTSV(ResultWriter):
284
  """
 
314
  def get_writer(
315
  output_format: str, output_dir: str
316
  ) -> Callable[[dict, TextIO, dict], None]:
317
+ output_format = output_format.strip().lower().replace(".", "")
318
 
319
  writers = {
320
  "txt": WriteTXT,
 
341
  def generate_file(
342
  output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str, add_timestamp: bool = True,
343
  ) -> Tuple[str, str]:
344
+ output_format = output_format.strip().lower().replace(".", "")
345
+
346
+ if add_timestamp:
347
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
348
+ output_file_name += timestamp
349
+
350
  file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
351
  file_writer = get_writer(output_format=output_format, output_dir=output_dir)
352
+ file_writer(result=result, output_file_name=output_file_name)
353
  content = read_file(file_path)
354
  return content, file_path
355
 
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  def safe_filename(name):
358
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
359
  safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)