sanchit-gandhi commited on
Commit
0a7fcda
·
1 Parent(s): 523090a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -5,7 +5,7 @@ import pytube as pt
5
  from transformers import pipeline
6
  from huggingface_hub import model_info
7
 
8
- MODEL_NAME = "openai/whisper-large-v2"
9
 
10
  device = 0 if torch.cuda.is_available() else "cpu"
11
 
@@ -17,7 +17,12 @@ pipe = pipeline(
17
  )
18
 
19
 
20
- def transcribe(microphone, file_upload):
 
 
 
 
 
21
  warn_output = ""
22
  if (microphone is not None) and (file_upload is not None):
23
  warn_output = (
@@ -30,6 +35,8 @@ def transcribe(microphone, file_upload):
30
 
31
  file = microphone if microphone is not None else file_upload
32
 
 
 
33
  text = pipe(file)["text"]
34
 
35
  return warn_output + text
@@ -44,12 +51,14 @@ def _return_yt_html_embed(yt_url):
44
  return HTML_str
45
 
46
 
47
- def yt_transcribe(yt_url):
48
  yt = pt.YouTube(yt_url)
49
  html_embed_str = _return_yt_html_embed(yt_url)
50
  stream = yt.streams.filter(only_audio=True)[0]
51
  stream.download(filename="audio.mp3")
52
 
 
 
53
  text = pipe("audio.mp3")["text"]
54
 
55
  return html_embed_str, text
@@ -62,6 +71,7 @@ mf_transcribe = gr.Interface(
62
  inputs=[
63
  gr.inputs.Audio(source="microphone", type="filepath", optional=True),
64
  gr.inputs.Audio(source="upload", type="filepath", optional=True),
 
65
  ],
66
  outputs="text",
67
  layout="horizontal",
@@ -77,7 +87,7 @@ mf_transcribe = gr.Interface(
77
 
78
  yt_transcribe = gr.Interface(
79
  fn=yt_transcribe,
80
- inputs=[gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
81
  outputs=["html", "text"],
82
  layout="horizontal",
83
  theme="huggingface",
 
5
  from transformers import pipeline
6
  from huggingface_hub import model_info
7
 
8
+ MODEL_NAME = "openai/whisper-tiny"
9
 
10
  device = 0 if torch.cuda.is_available() else "cpu"
11
 
 
17
  )
18
 
19
 
20
+ all_special_ids = pipe.tokenizer.all_special_ids
21
+ transcribe_token_id = all_special_ids[-5]
22
+ translate_token_id = all_special_ids[-6]
23
+
24
+
25
+ def transcribe(microphone, file_upload, task):
26
  warn_output = ""
27
  if (microphone is not None) and (file_upload is not None):
28
  warn_output = (
 
35
 
36
  file = microphone if microphone is not None else file_upload
37
 
38
+ pipe.model.config.forced_decoder_ids = [[2, translate_token_id if do_translate else transcribe_token_id]]
39
+
40
  text = pipe(file)["text"]
41
 
42
  return warn_output + text
 
51
  return HTML_str
52
 
53
 
54
+ def yt_transcribe(yt_url, do_translate):
55
  yt = pt.YouTube(yt_url)
56
  html_embed_str = _return_yt_html_embed(yt_url)
57
  stream = yt.streams.filter(only_audio=True)[0]
58
  stream.download(filename="audio.mp3")
59
 
60
+ pipe.model.config.forced_decoder_ids = [[2, translate_token_id if do_translate else transcribe_token_id]]
61
+
62
  text = pipe("audio.mp3")["text"]
63
 
64
  return html_embed_str, text
 
71
  inputs=[
72
  gr.inputs.Audio(source="microphone", type="filepath", optional=True),
73
  gr.inputs.Audio(source="upload", type="filepath", optional=True),
74
+ gr.Checkbox(label="Translate?", value=False),
75
  ],
76
  outputs="text",
77
  layout="horizontal",
 
87
 
88
  yt_transcribe = gr.Interface(
89
  fn=yt_transcribe,
90
+ inputs=[gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), gr.Checkbox(label="Translate?", value=False)],
91
  outputs=["html", "text"],
92
  layout="horizontal",
93
  theme="huggingface",