iaooo-shivprasad commited on
Commit
66c268e
·
1 Parent(s): 12af9b4

added float16 loading change to whisper audio tower

Browse files
Files changed (1) hide show
  1. modeling_bunny_phi.py +2 -4
modeling_bunny_phi.py CHANGED
@@ -615,7 +615,7 @@ class WhisperAudioTower(nn.Module):
615
  if self.is_loaded:
616
  return
617
 
618
- self.audio_tower = WhisperModel.from_pretrained(self.audio_tower_name)
619
 
620
  self.audio_tower.requires_grad_(False)
621
  self.audio_tower.eval()
@@ -2627,10 +2627,8 @@ class BunnyPhiForCausalLM(PhiForCausalLM, BunnyMetaForCausalLM):
2627
  audio_tower = self.get_audio_tower()
2628
  if not audio_tower.is_loaded:
2629
  audio_tower.load_model()
2630
- audio_tower.to(device='cuda', dtype=torch.float16)
2631
  audio_processor = audio_tower.audio_processor
2632
- audio_processor.to(device='cuda')
2633
- features = audio_processor(audio, sampling_rate=16000, return_tensors="pt").input_features # replace 16k with arg later
2634
  audio_tensor = features.to(self.device, dtype=self.dtype)
2635
  return audio_tensor
2636
 
 
615
  if self.is_loaded:
616
  return
617
 
618
+ self.audio_tower = WhisperModel.from_pretrained(self.audio_tower_name, torch_dtype=torch.float16)
619
 
620
  self.audio_tower.requires_grad_(False)
621
  self.audio_tower.eval()
 
2627
  audio_tower = self.get_audio_tower()
2628
  if not audio_tower.is_loaded:
2629
  audio_tower.load_model()
 
2630
  audio_processor = audio_tower.audio_processor
2631
+ features = audio_processor(audio, sampling_rate=16000, return_tensors="pt", device='cuda').input_features # replace 16k with arg later
 
2632
  audio_tensor = features.to(self.device, dtype=self.dtype)
2633
  return audio_tensor
2634