jhj0517 commited on
Commit
9cf3b54
·
1 Parent(s): c0bbe98

Add file writer classes

Browse files
Files changed (1) hide show
  1. modules/utils/subtitle_manager.py +275 -58
modules/utils/subtitle_manager.py CHANGED
@@ -1,66 +1,283 @@
 
 
 
 
1
  import re
 
 
 
2
 
3
  from modules.whisper.data_classes import Segment
4
 
5
 
6
- def timeformat_srt(time):
7
- hours = time // 3600
8
- minutes = (time - hours * 3600) // 60
9
- seconds = time - hours * 3600 - minutes * 60
10
- milliseconds = (time - int(time)) * 1000
11
- return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
12
-
13
-
14
- def timeformat_vtt(time):
15
- hours = time // 3600
16
- minutes = (time - hours * 3600) // 60
17
- seconds = time - hours * 3600 - minutes * 60
18
- milliseconds = (time - int(time)) * 1000
19
- return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
20
-
21
-
22
- def write_file(subtitle, output_file):
23
- with open(output_file, 'w', encoding='utf-8') as f:
24
- f.write(subtitle)
25
-
26
-
27
- def get_srt(segments):
28
- if segments and isinstance(segments[0], Segment):
29
- segments = [seg.dict() for seg in segments]
30
-
31
- output = ""
32
- for i, segment in enumerate(segments):
33
- output += f"{i + 1}\n"
34
- output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
35
- if segment['text'].startswith(' '):
36
- segment['text'] = segment['text'][1:]
37
- output += f"{segment['text']}\n\n"
38
- return output
39
-
40
-
41
- def get_vtt(segments):
42
- if segments and isinstance(segments[0], Segment):
43
- segments = [seg.dict() for seg in segments]
44
-
45
- output = "WEBVTT\n\n"
46
- for i, segment in enumerate(segments):
47
- output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
48
- if segment['text'].startswith(' '):
49
- segment['text'] = segment['text'][1:]
50
- output += f"{segment['text']}\n\n"
51
- return output
52
-
53
-
54
- def get_txt(segments):
55
- if segments and isinstance(segments[0], Segment):
56
- segments = [seg.dict() for seg in segments]
57
-
58
- output = ""
59
- for i, segment in enumerate(segments):
60
- if segment['text'].startswith(' '):
61
- segment['text'] = segment['text'][1:]
62
- output += f"{segment['text']}\n"
63
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  def parse_srt(file_path):
 
1
+ # Ported from https://github.com/openai/whisper/blob/main/whisper/utils.py
2
+
3
+ import json
4
+ import os
5
  import re
6
+ import sys
7
+ import zlib
8
+ from typing import Callable, List, Optional, TextIO, Union, Dict
9
 
10
  from modules.whisper.data_classes import Segment
11
 
12
 
13
+ def format_timestamp(
14
+ seconds: float, always_include_hours: bool = True, decimal_marker: str = ","
15
+ ):
16
+ assert seconds >= 0, "non-negative timestamp expected"
17
+ milliseconds = round(seconds * 1000.0)
18
+
19
+ hours = milliseconds // 3_600_000
20
+ milliseconds -= hours * 3_600_000
21
+
22
+ minutes = milliseconds // 60_000
23
+ milliseconds -= minutes * 60_000
24
+
25
+ seconds = milliseconds // 1_000
26
+ milliseconds -= seconds * 1_000
27
+
28
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
29
+ return (
30
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
31
+ )
32
+
33
+
34
+ def get_start(segments: List[dict]) -> Optional[float]:
35
+ return next(
36
+ (w["start"] for s in segments for w in s["words"]),
37
+ segments[0]["start"] if segments else None,
38
+ )
39
+
40
+
41
+ def get_end(segments: List[dict]) -> Optional[float]:
42
+ return next(
43
+ (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
44
+ segments[-1]["end"] if segments else None,
45
+ )
46
+
47
+
48
+ class ResultWriter:
49
+ extension: str
50
+
51
+ def __init__(self, output_dir: str):
52
+ self.output_dir = output_dir
53
+
54
+ def __call__(
55
+ self, result: Union[dict, List[Segment]], output_file_name: str, options: Optional[dict] = None, **kwargs
56
+ ):
57
+ if isinstance(result, List) and result and isinstance(result[0], Segment):
58
+ result = [seg.dict() for seg in result]
59
+
60
+ output_path = os.path.join(
61
+ self.output_dir, output_file_name + "." + self.extension
62
+ )
63
+
64
+ with open(output_path, "w", encoding="utf-8") as f:
65
+ self.write_result(result, file=f, options=options, **kwargs)
66
+
67
+ def write_result(
68
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
69
+ ):
70
+ raise NotImplementedError
71
+
72
+
73
+ class WriteTXT(ResultWriter):
74
+ extension: str = "txt"
75
+
76
+ def write_result(
77
+ self, result: Union[Dict, List[Segment]], file: TextIO, options: Optional[dict] = None, **kwargs
78
+ ):
79
+ for segment in result["segments"]:
80
+ print(segment["text"].strip(), file=file, flush=True)
81
+
82
+
83
+ class SubtitlesWriter(ResultWriter):
84
+ always_include_hours: bool
85
+ decimal_marker: str
86
+
87
+ def iterate_result(
88
+ self,
89
+ result: dict,
90
+ options: Optional[dict] = None,
91
+ *,
92
+ max_line_width: Optional[int] = None,
93
+ max_line_count: Optional[int] = None,
94
+ highlight_words: bool = False,
95
+ max_words_per_line: Optional[int] = None,
96
+ ):
97
+ options = options or {}
98
+ max_line_width = max_line_width or options.get("max_line_width")
99
+ max_line_count = max_line_count or options.get("max_line_count")
100
+ highlight_words = highlight_words or options.get("highlight_words", False)
101
+ max_words_per_line = max_words_per_line or options.get("max_words_per_line")
102
+ preserve_segments = max_line_count is None or max_line_width is None
103
+ max_line_width = max_line_width or 1000
104
+ max_words_per_line = max_words_per_line or 1000
105
+
106
+ def iterate_subtitles():
107
+ line_len = 0
108
+ line_count = 1
109
+ # the next subtitle to yield (a list of word timings with whitespace)
110
+ subtitle: List[dict] = []
111
+ last: float = get_start(result["segments"]) or 0.0
112
+ for segment in result["segments"]:
113
+ chunk_index = 0
114
+ words_count = max_words_per_line
115
+ while chunk_index < len(segment["words"]):
116
+ remaining_words = len(segment["words"]) - chunk_index
117
+ if max_words_per_line > len(segment["words"]) - chunk_index:
118
+ words_count = remaining_words
119
+ for i, original_timing in enumerate(
120
+ segment["words"][chunk_index : chunk_index + words_count]
121
+ ):
122
+ timing = original_timing.copy()
123
+ long_pause = (
124
+ not preserve_segments and timing["start"] - last > 3.0
125
+ )
126
+ has_room = line_len + len(timing["word"]) <= max_line_width
127
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
128
+ if (
129
+ line_len > 0
130
+ and has_room
131
+ and not long_pause
132
+ and not seg_break
133
+ ):
134
+ # line continuation
135
+ line_len += len(timing["word"])
136
+ else:
137
+ # new line
138
+ timing["word"] = timing["word"].strip()
139
+ if (
140
+ len(subtitle) > 0
141
+ and max_line_count is not None
142
+ and (long_pause or line_count >= max_line_count)
143
+ or seg_break
144
+ ):
145
+ # subtitle break
146
+ yield subtitle
147
+ subtitle = []
148
+ line_count = 1
149
+ elif line_len > 0:
150
+ # line break
151
+ line_count += 1
152
+ timing["word"] = "\n" + timing["word"]
153
+ line_len = len(timing["word"].strip())
154
+ subtitle.append(timing)
155
+ last = timing["start"]
156
+ chunk_index += max_words_per_line
157
+ if len(subtitle) > 0:
158
+ yield subtitle
159
+
160
+ if len(result["segments"]) > 0 and "words" in result["segments"][0]:
161
+ for subtitle in iterate_subtitles():
162
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
163
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
164
+ subtitle_text = "".join([word["word"] for word in subtitle])
165
+ if highlight_words:
166
+ last = subtitle_start
167
+ all_words = [timing["word"] for timing in subtitle]
168
+ for i, this_word in enumerate(subtitle):
169
+ start = self.format_timestamp(this_word["start"])
170
+ end = self.format_timestamp(this_word["end"])
171
+ if last != start:
172
+ yield last, start, subtitle_text
173
+
174
+ yield start, end, "".join(
175
+ [
176
+ re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
177
+ if j == i
178
+ else word
179
+ for j, word in enumerate(all_words)
180
+ ]
181
+ )
182
+ last = end
183
+ else:
184
+ yield subtitle_start, subtitle_end, subtitle_text
185
+ else:
186
+ for segment in result["segments"]:
187
+ segment_start = self.format_timestamp(segment["start"])
188
+ segment_end = self.format_timestamp(segment["end"])
189
+ segment_text = segment["text"].strip().replace("-->", "->")
190
+ yield segment_start, segment_end, segment_text
191
+
192
+ def format_timestamp(self, seconds: float):
193
+ return format_timestamp(
194
+ seconds=seconds,
195
+ always_include_hours=self.always_include_hours,
196
+ decimal_marker=self.decimal_marker,
197
+ )
198
+
199
+
200
+ class WriteVTT(SubtitlesWriter):
201
+ extension: str = "vtt"
202
+ always_include_hours: bool = False
203
+ decimal_marker: str = "."
204
+
205
+ def write_result(
206
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
207
+ ):
208
+ print("WEBVTT\n", file=file)
209
+ for start, end, text in self.iterate_result(result, options, **kwargs):
210
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
211
+
212
+
213
+ class WriteSRT(SubtitlesWriter):
214
+ extension: str = "srt"
215
+ always_include_hours: bool = True
216
+ decimal_marker: str = ","
217
+
218
+ def write_result(
219
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
220
+ ):
221
+ for i, (start, end, text) in enumerate(
222
+ self.iterate_result(result, options, **kwargs), start=1
223
+ ):
224
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
225
+
226
+
227
+ class WriteTSV(ResultWriter):
228
+ """
229
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
230
+ <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
231
+
232
+ Using integer milliseconds as start and end times means there's no chance of interference from
233
+ an environment setting a language encoding that causes the decimal in a floating point number
234
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
235
+ """
236
+
237
+ extension: str = "tsv"
238
+
239
+ def write_result(
240
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
241
+ ):
242
+ print("start", "end", "text", sep="\t", file=file)
243
+ for segment in result["segments"]:
244
+ print(round(1000 * segment["start"]), file=file, end="\t")
245
+ print(round(1000 * segment["end"]), file=file, end="\t")
246
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
247
+
248
+
249
+ class WriteJSON(ResultWriter):
250
+ extension: str = "json"
251
+
252
+ def write_result(
253
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
254
+ ):
255
+ json.dump(result, file)
256
+
257
+
258
+ def get_writer(
259
+ output_format: str, output_dir: str
260
+ ) -> Callable[[dict, TextIO, dict], None]:
261
+ writers = {
262
+ "txt": WriteTXT,
263
+ "vtt": WriteVTT,
264
+ "srt": WriteSRT,
265
+ "tsv": WriteTSV,
266
+ "json": WriteJSON,
267
+ }
268
+
269
+ if output_format == "all":
270
+ all_writers = [writer(output_dir) for writer in writers.values()]
271
+
272
+ def write_all(
273
+ result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
274
+ ):
275
+ for writer in all_writers:
276
+ writer(result, file, options, **kwargs)
277
+
278
+ return write_all
279
+
280
+ return writers[output_format](output_dir)
281
 
282
 
283
  def parse_srt(file_path):