minnehwg commited on
Commit
65b1238
·
verified ·
1 Parent(s): 74fcca9

Update util.py

Browse files
Files changed (1) hide show
  1. util.py +142 -0
util.py CHANGED
@@ -1,2 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def update(name):
2
  return f"Welcome to Gradio, {name}!"
 
1
+ from datasets import Dataset
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments
3
+ from youtube_transcript_api import YouTubeTranscriptApi
4
+ from deepmultilingualpunctuation import PunctuationModel
5
+ from googletrans import Translator
6
+ import time
7
+ import torch
8
+ import re
9
+
10
+ def load_model(cp):
11
+ tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base")
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(cp)
13
+ return tokenizer, model
14
+
15
+
16
+ def summarize(text, model, tokenizer, num_beams=4, device='cpu'):
17
+ model.to(device)
18
+ inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device)
19
+
20
+ with torch.no_grad():
21
+ summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams)
22
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
23
+
24
+ return summary
25
+
26
+
27
+ def processed(text):
28
+ processed_text = text.replace('\n', ' ')
29
+ processed_text = processed_text.lower()
30
+ return processed_text
31
+
32
+
33
+ def get_subtitles(video_url):
34
+ try:
35
+ video_id = video_url.split("v=")[1]
36
+ transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
37
+ subs = " ".join(entry['text'] for entry in transcript)
38
+ print(subs)
39
+
40
+ return transcript, subs
41
+
42
+ except Exception as e:
43
+ return [], f"An error occurred: {e}"
44
+
45
+ from youtube_transcript_api import YouTubeTranscriptApi
46
+
47
+
48
+ def restore_punctuation(text):
49
+ model = PunctuationModel()
50
+ result = model.restore_punctuation(text)
51
+ return result
52
+
53
+
54
+ def translate_long(text, language='vi'):
55
+ translator = Translator()
56
+ limit = 4700
57
+ chunks = []
58
+ current_chunk = ''
59
+
60
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
61
+
62
+ for sentence in sentences:
63
+ if len(current_chunk) + len(sentence) <= limit:
64
+ current_chunk += sentence.strip() + ' '
65
+ else:
66
+ chunks.append(current_chunk.strip())
67
+ current_chunk = sentence.strip() + ' '
68
+
69
+ if current_chunk:
70
+ chunks.append(current_chunk.strip())
71
+
72
+ translated_text = ''
73
+
74
+ for chunk in chunks:
75
+ try:
76
+ time.sleep(1)
77
+ translation = translator.translate(chunk, dest=language)
78
+ translated_text += translation.text + ' '
79
+ except Exception as e:
80
+ translated_text += chunk + ' '
81
+
82
+ return translated_text.strip()
83
+
84
+ def split_into_chunks(text, max_words=800, overlap_sentences=2):
85
+ sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
86
+
87
+ chunks = []
88
+ current_chunk = []
89
+ current_word_count = 0
90
+
91
+ for sentence in sentences:
92
+ word_count = len(sentence.split())
93
+ if current_word_count + word_count <= max_words:
94
+ current_chunk.append(sentence)
95
+ current_word_count += word_count
96
+ else:
97
+ if len(current_chunk) >= overlap_sentences:
98
+ overlap = current_chunk[-overlap_sentences:]
99
+ chunks.append(' '.join(current_chunk))
100
+ current_chunk = current_chunk[-overlap_sentences:] + [sentence]
101
+ current_word_count = sum(len(sent.split()) for sent in current_chunk)
102
+ if current_chunk:
103
+ if len(current_chunk) >= overlap_sentences:
104
+ overlap = current_chunk[-overlap_sentences:]
105
+ chunks.append(' '.join(current_chunk))
106
+
107
+ return chunks
108
+
109
+
110
+ def post_processing(text):
111
+ sentences = re.split(r'(?<=[.!?])\s*', text)
112
+ for i in range(len(sentences)):
113
+ if sentences[i]:
114
+ sentences[i] = sentences[i][0].upper() + sentences[i][1:]
115
+ text = " ".join(sentences)
116
+ return text
117
+
118
+
119
+ def display(text):
120
+ sentences = re.split(r'(?<=[.!?])\s*', text)
121
+ unique_sentences = list(dict.fromkeys(sentences[:-1]))
122
+ formatted_sentences = [f"• {sentence}" for sentence in unique_sentences]
123
+ return formatted_sentences
124
+
125
+
126
+
127
+ def pipeline(url, model, tokenizer):
128
+ trans, sub = get_subtitles(url)
129
+ sub = restore_punctuation(sub)
130
+ vie_sub = translate_long(sub)
131
+ vie_sub = processed(vie_sub)
132
+ chunks = split_into_chunks(vie_sub, 700, 2)
133
+ sum_para = []
134
+ for i in chunks:
135
+ tmp = summarize(i, model, tokenizer, num_beams=3)
136
+ sum_para.append(tmp)
137
+ suma = ''.join(sum_para)
138
+ del sub, vie_sub, sum_para, chunks
139
+ suma = post_processing(suma)
140
+ re = display(suma)
141
+ return re
142
+
143
  def update(name):
144
  return f"Welcome to Gradio, {name}!"