Skip to content

Commit

Permalink
Faiss GPU: bfloat16 brute-force kNN support (facebookresearch#4018)
Browse files Browse the repository at this point in the history
Summary:


This diff adds support for bfloat16 vector/query data types with the GPU brute-force k-nearest neighbor function (`bfKnn`).

The change is largely just plumbing the new data type through the template hierarchy (so distances can be computed in bfloat16).

Of note, by design, all final distance results are produced in float32 regardless of input data type (float32, float16, bfloat16). This is because the true nearest neighbors in many data sets can often differ by only ~1000 float32 ULPs in terms of distance which will result in possible false equivalency. This seems to be one area where lossy compression/quantization thoughout does not work as well (and is also why `CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION` is set in `StandardGpuResources.cpp`. However, given that there is native bf16 x bf16 = fp32 tensor core support on Ampere+ architectures, the matrix multiplication itself should 

WARNING: The one thing this diff does not yet handle properly is header inclusion / compilation for GPUs older than Ampere. This will need to be fixed before landing (so that compiling with an older CUDA SDK or compiling for the Volta architecture will simply error out at runtime properly with lack of support, instead of failing to compile (?)

Differential Revision: D65459723
  • Loading branch information
Jeff Johnson authored and facebook-github-bot committed Nov 6, 2024
1 parent cfd4804 commit 841d313
Show file tree
Hide file tree
Showing 13 changed files with 725 additions and 46 deletions.
20 changes: 15 additions & 5 deletions faiss/gpu/GpuDistance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
FAISS_THROW_IF_NOT_MSG(
args.vectorType == args.queryType,
"limitation: both vectorType and queryType must currently "
"be the same (F32 or F16");
"be the same (F32 / F16 / BF16");

#if defined USE_NVIDIA_RAFT
// Note: For now, RAFT bfknn requires queries and vectors to be same layout
Expand Down Expand Up @@ -374,6 +374,12 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
bfKnnConvert<float>(prov, args);
} else if (args.vectorType == DistanceDataType::F16) {
bfKnnConvert<half>(prov, args);
} else if (args.vectorType == DistanceDataType::BF16) {
#ifdef FAISS_USE_FULL_BFLOAT16
bfKnnConvert<__nv_bfloat16>(prov, args);
#else
FAISS_THROW_MSG("not compiled with bfloat16 support");
#endif
} else {
FAISS_THROW_MSG("unknown vectorType");
}
Expand Down Expand Up @@ -440,8 +446,10 @@ void bfKnn_single_query_shard(
args.k > 0,
"bfKnn_tiling: tiling vectors is only supported for k > 0");
size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
: args.vectorType == DistanceDataType::F16 ? 2
: 0;
: (args.vectorType == DistanceDataType::F16 ||
args.vectorType == DistanceDataType::BF16)
? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown vectorType");
size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size);
Expand Down Expand Up @@ -498,8 +506,10 @@ void bfKnn_tiling(
args.k > 0,
"bfKnn_tiling: tiling queries is only supported for k > 0");
size_t distance_size = args.queryType == DistanceDataType::F32 ? 4
: args.queryType == DistanceDataType::F16 ? 2
: 0;
: (args.queryType == DistanceDataType::F16 ||
args.queryType == DistanceDataType::BF16)
? 2
: 0;
FAISS_THROW_IF_NOT_MSG(
distance_size > 0, "bfKnn_tiling: unknown queryType");
size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8
Expand Down
1 change: 1 addition & 0 deletions faiss/gpu/GpuDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class GpuResourcesProvider;
enum class DistanceDataType {
F32 = 1,
F16,
BF16,
};

// Scalar type of the indices data
Expand Down
97 changes: 97 additions & 0 deletions faiss/gpu/impl/Distance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,29 @@ void runAllPairwiseL2Distance(
outDistances);
}

#ifdef FAISS_USE_FULL_BFLOAT16
void runAllPairwiseL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances) {
runAllPairwiseDistance<__nv_bfloat16>(
true,
res,
stream,
vectors,
vectorsRowMajor,
vectorNorms,
queries,
queriesRowMajor,
outDistances);
}
#endif // FAISS_USE_FULL_BFLOAT16

void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Expand Down Expand Up @@ -544,6 +567,28 @@ void runAllPairwiseIPDistance(
outDistances);
}

#ifdef FAISS_USE_FULL_BFLOAT16
void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances) {
runAllPairwiseDistance<__nv_bfloat16>(
false,
res,
stream,
vectors,
vectorsRowMajor,
nullptr,
queries,
queriesRowMajor,
outDistances);
}
#endif // FAISS_USE_FULL_BFLOAT16

void runL2Distance(
GpuResources* res,
cudaStream_t stream,
Expand Down Expand Up @@ -596,6 +641,34 @@ void runL2Distance(
ignoreOutDistances);
}

#ifdef FAISS_USE_FULL_BFLOAT16
void runL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices,
bool ignoreOutDistances) {
runL2Distance<__nv_bfloat16>(
res,
stream,
vectors,
vectorsRowMajor,
vectorNorms,
queries,
queriesRowMajor,
k,
outDistances,
outIndices,
ignoreOutDistances);
}
#endif // FAISS_USE_FULL_BFLOAT16

void runIPDistance(
GpuResources* res,
cudaStream_t stream,
Expand Down Expand Up @@ -640,5 +713,29 @@ void runIPDistance(
outIndices);
}

#ifdef FAISS_USE_FULL_BFLOAT16
void runIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices) {
runIPDistance<__nv_bfloat16>(
res,
stream,
vectors,
vectorsRowMajor,
queries,
queriesRowMajor,
k,
outDistances,
outIndices);
}
#endif // FAISS_USE_FULL_BFLOAT16

} // namespace gpu
} // namespace faiss
51 changes: 51 additions & 0 deletions faiss/gpu/impl/Distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ void runAllPairwiseL2Distance(
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);

#ifdef FAISS_USE_FULL_BFLOAT16
void runAllPairwiseL2Distance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);
#endif // FAISS_USE_FULL_BFLOAT16

void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Expand All @@ -59,6 +71,17 @@ void runAllPairwiseIPDistance(
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);

#ifdef FAISS_USE_FULL_BFLOAT16
void runAllPairwiseIPDistance(
GpuResources* res,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
Tensor<float, 2, true>& outDistances);
#endif // FAISS_USE_FULL_BFLOAT16

/// Calculates brute-force L2 distance between `vectors` and
/// `queries`, returning the k closest results seen
void runL2Distance(
Expand Down Expand Up @@ -91,6 +114,21 @@ void runL2Distance(
Tensor<idx_t, 2, true>& outIndices,
bool ignoreOutDistances = false);

#ifdef FAISS_USE_FULL_BFLOAT16
void runL2Distance(
GpuResources* resources,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<float, 1, true>* vectorNorms,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices,
bool ignoreOutDistances = false);
#endif // FAISS_USE_FULL_BFLOAT16

/// Calculates brute-force inner product distance between `vectors`
/// and `queries`, returning the k closest results seen
void runIPDistance(
Expand All @@ -115,6 +153,19 @@ void runIPDistance(
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices);

#ifdef FAISS_USE_FULL_BFLOAT16
void runIPDistance(
GpuResources* resources,
cudaStream_t stream,
Tensor<__nv_bfloat16, 2, true>& vectors,
bool vectorsRowMajor,
Tensor<__nv_bfloat16, 2, true>& queries,
bool queriesRowMajor,
int k,
Tensor<float, 2, true>& outDistances,
Tensor<idx_t, 2, true>& outIndices);
#endif // FAISS_USE_FULL_BFLOAT16

//
// General distance implementation, assumes that all arguments are on the
// device. This is the top-level internal distance function to call to dispatch
Expand Down
8 changes: 4 additions & 4 deletions faiss/gpu/impl/GpuScalarQuantizer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
inline __device__ void decode(void* data, idx_t vec, int d, float* out)
const {
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
out[0] = Convert<half, float>()(p[d]);
out[0] = ConvertTo<float>::to(p[d]);
}

inline __device__ float decodePartial(
Expand All @@ -172,7 +172,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
int d,
float v[kDimPerIter]) const {
half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec];
p[d] = Convert<float, half>()(v[0]);
p[d] = ConvertTo<half>::to(v[0]);
}

inline __device__ void encodePartial(
Expand All @@ -191,11 +191,11 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
static constexpr int kEncodeBits = 16;

inline __device__ EncodeT encodeNew(int dim, float v) const {
return Convert<float, half>()(v);
return ConvertTo<half>::to(v);
}

inline __device__ float decodeNew(int dim, EncodeT v) const {
return Convert<half, float>()(v);
return ConvertTo<float>::to(v);
}

int bytesPerVec;
Expand Down
15 changes: 14 additions & 1 deletion faiss/gpu/impl/L2Norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <faiss/gpu/impl/L2Norm.cuh>
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/MathOperators.cuh>
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/Reductions.cuh>
Expand Down Expand Up @@ -276,5 +275,19 @@ void runL2Norm(
runL2Norm<half, half2>(input, inputRowMajor, output, normSquared, stream);
}

#ifdef FAISS_USE_FULL_BFLOAT16

void runL2Norm(
Tensor<__nv_bfloat16, 2, true>& input,
bool inputRowMajor,
Tensor<float, 1, true>& output,
bool normSquared,
cudaStream_t stream) {
runL2Norm<__nv_bfloat16, __nv_bfloat162>(
input, inputRowMajor, output, normSquared, stream);
}

#endif // FAISS_USE_FULL_BFLOAT16

} // namespace gpu
} // namespace faiss
13 changes: 12 additions & 1 deletion faiss/gpu/impl/L2Norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#pragma once

#include <cuda_fp16.h>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/Tensor.cuh>

namespace faiss {
Expand All @@ -27,5 +27,16 @@ void runL2Norm(
bool normSquared,
cudaStream_t stream);

#ifdef FAISS_USE_FULL_BFLOAT16

void runL2Norm(
Tensor<__nv_bfloat16, 2, true>& input,
bool inputRowMajor,
Tensor<float, 1, true>& output,
bool normSquared,
cudaStream_t stream);

#endif // FAISS_USE_FULL_BFLOAT16

} // namespace gpu
} // namespace faiss
8 changes: 2 additions & 6 deletions faiss/gpu/impl/VectorResidual.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,8 @@ __global__ void gatherReconstructByIds(
auto vec = vecs[id];
auto outVec = out[blockIdx.x];

Convert<T, float> conv;

for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]);
outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo<float>::to(vec[i]);
}
}

Expand All @@ -131,10 +129,8 @@ __global__ void gatherReconstructByRange(
auto vec = vecs[id];
auto outVec = out[blockIdx.x];

Convert<T, float> conv;

for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) {
outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]);
outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo<float>::to(vec[i]);
}
}

Expand Down
Loading

0 comments on commit 841d313

Please sign in to comment.