File size: 4,013 Bytes
4008bf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
#!/usr/bin/env python3
# make sure to use branch: https://github.com/huggingface/transformers/pull/26701
import copy
import time
import torch
from datasets import load_dataset
from transformers import (
AutoProcessor,
WhisperForConditionalGeneration,
)
DEVICE = "cuda"
DTYPE = torch.float16
SAMPLING_RATE = 16_000
BATCH_SIZE = 1
USE_FLASH_ATTN_2 = True
# TO DEBUG
GAMMAS = [5, 7, 6, 5, 4, 3, 5]
COUNT = 0
# local loading is faster
teacher = WhisperForConditionalGeneration.from_pretrained(
"/home/patrick/distil_whisper/",
torch_dtype=DTYPE,
variant="fp16",
low_cpu_mem_usage=True,
use_flash_attention_2=USE_FLASH_ATTN_2,
)
student = WhisperForConditionalGeneration.from_pretrained(
"/home/patrick/distil_whisper_student/",
torch_dtype=DTYPE,
variant="fp16",
low_cpu_mem_usage=True,
use_flash_attention_2=USE_FLASH_ATTN_2,
)
# student = WhisperForCausalLM.from_pretrained("/home/patrick/distil_whisper_student", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True, use_flash_attention_2=USE_FLASH_ATTN_2)
student.generation_config = copy.deepcopy(teacher.generation_config)
student.generation_config.num_assistant_tokens_schedule = "constant"
# teacher = WhisperForConditionalGeneration.from_pretrained(
# "openai/whisper-large-v2", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True
# )
# student = WhisperForConditionalGeneration.from_pretrained(
# "sanchit-gandhi/large-32-2-gpu-flat-lr", torch_dtype=DTYPE, variant="fp16", low_cpu_mem_usage=True
# )
#
teacher.to(DEVICE)
student.to(DEVICE)
processor = AutoProcessor.from_pretrained("sanchit-gandhi/large-32-2-gpu-flat-lr")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
total_time_default = 0
total_time_spec = 0
total_time_spec_2 = 0
input_values = ds[0]["audio"]["array"]
inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE)
input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE)
_ = teacher.generate(input_features, max_length=100)
end_idx = ds.shape[0]
for audio_idx in range(0, end_idx, BATCH_SIZE):
input_values = ds[audio_idx : audio_idx + BATCH_SIZE]
input_values = [i["array"] for i in input_values["audio"]]
inputs = processor(input_values, return_tensors="pt", sampling_rate=SAMPLING_RATE)
input_features = inputs.input_features.to(device=DEVICE, dtype=DTYPE)
start_time = time.time()
out = teacher.generate(input_features, max_length=100)
run_time = time.time() - start_time
print(f"Normal Decoding: {run_time}")
total_time_default += run_time
default_out = processor.batch_decode(out, skip_special_tokens=True)
# print("Output", default_out)
# start_time = time.time()
# with torch.no_grad():
# encoder_outputs = teacher.get_encoder()(input_features).last_hidden_state
# out, ratio = speculative_decoding(teacher, student, encoder_outputs, max_length=100, gamma=5)
# run_time = time.time() - start_time
# print(20 * "=")
# print(f"Speculative Decoding: {run_time}")
# total_time_spec += run_time
# spec_out = processor.batch_decode(out)
start_time = time.time()
with torch.no_grad():
encoder_outputs = teacher.get_encoder()(input_features)
out = teacher.generate(
assistant_model=student,
assistant_encoder_outputs=encoder_outputs,
encoder_outputs=encoder_outputs,
max_length=100,
)
run_time = time.time() - start_time
spec_out_2 = processor.batch_decode(out, skip_special_tokens=True)
print(f"Speculative Decoding 2: {run_time}")
total_time_spec_2 += run_time
if spec_out_2 != default_out:
COUNT += 1
print(f"Audio {audio_idx} does not match. Spec: {spec_out_2}, True: {default_out}")
print(20 * "=")
print("Total time", total_time_default)
print(f"Overall speed-up spec 2 {total_time_default / total_time_spec_2}")
# print(f"Overall speed-up {total_time_default / total_time_spec}")
|