camparchimedes commited on
Commit
0a1c65b
Β·
verified Β·
1 Parent(s): 84b5bed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -75
app.py CHANGED
@@ -1,75 +1,32 @@
1
- import gradio as gr
2
- import warnings
3
- import torch
4
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
5
- import soundfile as sf
6
-
7
- warnings.filterwarnings("ignore")
8
-
9
- # Load tokenizer + model
10
- processor = AutoProcessor.from_pretrained("NbAiLab/nb-whisper-large-verbatim")
11
- model = AutoModelForSpeechSeq2Seq.from_pretrained("NbAiLab/nb-whisper-large-verbatim")
12
-
13
- # set up device
14
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- torch_dtype = torch.float32
16
-
17
- # move model to device
18
- model.to(device)
19
-
20
- def transcribe_audio(audio_file, batch_size=4):
21
- audio_input, sample_rate = sf.read(audio_file)
22
- chunk_size = 16000 * 30
23
- chunks = [audio_input[i:i + chunk_size] for i in range(0, len(audio_input), chunk_size)]
24
-
25
- transcription = ""
26
- for i in range(0, len(chunks), batch_size):
27
- batch_chunks = chunks[i:i + batch_size]
28
- inputs = processor(batch_chunks, sampling_rate=16000, return_tensors="pt", padding=False)
29
- inputs = inputs.to(device)
30
-
31
- # Manually define the attention mask
32
- attention_mask = torch.ones(inputs.input_features.shape[:2], dtype=torch.long)
33
- attention_mask = attention_mask.to(device)
34
-
35
- tokenizer.pad_token != tokenizer.eos_token
36
-
37
- with torch.no_grad():
38
- output = model.generate(
39
- inputs.input_features,
40
- max_length=1024, # Increase max_length for longer outputs
41
- num_beams=7,
42
- task="transcribe",
43
- attention_mask=attention_mask,
44
- forced_decoder_ids=None, # forced_decoder_ids must not be set
45
- language="no"
46
- )
47
- transcription += " ".join(processor.batch_decode(output, skip_special_tokens=True)) + " "
48
-
49
- return transcription.strip()
50
-
51
-
52
-
53
-
54
- # HTML |banner image
55
- banner_html = """
56
- <div style="text-align: center;">
57
- <img src="https://huggingface.co/spaces/camparchimedes/ola_s-audioshop/resolve/main/Olas_AudioSwitch_Shop.png" width="87%" height="auto"/>
58
- </div>
59
- """
60
-
61
- # Gradio interface
62
- iface = gr.Blocks()
63
-
64
- with iface:
65
- gr.HTML(banner_html)
66
- gr.Markdown("# 𝐍𝐯𝐒𝐝𝐒𝐚 π€πŸπŸŽπŸŽ πŸ‘‹πŸΌπŸ‘ΎπŸ¦Ύβš‘ @{NbAiLab/whisper-norwegian-medium}\nUpload audio file:β˜•")
67
- audio_input = gr.Audio(type="filepath")
68
- batch_size_input = gr.Slider(minimum=1, maximum=16, step=1, label="Batch Size")
69
- transcription_output = gr.Textbox()
70
- transcribe_button = gr.Button("Transcribe")
71
-
72
- transcribe_button.click(fn=transcribe_audio, inputs=[audio_input, batch_size_input], outputs=transcription_output)
73
-
74
- # Launch interface
75
- iface.launch(share=True, debug=True)
 
1
+ def test_eos_pad():
2
+ from datasets import load_dataset
3
+ import torch
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
+
6
+ raw_text_batch = 'a'
7
+
8
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
9
+ # print(f'{tokenizer.eos_token=}')
10
+ # print(f'{tokenizer.eos_token_id=}')
11
+ # print(f'{tokenizer.pad_token=}')
12
+ # print(f'{tokenizer.pad_token_id=}')
13
+
14
+ # print(f'{raw_text_batch=}')
15
+ # tokenize_batch = tokenizer(raw_text_batch, padding="max_length", max_length=5, truncation=True, return_tensors="pt")
16
+ # print(f'{tokenize_batch=}')
17
+
18
+ if tokenizer.pad_token_id is None:
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+ probe_network = GPT2LMHeadModel.from_pretrained("gpt2")
21
+ device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
22
+ probe_network = probe_network.to(device)
23
+
24
+ print(f'{tokenizer.eos_token=}')
25
+ print(f'{tokenizer.eos_token_id=}')
26
+ print(f'{tokenizer.pad_token=}')
27
+ print(f'{tokenizer.pad_token_id=}')
28
+
29
+ print(f'{raw_text_batch=}')
30
+ tokenize_batch = tokenizer(raw_text_batch, padding="max_length", max_length=5, truncation=True, return_tensors="pt")
31
+ print(f'{tokenize_batch=}')
32
+ print('Done')