update modeling_qwen.py
Browse files- modeling_qwen.py +11 -4
modeling_qwen.py
CHANGED
|
@@ -31,7 +31,11 @@ try:
|
|
| 31 |
except ImportError:
|
| 32 |
rearrange = None
|
| 33 |
from torch import nn
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 37 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
|
@@ -293,7 +297,7 @@ class QWenAttention(nn.Module):
|
|
| 293 |
device = query.device
|
| 294 |
if self.use_cache_quantization:
|
| 295 |
qk, qk_scale, qk_zero = key
|
| 296 |
-
if self.use_cache_kernel:
|
| 297 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
| 298 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 299 |
cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
|
|
@@ -348,7 +352,7 @@ class QWenAttention(nn.Module):
|
|
| 348 |
|
| 349 |
if self.use_cache_quantization:
|
| 350 |
qv, qv_scale, qv_zero = value
|
| 351 |
-
if self.use_cache_kernel:
|
| 352 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
| 353 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 354 |
cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
|
|
@@ -1021,7 +1025,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
| 1021 |
if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
|
| 1022 |
config.use_flash_attn = False
|
| 1023 |
if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
|
| 1024 |
-
|
|
|
|
|
|
|
|
|
|
| 1025 |
|
| 1026 |
self.transformer = QWenModel(config)
|
| 1027 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
| 31 |
except ImportError:
|
| 32 |
rearrange = None
|
| 33 |
from torch import nn
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
| 37 |
+
except ImportError:
|
| 38 |
+
cache_autogptq_cuda_256 = None
|
| 39 |
|
| 40 |
SUPPORT_CUDA = torch.cuda.is_available()
|
| 41 |
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
|
|
|
| 297 |
device = query.device
|
| 298 |
if self.use_cache_quantization:
|
| 299 |
qk, qk_scale, qk_zero = key
|
| 300 |
+
if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
|
| 301 |
shape = query.shape[:-1] + (qk.shape[-2],)
|
| 302 |
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 303 |
cache_autogptq_cuda_256.vecquant8matmul_batched_faster_old(
|
|
|
|
| 352 |
|
| 353 |
if self.use_cache_quantization:
|
| 354 |
qv, qv_scale, qv_zero = value
|
| 355 |
+
if self.use_cache_kernel and cache_autogptq_cuda_256 is not None:
|
| 356 |
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
| 357 |
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
| 358 |
cache_autogptq_cuda_256.vecquant8matmul_batched_column_compression_faster_old(
|
|
|
|
| 1025 |
if hasattr(config, 'use_cache_quantization') and config.use_cache_quantization:
|
| 1026 |
config.use_flash_attn = False
|
| 1027 |
if hasattr(config, 'use_cache_kernel') and config.use_cache_kernel:
|
| 1028 |
+
try:
|
| 1029 |
+
from kernels.cpp_kernels import cache_autogptq_cuda_256
|
| 1030 |
+
except ImportError:
|
| 1031 |
+
cache_autogptq_cuda_256 = None
|
| 1032 |
|
| 1033 |
self.transformer = QWenModel(config)
|
| 1034 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|