-
Notifications
You must be signed in to change notification settings - Fork 155
/
quant_cuda.cpp
34 lines (29 loc) · 1.13 KB
/
quant_cuda.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>
void vecquant3matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant3matmul_faster_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
);
void vecquant3matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda(vec, mat, mul, scales, zeros);
}
void vecquant3matmul_faster(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_faster_cuda(vec, mat, mul, scales, zeros);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul_faster", &vecquant3matmul_faster, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version");
}