diff --git a/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu b/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu index 2bde2f0..43b4370 100644 --- a/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu +++ b/awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu @@ -247,6 +247,7 @@ torch::Tensor gemv_forward_cuda_new( dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE); dim3 num_threads(BLOCK_SIZE); + constexpr int kSmemByteSizePerBatch = N_PER_BLOCK * K_INTERLEAVE * BLOCK_SIZE; // if (group_size == 64) // { // gemv_kernel_g64<<>>( @@ -261,37 +262,37 @@ torch::Tensor gemv_forward_cuda_new( switch (m) { case 1: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 2: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 3: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 4: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 5: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 6: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break; case 7: - gemv_kernel<<>>( + gemv_kernel<<>>( in_feats, kernel, scaling_factors, zeros, out_feats, k, n ); break;