Qwen
/

yangapku commited on
Commit
03752a6
·
1 Parent(s): 013d71a

add kernel file check in modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +14 -4
modeling_qwen.py CHANGED
@@ -6,11 +6,13 @@
6
  import copy
7
  import importlib
8
  import math
 
9
  from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
10
 
11
  import torch
12
  import torch.nn.functional as F
13
  import torch.utils.checkpoint
 
14
  from torch.cuda.amp import autocast
15
 
16
  from torch.nn import CrossEntropyLoss
@@ -295,11 +297,19 @@ class QWenAttention(nn.Module):
295
  self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
296
 
297
  if config.use_cache_quantization and config.use_cache_kernel:
298
- from .cpp_kernels import cache_autogptq_cuda_256
299
- try:
300
- self.cache_kernels = cache_autogptq_cuda_256
301
- except ImportError:
 
302
  self.cache_kernels = None
 
 
 
 
 
 
 
303
 
304
  def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
305
  device = query.device
 
6
  import copy
7
  import importlib
8
  import math
9
+ import pathlib
10
  from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
11
 
12
  import torch
13
  import torch.nn.functional as F
14
  import torch.utils.checkpoint
15
+ import warnings
16
  from torch.cuda.amp import autocast
17
 
18
  from torch.nn import CrossEntropyLoss
 
297
  self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
298
 
299
  if config.use_cache_quantization and config.use_cache_kernel:
300
+ # pre check if the support files existing
301
+ module_root = pathlib.Path(__file__).parent
302
+ src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
303
+ if any(not (module_root/src).is_file() for src in src_files):
304
+ warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
305
  self.cache_kernels = None
306
+ else:
307
+ try:
308
+ from .cpp_kernels import cache_autogptq_cuda_256
309
+ self.cache_kernels = cache_autogptq_cuda_256
310
+ except ImportError:
311
+ warnings.warn("Failed to import KV cache kernels.")
312
+ self.cache_kernels = None
313
 
314
  def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
315
  device = query.device