Commit
·
66c268e
1
Parent(s):
12af9b4
added float16 loading change to whisper audio tower
Browse files- modeling_bunny_phi.py +2 -4
modeling_bunny_phi.py
CHANGED
@@ -615,7 +615,7 @@ class WhisperAudioTower(nn.Module):
|
|
615 |
if self.is_loaded:
|
616 |
return
|
617 |
|
618 |
-
self.audio_tower = WhisperModel.from_pretrained(self.audio_tower_name)
|
619 |
|
620 |
self.audio_tower.requires_grad_(False)
|
621 |
self.audio_tower.eval()
|
@@ -2627,10 +2627,8 @@ class BunnyPhiForCausalLM(PhiForCausalLM, BunnyMetaForCausalLM):
|
|
2627 |
audio_tower = self.get_audio_tower()
|
2628 |
if not audio_tower.is_loaded:
|
2629 |
audio_tower.load_model()
|
2630 |
-
audio_tower.to(device='cuda', dtype=torch.float16)
|
2631 |
audio_processor = audio_tower.audio_processor
|
2632 |
-
audio_processor
|
2633 |
-
features = audio_processor(audio, sampling_rate=16000, return_tensors="pt").input_features # replace 16k with arg later
|
2634 |
audio_tensor = features.to(self.device, dtype=self.dtype)
|
2635 |
return audio_tensor
|
2636 |
|
|
|
615 |
if self.is_loaded:
|
616 |
return
|
617 |
|
618 |
+
self.audio_tower = WhisperModel.from_pretrained(self.audio_tower_name, torch_dtype=torch.float16)
|
619 |
|
620 |
self.audio_tower.requires_grad_(False)
|
621 |
self.audio_tower.eval()
|
|
|
2627 |
audio_tower = self.get_audio_tower()
|
2628 |
if not audio_tower.is_loaded:
|
2629 |
audio_tower.load_model()
|
|
|
2630 |
audio_processor = audio_tower.audio_processor
|
2631 |
+
features = audio_processor(audio, sampling_rate=16000, return_tensors="pt", device='cuda').input_features # replace 16k with arg later
|
|
|
2632 |
audio_tensor = features.to(self.device, dtype=self.dtype)
|
2633 |
return audio_tensor
|
2634 |
|