Skip to content

Commit

Permalink
Faiss GPU: bfloat16 brute-force kNN support (facebookresearch#4018)
Browse files Browse the repository at this point in the history
Summary:


This diff adds support for bfloat16 vector/query data types with the GPU brute-force k-nearest neighbor function (`bfKnn`).

The change is largely just plumbing the new data type through the template hierarchy (so distances can be computed in bfloat16).

Of note, by design, all final distance results are produced in float32 regardless of input data type (float32, float16, bfloat16). This is because the true nearest neighbors in many data sets can often differ by only ~1000 float32 ULPs in terms of distance which will result in possible false equivalency. This seems to be one area where lossy compression/quantization thoughout does not work as well (and is also why `CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION` is set in `StandardGpuResources.cpp`. However, given that there is native bf16 x bf16 = fp32 tensor core support on Ampere+ architectures, the matrix multiplication itself should use them.

As bfloat16 support is quite lacking on AMD/ROCm (see [here](https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Device_API_supported_by_HIP.html), very few bf16 functions implemented), bf16 functionality is completely disabled / not compiled for AMD ROCm.

Reviewed By: mdouze

Differential Revision: D65459723
  • Loading branch information
Jeff Johnson authored and facebook-github-bot committed Nov 19, 2024
1 parent 3c25a68 commit 97467dc
Show file tree
Hide file tree
Showing 20 changed files with 913 additions and 90 deletions.
17 changes: 15 additions & 2 deletions contrib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def swig_ptr_from_FloatTensor(x):
return faiss.cast_integer_to_float_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 4)

def swig_ptr_from_BFloat16Tensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
assert x.is_contiguous()
assert x.dtype == torch.bfloat16
return faiss.cast_integer_to_void_ptr(
x.untyped_storage().data_ptr() + x.storage_offset() * 2)


def swig_ptr_from_IntTensor(x):
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
Expand Down Expand Up @@ -606,8 +613,11 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
elif xb.dtype == torch.float16:
xb_type = faiss.DistanceDataType_F16
xb_ptr = swig_ptr_from_HalfTensor(xb)
elif xb.dtype == torch.bfloat16:
xb_type = faiss.DistanceDataType_BF16
xb_ptr = swig_ptr_from_BFloat16Tensor(xb)
else:
raise TypeError('xb must be f32 or f16')
raise TypeError('xq must be float32, float16 or bfloat16')

nq, d2 = xq.size()
assert d2 == d
Expand All @@ -625,8 +635,11 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
elif xq.dtype == torch.float16:
xq_type = faiss.DistanceDataType_F16
xq_ptr = swig_ptr_from_HalfTensor(xq)
elif xq.dtype == torch.bfloat16:
xq_type = faiss.DistanceDataType_BF16
xq_ptr = swig_ptr_from_BFloat16Tensor(xq)
else:
raise TypeError('xq must be f32 or f16')
raise TypeError('xq must be float32, float16 or bfloat16')

if D is None:
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
Expand Down
26 changes: 21 additions & 5 deletions faiss/gpu/GpuDistance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#include <optional>

#if defined USE_NVIDIA_CUVS
Expand Down Expand Up @@ -231,7 +232,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
FAISS_THROW_IF_NOT_MSG(
args.vectorType == args.queryType,
"limitation: both vectorType and queryType must currently "
"be the same (F32 or F16");
"be the same (F32 / F16 / BF16");

#if defined USE_NVIDIA_CUVS
// Note: For now, cuVS bfknn requires queries and vectors to be same layout
Expand Down Expand Up @@ -400,6 +401,17 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
bfKnnConvert<float>(prov, args);
} else if (args.vectorType == DistanceDataType::F16) {
bfKnnConvert<half>(prov, args);
} else if (args.vectorType == DistanceDataType::BF16) {
// no bf16 support for AMD
#ifndef USE_AMD_ROCM
if (prov->getResources()->supportsBFloat16CurrentDevice()) {
bfKnnConvert<__nv_bfloat16>(prov, args);
} else {
FAISS_THROW_MSG("not compiled with bfloat16 support");
}
#else
FAISS_THROW_MSG("no AMD bfloat16 support");
#endif
} else {
FAISS_THROW_MSG("unknown vectorType");
}
Expand Down Expand Up @@ -466,8 +478,10 @@ void bfKnn_single_query_shard(
args.k > 0,
"bfKnn_tiling: tiling vectors is only supported for k > 0");
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
: args.vectorType == DistanceDataType::F16 ? 2
: 0;
: (args.vectorType == DistanceDataType::F16 ||
args.vectorType == DistanceDataType::BF16)
? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown vectorType");
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
Expand Down Expand Up @@ -524,8 +538,10 @@ void bfKnn_tiling(
args.k > 0,
"bfKnn_tiling: tiling queries is only supported for k > 0");
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
: args.queryType == DistanceDataType::F16 ? 2
: 0;
: (args.queryType == DistanceDataType::F16 ||
args.queryType == DistanceDataType::BF16)
? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown queryType");
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8
Expand Down
1 change: 1 addition & 0 deletions faiss/gpu/GpuDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class GpuResourcesProvider;
enum class DistanceDataType {
F32 = 1,
F16,
BF16,
};

// Scalar type of the indices data
Expand Down
4 changes: 4 additions & 0 deletions faiss/gpu/GpuResources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() {

GpuResources::~GpuResources() = default;

bool GpuResources::supportsBFloat16CurrentDevice() {
return supportsBFloat16(getCurrentDevice());
}

cublasHandle_t GpuResources::getBlasHandleCurrentDevice() {
return getBlasHandle(getCurrentDevice());
}
Expand Down
6 changes: 6 additions & 0 deletions faiss/gpu/GpuResources.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ class GpuResources {
/// of demand
virtual void initializeForDevice(int device) = 0;

/// Does the given GPU support bfloat16?
virtual bool supportsBFloat16(int device) = 0;

/// Returns the cuBLAS handle that we use for the given device
virtual cublasHandle_t getBlasHandle(int device) = 0;

Expand Down Expand Up @@ -252,6 +255,9 @@ class GpuResources {
/// Functions provided by default
///

/// Does the current GPU support bfloat16?
bool supportsBFloat16CurrentDevice();

/// Calls getBlasHandle with the current device
cublasHandle_t getBlasHandleCurrentDevice();

Expand Down
15 changes: 15 additions & 0 deletions faiss/gpu/StandardGpuResources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU(
return requested;
}

/// Does the given GPU support bfloat16?
bool StandardGpuResourcesImpl::supportsBFloat16(int device) {
initializeForDevice(device);
auto& prop = getDeviceProperties(device);
return prop.major >= 8;
}

void StandardGpuResourcesImpl::noTempMemory() {
setTempMemory(0);
}
Expand Down Expand Up @@ -701,6 +708,14 @@ std::shared_ptr<GpuResources> StandardGpuResources::getResources() {
return res_;
}

bool StandardGpuResources::supportsBFloat16(int device) {
return res_->supportsBFloat16(device);
}

bool StandardGpuResources::supportsBFloat16CurrentDevice() {
return res_->supportsBFloat16CurrentDevice();
}

void StandardGpuResources::noTempMemory() {
res_->noTempMemory();
}
Expand Down
9 changes: 9 additions & 0 deletions faiss/gpu/StandardGpuResources.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources {

~StandardGpuResourcesImpl() override;

/// Does the given GPU support bfloat16?
bool supportsBFloat16(int device) override;

/// Disable allocation of temporary memory; all temporary memory
/// requests will call cudaMalloc / cudaFree at the point of use
void noTempMemory();
Expand Down Expand Up @@ -199,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider {

std::shared_ptr<GpuResources> getResources() override;

/// Whether or not the given device supports native bfloat16 arithmetic
bool supportsBFloat16(int device);

/// Whether or not the current device supports native bfloat16 arithmetic
bool supportsBFloat16CurrentDevice();

/// Disable allocation of temporary memory; all temporary memory
/// requests will call cudaMalloc / cudaFree at the point of use
void noTempMemory();
Expand Down
101 changes: 101 additions & 0 deletions faiss/gpu/impl/Distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,30 @@ void runAllPairwiseL2Distance(
outDistances);
}

// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runAllPairwiseL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances) {
runAllPairwiseDistance<__nv_bfloat16>(
true,
res,
stream,
vectors,
vectorsRowMajor,
vectorNorms,
queries,
queriesRowMajor,
outDistances);
}
#endif // USE_AMD_ROCM

void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Expand Down Expand Up @@ -544,6 +568,29 @@ void runAllPairwiseIPDistance(
outDistances);
}

// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances) {
runAllPairwiseDistance<__nv_bfloat16>(
false,
res,
stream,
vectors,
vectorsRowMajor,
nullptr,
queries,
queriesRowMajor,
outDistances);
}
#endif // USE_AMD_ROCM

void runL2Distance(
GpuResources* res,
cudaStream_t stream,
Expand Down Expand Up @@ -596,6 +643,35 @@ void runL2Distance(
ignoreOutDistances);
}

// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices,
bool ignoreOutDistances) {
runL2Distance<__nv_bfloat16>(
res,
stream,
vectors,
vectorsRowMajor,
vectorNorms,
queries,
queriesRowMajor,
k,
outDistances,
outIndices,
ignoreOutDistances);
}
#endif // USE_AMD_ROCM

void runIPDistance(
GpuResources* res,
cudaStream_t stream,
Expand Down Expand Up @@ -640,5 +716,30 @@ void runIPDistance(
outIndices);
}

// no bf16 support for AMD
#ifndef USE_AMD_ROCM
void runIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices) {
runIPDistance<__nv_bfloat16>(
res,
stream,
vectors,
vectorsRowMajor,
queries,
queriesRowMajor,
k,
outDistances,
outIndices);
}
#endif // USE_AMD_ROCM

} // namespace gpu
} // namespace faiss
Loading

0 comments on commit 97467dc

Please sign in to comment.