From 40f820541031dec924870b5531495685c2805e90 Mon Sep 17 00:00:00 2001 From: tmxu Date: Wed, 12 Jun 2024 13:49:15 +0800 Subject: [PATCH] fix illegal memory access of GEMV kernel --- .../csrc/quantization_new/gemv/gemv_cuda.cu | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu b/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu index 2bde2f0..43b4370 100644 --- a/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu +++ b/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu @@ -247,6 +247,7 @@ torch::Tensor gemv_forward_cuda_new( dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE); dim3 num_threads(BLOCK_SIZE); + constexpr int kSmemByteSizePerBatch = N_PER_BLOCK * K_INTERLEAVE * BLOCK_SIZE; // if (group_size == 64) // { // gemv_kernel_g64<<>>( @@ -261,37 +262,37 @@ torch::Tensor gemv_forward_cuda_new( switch (m) { case 1: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 2: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 3: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 4: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 5: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 6: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 7: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break;