diff --git a/faiss/gpu/GpuDistance.cu b/faiss/gpu/GpuDistance.cu index f515067889..e80477f1a0 100644 --- a/faiss/gpu/GpuDistance.cu +++ b/faiss/gpu/GpuDistance.cu @@ -30,6 +30,7 @@ #include #include #include +#include #include #if defined USE_NVIDIA_CUVS @@ -231,7 +232,7 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) { FAISS_THROW_IF_NOT_MSG( args.vectorType == args.queryType, "limitation: both vectorType and queryType must currently " - "be the same (F32 or F16"); + "be the same (F32 / F16 / BF16"); #if defined USE_NVIDIA_CUVS // Note: For now, cuVS bfknn requires queries and vectors to be same layout @@ -400,6 +401,17 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) { bfKnnConvert(prov, args); } else if (args.vectorType == DistanceDataType::F16) { bfKnnConvert(prov, args); + } else if (args.vectorType == DistanceDataType::BF16) { +// no bf16 support for AMD +#ifndef USE_AMD_ROCM + if (prov->getResources()->supportsBFloat16CurrentDevice()) { + bfKnnConvert<__nv_bfloat16>(prov, args); + } else { + FAISS_THROW_MSG("not compiled with bfloat16 support"); + } +#else + FAISS_THROW_MSG("no AMD bfloat16 support"); +#endif } else { FAISS_THROW_MSG("unknown vectorType"); } @@ -466,8 +478,10 @@ void bfKnn_single_query_shard( args.k > 0, "bfKnn_tiling: tiling vectors is only supported for k > 0"); size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4 - : args.vectorType == DistanceDataType::F16 ? 2 - : 0; + : (args.vectorType == DistanceDataType::F16 || + args.vectorType == DistanceDataType::BF16) + ? 2 + : 0; FAISS_THROW_IF_NOT_MSG( distance_size > 0, "bfKnn_tiling: unknown vectorType"); size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size); @@ -524,8 +538,10 @@ void bfKnn_tiling( args.k > 0, "bfKnn_tiling: tiling queries is only supported for k > 0"); size_t distance_size = args.queryType == DistanceDataType::F32 ? 4 - : args.queryType == DistanceDataType::F16 ? 2 - : 0; + : (args.queryType == DistanceDataType::F16 || + args.queryType == DistanceDataType::BF16) + ? 2 + : 0; FAISS_THROW_IF_NOT_MSG( distance_size > 0, "bfKnn_tiling: unknown queryType"); size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8 diff --git a/faiss/gpu/GpuDistance.h b/faiss/gpu/GpuDistance.h index 7052fc68b0..e4daf5e296 100644 --- a/faiss/gpu/GpuDistance.h +++ b/faiss/gpu/GpuDistance.h @@ -19,6 +19,7 @@ class GpuResourcesProvider; enum class DistanceDataType { F32 = 1, F16, + BF16, }; // Scalar type of the indices data diff --git a/faiss/gpu/GpuResources.cpp b/faiss/gpu/GpuResources.cpp index 1f0f2541f1..74df9b96d3 100644 --- a/faiss/gpu/GpuResources.cpp +++ b/faiss/gpu/GpuResources.cpp @@ -161,6 +161,10 @@ GpuMemoryReservation::~GpuMemoryReservation() { GpuResources::~GpuResources() = default; +bool GpuResources::supportsBFloat16CurrentDevice() { + return supportsBFloat16(getCurrentDevice()); +} + cublasHandle_t GpuResources::getBlasHandleCurrentDevice() { return getBlasHandle(getCurrentDevice()); } diff --git a/faiss/gpu/GpuResources.h b/faiss/gpu/GpuResources.h index 3fec634fef..d914dae7ae 100644 --- a/faiss/gpu/GpuResources.h +++ b/faiss/gpu/GpuResources.h @@ -205,6 +205,9 @@ class GpuResources { /// of demand virtual void initializeForDevice(int device) = 0; + /// Does the given GPU support bfloat16? + virtual bool supportsBFloat16(int device) = 0; + /// Returns the cuBLAS handle that we use for the given device virtual cublasHandle_t getBlasHandle(int device) = 0; @@ -252,6 +255,9 @@ class GpuResources { /// Functions provided by default /// + /// Does the current GPU support bfloat16? + bool supportsBFloat16CurrentDevice(); + /// Calls getBlasHandle with the current device cublasHandle_t getBlasHandleCurrentDevice(); diff --git a/faiss/gpu/StandardGpuResources.cpp b/faiss/gpu/StandardGpuResources.cpp index a91c7f693c..39ee38efa9 100644 --- a/faiss/gpu/StandardGpuResources.cpp +++ b/faiss/gpu/StandardGpuResources.cpp @@ -206,6 +206,13 @@ size_t StandardGpuResourcesImpl::getDefaultTempMemForGPU( return requested; } +/// Does the given GPU support bfloat16? +bool StandardGpuResourcesImpl::supportsBFloat16(int device) { + initializeForDevice(device); + auto& prop = getDeviceProperties(device); + return prop.major >= 8; +} + void StandardGpuResourcesImpl::noTempMemory() { setTempMemory(0); } @@ -701,6 +708,14 @@ std::shared_ptr StandardGpuResources::getResources() { return res_; } +bool StandardGpuResources::supportsBFloat16(int device) { + return res_->supportsBFloat16(device); +} + +bool StandardGpuResources::supportsBFloat16CurrentDevice() { + return res_->supportsBFloat16CurrentDevice(); +} + void StandardGpuResources::noTempMemory() { res_->noTempMemory(); } diff --git a/faiss/gpu/StandardGpuResources.h b/faiss/gpu/StandardGpuResources.h index 322a341a00..9c8cf4d55d 100644 --- a/faiss/gpu/StandardGpuResources.h +++ b/faiss/gpu/StandardGpuResources.h @@ -48,6 +48,9 @@ class StandardGpuResourcesImpl : public GpuResources { ~StandardGpuResourcesImpl() override; + /// Does the given GPU support bfloat16? + bool supportsBFloat16(int device) override; + /// Disable allocation of temporary memory; all temporary memory /// requests will call cudaMalloc / cudaFree at the point of use void noTempMemory(); @@ -199,6 +202,12 @@ class StandardGpuResources : public GpuResourcesProvider { std::shared_ptr getResources() override; + /// Whether or not the given device supports native bfloat16 arithmetic + bool supportsBFloat16(int device); + + /// Whether or not the current device supports native bfloat16 arithmetic + bool supportsBFloat16CurrentDevice(); + /// Disable allocation of temporary memory; all temporary memory /// requests will call cudaMalloc / cudaFree at the point of use void noTempMemory(); diff --git a/faiss/gpu/impl/Distance.cu b/faiss/gpu/impl/Distance.cu index 3ac99b2576..eb2e91e93e 100644 --- a/faiss/gpu/impl/Distance.cu +++ b/faiss/gpu/impl/Distance.cu @@ -504,6 +504,30 @@ void runAllPairwiseL2Distance( outDistances); } +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runAllPairwiseL2Distance( + GpuResources* res, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + Tensor& outDistances) { + runAllPairwiseDistance<__nv_bfloat16>( + true, + res, + stream, + vectors, + vectorsRowMajor, + vectorNorms, + queries, + queriesRowMajor, + outDistances); +} +#endif // USE_AMD_ROCM + void runAllPairwiseIPDistance( GpuResources* res, cudaStream_t stream, @@ -544,6 +568,29 @@ void runAllPairwiseIPDistance( outDistances); } +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runAllPairwiseIPDistance( + GpuResources* res, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + Tensor& outDistances) { + runAllPairwiseDistance<__nv_bfloat16>( + false, + res, + stream, + vectors, + vectorsRowMajor, + nullptr, + queries, + queriesRowMajor, + outDistances); +} +#endif // USE_AMD_ROCM + void runL2Distance( GpuResources* res, cudaStream_t stream, @@ -596,6 +643,35 @@ void runL2Distance( ignoreOutDistances); } +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runL2Distance( + GpuResources* res, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances) { + runL2Distance<__nv_bfloat16>( + res, + stream, + vectors, + vectorsRowMajor, + vectorNorms, + queries, + queriesRowMajor, + k, + outDistances, + outIndices, + ignoreOutDistances); +} +#endif // USE_AMD_ROCM + void runIPDistance( GpuResources* res, cudaStream_t stream, @@ -640,5 +716,30 @@ void runIPDistance( outIndices); } +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runIPDistance( + GpuResources* res, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + int k, + Tensor& outDistances, + Tensor& outIndices) { + runIPDistance<__nv_bfloat16>( + res, + stream, + vectors, + vectorsRowMajor, + queries, + queriesRowMajor, + k, + outDistances, + outIndices); +} +#endif // USE_AMD_ROCM + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/impl/Distance.cuh b/faiss/gpu/impl/Distance.cuh index 17d21f4d9a..d8e1d5c239 100644 --- a/faiss/gpu/impl/Distance.cuh +++ b/faiss/gpu/impl/Distance.cuh @@ -41,6 +41,19 @@ void runAllPairwiseL2Distance( bool queriesRowMajor, Tensor& outDistances); +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runAllPairwiseL2Distance( + GpuResources* res, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + Tensor& outDistances); +#endif // USE_AMD_ROCM + void runAllPairwiseIPDistance( GpuResources* res, cudaStream_t stream, @@ -59,6 +72,18 @@ void runAllPairwiseIPDistance( bool queriesRowMajor, Tensor& outDistances); +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runAllPairwiseIPDistance( + GpuResources* res, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + Tensor& outDistances); +#endif // USE_AMD_ROCM + /// Calculates brute-force L2 distance between `vectors` and /// `queries`, returning the k closest results seen void runL2Distance( @@ -91,6 +116,22 @@ void runL2Distance( Tensor& outIndices, bool ignoreOutDistances = false); +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runL2Distance( + GpuResources* resources, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor* vectorNorms, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + int k, + Tensor& outDistances, + Tensor& outIndices, + bool ignoreOutDistances = false); +#endif // USE_AMD_ROCM + /// Calculates brute-force inner product distance between `vectors` /// and `queries`, returning the k closest results seen void runIPDistance( @@ -115,6 +156,20 @@ void runIPDistance( Tensor& outDistances, Tensor& outIndices); +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runIPDistance( + GpuResources* resources, + cudaStream_t stream, + Tensor<__nv_bfloat16, 2, true>& vectors, + bool vectorsRowMajor, + Tensor<__nv_bfloat16, 2, true>& queries, + bool queriesRowMajor, + int k, + Tensor& outDistances, + Tensor& outIndices); +#endif // USE_AMD_ROCM + // // General distance implementation, assumes that all arguments are on the // device. This is the top-level internal distance function to call to dispatch diff --git a/faiss/gpu/impl/GeneralDistance.cuh b/faiss/gpu/impl/GeneralDistance.cuh index cc60794a04..208d3a81bf 100644 --- a/faiss/gpu/impl/GeneralDistance.cuh +++ b/faiss/gpu/impl/GeneralDistance.cuh @@ -151,10 +151,10 @@ __launch_bounds__(TILE_SIZE* TILE_SIZE) __global__ void generalDistance( bool kInBounds = k < query.getSize(1); queryTileBase[threadIdx.x + i * TILE_SIZE] = - kInBounds ? queryBase[k] : ConvertTo::to(0); + kInBounds ? queryBase[k] : ConvertTo::to(0.0f); vecTileBase[threadIdx.x + i * TILE_SIZE] = - kInBounds ? vecBase[k] : ConvertTo::to(0); + kInBounds ? vecBase[k] : ConvertTo::to(0.0f); } __syncthreads(); @@ -185,10 +185,10 @@ __launch_bounds__(TILE_SIZE* TILE_SIZE) __global__ void generalDistance( for (idx_t k = threadIdx.x; k < limit; k += TILE_SIZE) { // Load query tile queryTileBase[threadIdx.x] = - queryThreadInBounds ? queryBase[k] : ConvertTo::to(0); + queryThreadInBounds ? queryBase[k] : ConvertTo::to(0.0f); vecTileBase[threadIdx.x] = - vecThreadInBoundsLoad ? vecBase[k] : ConvertTo::to(0); + vecThreadInBoundsLoad ? vecBase[k] : ConvertTo::to(0.0f); __syncthreads(); @@ -211,11 +211,11 @@ __launch_bounds__(TILE_SIZE* TILE_SIZE) __global__ void generalDistance( // Load query tile queryTileBase[threadIdx.x] = queryThreadInBounds && kInBounds ? queryBase[k] - : ConvertTo::to(0); + : ConvertTo::to(0.0f); vecTileBase[threadIdx.x] = vecThreadInBoundsLoad && kInBounds ? vecBase[k] - : ConvertTo::to(0); + : ConvertTo::to(0.0f); __syncthreads(); diff --git a/faiss/gpu/impl/GpuScalarQuantizer.cuh b/faiss/gpu/impl/GpuScalarQuantizer.cuh index cb7454cf11..c2d781419d 100644 --- a/faiss/gpu/impl/GpuScalarQuantizer.cuh +++ b/faiss/gpu/impl/GpuScalarQuantizer.cuh @@ -154,7 +154,7 @@ struct Codec { inline __device__ void decode(void* data, idx_t vec, int d, float* out) const { half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec]; - out[0] = Convert()(p[d]); + out[0] = ConvertTo::to(p[d]); } inline __device__ float decodePartial( @@ -172,7 +172,7 @@ struct Codec { int d, float v[kDimPerIter]) const { half* p = (half*)&((uint8_t*)data)[vec * bytesPerVec]; - p[d] = Convert()(v[0]); + p[d] = ConvertTo::to(v[0]); } inline __device__ void encodePartial( @@ -191,11 +191,11 @@ struct Codec { static constexpr int kEncodeBits = 16; inline __device__ EncodeT encodeNew(int dim, float v) const { - return Convert()(v); + return ConvertTo::to(v); } inline __device__ float decodeNew(int dim, EncodeT v) const { - return Convert()(v); + return ConvertTo::to(v); } int bytesPerVec; diff --git a/faiss/gpu/impl/L2Norm.cu b/faiss/gpu/impl/L2Norm.cu index 66eb06d0d7..262fa19153 100644 --- a/faiss/gpu/impl/L2Norm.cu +++ b/faiss/gpu/impl/L2Norm.cu @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #include @@ -276,5 +275,18 @@ void runL2Norm( runL2Norm(input, inputRowMajor, output, normSquared, stream); } +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runL2Norm( + Tensor<__nv_bfloat16, 2, true>& input, + bool inputRowMajor, + Tensor& output, + bool normSquared, + cudaStream_t stream) { + runL2Norm<__nv_bfloat16, __nv_bfloat162>( + input, inputRowMajor, output, normSquared, stream); +} +#endif + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/impl/L2Norm.cuh b/faiss/gpu/impl/L2Norm.cuh index fa798b75b7..79aef4f131 100644 --- a/faiss/gpu/impl/L2Norm.cuh +++ b/faiss/gpu/impl/L2Norm.cuh @@ -7,7 +7,7 @@ #pragma once -#include +#include #include namespace faiss { @@ -27,5 +27,15 @@ void runL2Norm( bool normSquared, cudaStream_t stream); +// no bf16 support for AMD +#ifndef USE_AMD_ROCM +void runL2Norm( + Tensor<__nv_bfloat16, 2, true>& input, + bool inputRowMajor, + Tensor& output, + bool normSquared, + cudaStream_t stream); +#endif + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/impl/VectorResidual.cu b/faiss/gpu/impl/VectorResidual.cu index cba7d9073c..425036552d 100644 --- a/faiss/gpu/impl/VectorResidual.cu +++ b/faiss/gpu/impl/VectorResidual.cu @@ -114,10 +114,8 @@ __global__ void gatherReconstructByIds( auto vec = vecs[id]; auto outVec = out[blockIdx.x]; - Convert conv; - for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) { - outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]); + outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo::to(vec[i]); } } @@ -131,10 +129,8 @@ __global__ void gatherReconstructByRange( auto vec = vecs[id]; auto outVec = out[blockIdx.x]; - Convert conv; - for (idx_t i = threadIdx.x; i < vecs.getSize(1); i += blockDim.x) { - outVec[i] = id == idx_t(-1) ? 0.0f : conv(vec[i]); + outVec[i] = id == idx_t(-1) ? 0.0f : ConvertTo::to(vec[i]); } } diff --git a/faiss/gpu/test/TestGpuDistance.cu b/faiss/gpu/test/TestGpuDistance.cu index 3915055480..5dbca4ec9a 100644 --- a/faiss/gpu/test/TestGpuDistance.cu +++ b/faiss/gpu/test/TestGpuDistance.cu @@ -32,6 +32,13 @@ #include #include +enum class TestThresholds { + Normal, + BF16, + // Linf has worse error than the other metrics for bf16 + BF16_Linf, +}; + void evaluate_bfknn( faiss::gpu::GpuDistanceParams& args, faiss::gpu::GpuResourcesProvider* res, @@ -43,16 +50,39 @@ void evaluate_bfknn( int k, bool colMajorVecs, bool colMajorQueries, - faiss::MetricType metric) { + faiss::MetricType metric, + TestThresholds thresh = TestThresholds::Normal) { using namespace faiss::gpu; bfKnn(res, args); std::stringstream str; - str << "using cuVS " << args.use_cuvs << "metric " << metric + str << "using cuVS " << args.use_cuvs << " metric " << metric << " colMajorVecs " << colMajorVecs << " colMajorQueries " << colMajorQueries; + float maxRelativeError; + float pctMaxDiff1; + float pctMaxDiffN; + + switch (thresh) { + case TestThresholds::Normal: + maxRelativeError = 6e-3f; + pctMaxDiff1 = 0.1f; + pctMaxDiffN = 0.015f; + break; + case TestThresholds::BF16: + maxRelativeError = 1.5e-2f; + pctMaxDiff1 = 0.3f; + pctMaxDiffN = 0.1f; + break; + case TestThresholds::BF16_Linf: + maxRelativeError = 1.5e-2f; + pctMaxDiff1 = 0.53f; + pctMaxDiffN = 0.2f; + break; + } + compareLists( cpuDistance.data(), cpuIndices.data(), @@ -64,9 +94,9 @@ void evaluate_bfknn( false, false, true, - 6e-3f, - 0.1f, - 0.015f); + maxRelativeError, + pctMaxDiff1, + pctMaxDiffN); } void testTransposition( @@ -82,6 +112,10 @@ void testTransposition( StandardGpuResources res; res.noTempMemory(); + // The transpose and distance code assumes the desired device is already set + DeviceScope scope(device); + auto stream = res.getDefaultStream(device); + int dim = randVal(20, 150); int numVecs = randVal(10, 30000); int numQuery = randVal(1, 1024); @@ -120,10 +154,6 @@ void testTransposition( cpuIndex.search( numQuery, queries.data(), k, cpuDistance.data(), cpuIndices.data()); - // The transpose and distance code assumes the desired device is already set - DeviceScope scope(device); - auto stream = res.getDefaultStream(device); - // Copy input data to GPU, and pre-transpose both vectors and queries for // passing auto gpuVecs = toDeviceNonTemporary( @@ -191,12 +221,161 @@ void testTransposition( metric); } +void testTransposition_bf16( + bool colMajorVecs, + bool colMajorQueries, + faiss::MetricType metric, + bool use_raft = false, + float metricArg = 0) { + using namespace faiss::gpu; + +#ifdef USE_AMD_ROCM + std::cout << "skipping bfloat16 test (no bfloat16 support on AMD)\n"; + EXPECT_TRUE(true); + return; +#else + int device = randVal(0, getNumDevices() - 1); + + StandardGpuResources res; + if (!res.supportsBFloat16(device)) { + std::cout << "skipping bfloat16 test (no bfloat16 support on device)\n"; + return; + } + + res.noTempMemory(); + // The transpose and distance code assumes the desired device is already set + DeviceScope scope(device); + auto stream = res.getDefaultStream(device); + + int dim = randVal(20, 150); + int numVecs = randVal(10, 30000); + int numQuery = randVal(1, 1024); + int k = std::min(numVecs, randVal(20, 70)); + + // Input data for CPU + std::vector vecs = randVecs(numVecs, dim); + std::vector queries = randVecs(numQuery, dim); + + if ((metric == faiss::MetricType::METRIC_JensenShannon) || + (metric == faiss::MetricType::METRIC_Jaccard)) { + // make values positive + for (auto& v : vecs) { + v = std::abs(v); + if (v == 0) { + v = 1e-6; + } + } + + for (auto& q : queries) { + q = std::abs(q); + if (q == 0) { + q = 1e-6; + } + } + } + + // The CPU index is our reference for the results + faiss::IndexFlat cpuIndex(dim, metric); + cpuIndex.metric_arg = metricArg; + cpuIndex.add(numVecs, vecs.data()); + + std::vector cpuDistance(numQuery * k, 0); + std::vector cpuIndices(numQuery * k, -1); + + cpuIndex.search( + numQuery, queries.data(), k, cpuDistance.data(), cpuIndices.data()); + + // Convert float32 data to bfloat16 via truncation not rounding + // (just copy high 2 bytes) + std::vector bf16_vecs(vecs.size()); + std::vector bf16_queries(queries.size()); + + auto fn_f32_bf16 = [](float v) { + uint32_t vi; + std::memcpy(&vi, &v, sizeof(uint32_t)); + return uint16_t(vi >> 16); + }; + + std::transform(vecs.begin(), vecs.end(), bf16_vecs.begin(), fn_f32_bf16); + std::transform( + queries.begin(), queries.end(), bf16_queries.begin(), fn_f32_bf16); + + // Copy input data to GPU, and pre-transpose both vectors and queries for + // passing. Just use uint16_t in lieu of __nv_bfloat16 + auto gpuVecs = toDeviceNonTemporary( + res.getResources().get(), + device, + bf16_vecs.data(), + stream, + {numVecs, dim}); + auto gpuQueries = toDeviceNonTemporary( + res.getResources().get(), + device, + bf16_queries.data(), + stream, + {numQuery, dim}); + + DeviceTensor vecsT( + res.getResources().get(), + makeDevAlloc(AllocType::Other, stream), + {dim, numVecs}); + runTransposeAny(gpuVecs, 0, 1, vecsT, stream); + + DeviceTensor queriesT( + res.getResources().get(), + makeDevAlloc(AllocType::Other, stream), + {dim, numQuery}); + runTransposeAny(gpuQueries, 0, 1, queriesT, stream); + + std::vector gpuDistance(numQuery * k, 0); + std::vector gpuIndices(numQuery * k, -1); + + GpuDistanceParams args; + args.metric = metric; + args.metricArg = metricArg; + args.k = k; + args.dims = dim; + args.vectors = colMajorVecs ? vecsT.data() : gpuVecs.data(); + args.vectorType = DistanceDataType::BF16; + args.vectorsRowMajor = !colMajorVecs; + args.numVectors = numVecs; + args.queries = colMajorQueries ? queriesT.data() : gpuQueries.data(); + args.queryType = DistanceDataType::BF16; + args.queriesRowMajor = !colMajorQueries; + args.numQueries = numQuery; + args.outDistances = gpuDistance.data(); + args.outIndices = gpuIndices.data(); + args.device = device; + + evaluate_bfknn( + args, + &res, + cpuDistance, + cpuIndices, + gpuDistance, + gpuIndices, + numQuery, + k, + colMajorVecs, + colMajorQueries, + metric, + metric == faiss::MetricType::METRIC_Linf ? TestThresholds::BF16_Linf + : TestThresholds::BF16); +#endif +} + // Test different memory layouts for brute-force k-NN TEST(TestGpuDistance, Transposition_RR) { testTransposition(false, false, faiss::MetricType::METRIC_L2); testTransposition(false, false, faiss::MetricType::METRIC_INNER_PRODUCT); } +TEST(TestGpuDistance, Transposition_RR_BF16) { + testTransposition_bf16(false, false, faiss::MetricType::METRIC_L2); + testTransposition_bf16( + false, false, faiss::MetricType::METRIC_INNER_PRODUCT); +} + #if defined USE_NVIDIA_CUVS TEST(TestCuvsGpuDistance, Transposition_RR) { testTransposition(false, false, faiss::MetricType::METRIC_L2, true); @@ -209,6 +388,10 @@ TEST(TestGpuDistance, Transposition_RC) { testTransposition(false, true, faiss::MetricType::METRIC_L2); } +TEST(TestGpuDistance, Transposition_RC_BF16) { + testTransposition_bf16(false, true, faiss::MetricType::METRIC_L2); +} + #if defined USE_NVIDIA_CUVS TEST(TestCuvsGpuDistance, Transposition_RC) { testTransposition(false, true, faiss::MetricType::METRIC_L2, true); @@ -219,6 +402,10 @@ TEST(TestGpuDistance, Transposition_CR) { testTransposition(true, false, faiss::MetricType::METRIC_L2); } +TEST(TestGpuDistance, Transposition_CR_BF16) { + testTransposition_bf16(true, false, faiss::MetricType::METRIC_L2); +} + #if defined USE_NVIDIA_CUVS TEST(TestCuvsGpuDistance, Transposition_CR) { testTransposition(true, false, faiss::MetricType::METRIC_L2, true); @@ -229,6 +416,10 @@ TEST(TestGpuDistance, Transposition_CC) { testTransposition(true, true, faiss::MetricType::METRIC_L2); } +TEST(TestGpuDistance, Transposition_CC_BF16) { + testTransposition_bf16(true, true, faiss::MetricType::METRIC_L2); +} + #if defined USE_NVIDIA_CUVS TEST(TestCuvsGpuDistance, Transposition_CC) { testTransposition(true, true, faiss::MetricType::METRIC_L2, true); @@ -239,6 +430,10 @@ TEST(TestGpuDistance, L1) { testTransposition(false, false, faiss::MetricType::METRIC_L1); } +TEST(TestGpuDistance, L1_BF16) { + testTransposition_bf16(false, false, faiss::MetricType::METRIC_L1); +} + #if defined USE_NVIDIA_CUVS TEST(TestCuvsGpuDistance, L1) { testTransposition(false, false, faiss::MetricType::METRIC_L1, true); @@ -257,10 +452,18 @@ TEST(TestCuvsGpuDistance, L1_RC) { } #endif +TEST(TestGpuDistance, L1_RC_BF16) { + testTransposition_bf16(false, true, faiss::MetricType::METRIC_L1); +} + TEST(TestGpuDistance, L1_CR) { testTransposition(true, false, faiss::MetricType::METRIC_L1); } +TEST(TestGpuDistance, L1_CR_BF16) { + testTransposition_bf16(true, false, faiss::MetricType::METRIC_L1); +} + #if defined USE_NVIDIA_CUVS TEST(TestCuvsGpuDistance, L1_CR) { testTransposition(true, false, faiss::MetricType::METRIC_L1, true); @@ -271,6 +474,10 @@ TEST(TestGpuDistance, L1_CC) { testTransposition(true, true, faiss::MetricType::METRIC_L1); } +TEST(TestGpuDistance, L1_CC_BF16) { + testTransposition_bf16(true, true, faiss::MetricType::METRIC_L1); +} + #if defined USE_NVIDIA_CUVS TEST(TestCuvsGpuDistance, L1_CC) { testTransposition(true, true, faiss::MetricType::METRIC_L1, true); @@ -286,53 +493,82 @@ TEST(TestGpuDistance, Linf) { // Test remainder of metric types TEST(TestCuvsGpuDistance, Linf) { testTransposition(false, false, faiss::MetricType::METRIC_Linf, true); -} #endif -TEST(TestGpuDistance, Lp) { - testTransposition(false, false, faiss::MetricType::METRIC_Lp, false, 3); -} + TEST(TestGpuDistance, Linf_BF16) { + testTransposition_bf16(false, false, faiss::MetricType::METRIC_Linf); + } + + TEST(TestGpuDistance, Lp) { + testTransposition(false, false, faiss::MetricType::METRIC_Lp, false, 3); + } + + TEST(TestGpuDistance, Lp_BF16) { + testTransposition_bf16( + false, false, faiss::MetricType::METRIC_Lp, false, 3); + } #if defined USE_NVIDIA_CUVS -TEST(TestCuvsGpuDistance, Lp) { - testTransposition(false, false, faiss::MetricType::METRIC_Lp, true, 3); -} + TEST(TestCuvsGpuDistance, Lp) { + testTransposition(false, false, faiss::MetricType::METRIC_Lp, true, 3); + } #endif -TEST(TestGpuDistance, Canberra) { - testTransposition(false, false, faiss::MetricType::METRIC_Canberra); -} + TEST(TestGpuDistance, Canberra) { + testTransposition(false, false, faiss::MetricType::METRIC_Canberra); + } + + TEST(TestGpuDistance, Canberra_BF16) { + testTransposition_bf16( + false, false, faiss::MetricType::METRIC_Canberra); + } #if defined USE_NVIDIA_CUVS -TEST(TestCuvsGpuDistance, Canberra) { - testTransposition(false, false, faiss::MetricType::METRIC_Canberra, true); -} + TEST(TestCuvsGpuDistance, Canberra) { + testTransposition( + false, false, faiss::MetricType::METRIC_Canberra, true); + } #endif -TEST(TestGpuDistance, BrayCurtis) { - testTransposition(false, false, faiss::MetricType::METRIC_BrayCurtis); -} + TEST(TestGpuDistance, BrayCurtis) { + testTransposition(false, false, faiss::MetricType::METRIC_BrayCurtis); + } -TEST(TestGpuDistance, JensenShannon) { - testTransposition(false, false, faiss::MetricType::METRIC_JensenShannon); -} + TEST(TestGpuDistance, BrayCurtis_BF16) { + testTransposition_bf16( + false, false, faiss::MetricType::METRIC_BrayCurtis); + } + + TEST(TestGpuDistance, JensenShannon) { + testTransposition( + false, false, faiss::MetricType::METRIC_JensenShannon); + } + + TEST(TestGpuDistance, JensenShannon_BF16) { + testTransposition_bf16( + false, false, faiss::MetricType::METRIC_JensenShannon); + } #if defined USE_NVIDIA_CUVS -TEST(TestCuvsGpuDistance, JensenShannon) { - testTransposition( - false, false, faiss::MetricType::METRIC_JensenShannon, true); -} + TEST(TestCuvsGpuDistance, JensenShannon) { + testTransposition( + false, false, faiss::MetricType::METRIC_JensenShannon, true); + } #endif -TEST(TestGpuDistance, Jaccard) { - testTransposition(false, false, faiss::MetricType::METRIC_Jaccard); -} + TEST(TestGpuDistance, Jaccard) { + testTransposition(false, false, faiss::MetricType::METRIC_Jaccard); + } -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); + TEST(TestGpuDistance, Jaccard_BF16) { + testTransposition_bf16(false, false, faiss::MetricType::METRIC_Jaccard); + } - // just run with a fixed test seed - faiss::gpu::setTestSeed(100); + int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} + // just run with a fixed test seed + faiss::gpu::setTestSeed(100); + + return RUN_ALL_TESTS(); + } diff --git a/faiss/gpu/utils/ConversionOperators.cuh b/faiss/gpu/utils/ConversionOperators.cuh index bbaac78f64..f0ab1ea1fd 100644 --- a/faiss/gpu/utils/ConversionOperators.cuh +++ b/faiss/gpu/utils/ConversionOperators.cuh @@ -22,30 +22,14 @@ namespace gpu { // Conversion utilities // -template -struct Convert { - inline __device__ To operator()(From v) const { - return (To)v; - } -}; - -template <> -struct Convert { - inline __device__ half operator()(float v) const { - return __float2half(v); - } -}; - -template <> -struct Convert { - inline __device__ float operator()(half v) const { - return __half2float(v); +template +struct ConvertTo { + template + static inline __device__ T to(U v) { + return T(v); } }; -template -struct ConvertTo {}; - template <> struct ConvertTo { static inline __device__ float to(float v) { @@ -54,6 +38,12 @@ struct ConvertTo { static inline __device__ float to(half v) { return __half2float(v); } + +#ifndef USE_AMD_ROCM + static inline __device__ float to(__nv_bfloat16 v) { + return __bfloat162float(v); + } +#endif // !USE_AMD_ROCM }; template <> @@ -106,6 +96,31 @@ struct ConvertTo { } }; +// no bf16 support for AMD +#ifndef USE_AMD_ROCM + +template <> +struct ConvertTo<__nv_bfloat16> { + static inline __device__ __nv_bfloat16 to(float v) { + return __float2bfloat16(v); + } + static inline __device__ __nv_bfloat16 to(half v) { + return __float2bfloat16(__half2float(v)); + } + static inline __device__ __nv_bfloat16 to(__nv_bfloat16 v) { + return v; + } +}; + +#endif // USE_AMD_ROCM + +template +struct Convert { + inline __device__ To operator()(From v) const { + return ConvertTo::to(v); + } +}; + // Tensor conversion template void runConvert(const From* in, To* out, size_t num, cudaStream_t stream) { diff --git a/faiss/gpu/utils/Float16.cuh b/faiss/gpu/utils/Float16.cuh index 449829de66..6a1f779eab 100644 --- a/faiss/gpu/utils/Float16.cuh +++ b/faiss/gpu/utils/Float16.cuh @@ -16,7 +16,21 @@ #define FAISS_USE_FULL_FLOAT16 1 #endif // __CUDA_ARCH__ types +// Some compute capabilities have full bfloat16 ALUs. +// FIXME: no support in ROCm yet +#if __CUDA_ARCH__ >= 800 // || defined(USE_AMD_ROCM) +#define FAISS_USE_FULL_BFLOAT16 1 +#endif // __CUDA_ARCH__ types + #include +#if !defined(USE_AMD_ROCM) +#include +#endif +// #else +// FIXME: no support in ROCm yet +// #include +// #include +// #endif // !defined(USE_AMD_ROCM) namespace faiss { namespace gpu { diff --git a/faiss/gpu/utils/MathOperators.cuh b/faiss/gpu/utils/MathOperators.cuh index d825233c0d..9239c735f6 100644 --- a/faiss/gpu/utils/MathOperators.cuh +++ b/faiss/gpu/utils/MathOperators.cuh @@ -13,7 +13,7 @@ // // Templated wrappers to express math for different scalar and vector // types, so kernels can have the same written form but can operate -// over half and float, and on vector types transparently +// over half, bfloat16 and float, and on vector types transparently // namespace faiss { @@ -556,5 +556,240 @@ struct Math { } }; +#ifndef USE_AMD_ROCM + +template <> +struct Math<__nv_bfloat16> { + typedef __nv_bfloat16 ScalarType; + + static inline __device__ __nv_bfloat16 + add(__nv_bfloat16 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hadd(a, b); +#else + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); +#endif + } + + static inline __device__ __nv_bfloat16 + sub(__nv_bfloat16 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hsub(a, b); +#else + return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b)); +#endif + } + + static inline __device__ __nv_bfloat16 + mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hmul(a, b); +#else + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); +#endif + } + + static inline __device__ __nv_bfloat16 neg(__nv_bfloat16 v) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hneg(v); +#else + return __float2bfloat16(-__bfloat162float(v)); +#endif + } + + static inline __device__ float reduceAdd(__nv_bfloat16 v) { + return ConvertTo::to(v); + } + + static inline __device__ bool lt(__nv_bfloat16 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hlt(a, b); +#else + return __bfloat162float(a) < __bfloat162float(b); +#endif + } + + static inline __device__ bool gt(__nv_bfloat16 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hgt(a, b); +#else + return __bfloat162float(a) > __bfloat162float(b); +#endif + } + + static inline __device__ bool eq(__nv_bfloat16 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __heq(a, b); +#else + return __bfloat162float(a) == __bfloat162float(b); +#endif + } + + static inline __device__ __nv_bfloat16 zero() { +#if CUDA_VERSION >= 9000 + return 0.0f; +#else + __nv_bfloat16 h; + h.x = 0; + return h; +#endif + } +}; + +template <> +struct Math<__nv_bfloat162> { + typedef __nv_bfloat16 ScalarType; + +#ifndef FAISS_USE_FULL_BFLOAT16 + // define a few conversion functions that don't exist on cuda 11 + // this overrides their definition in cuda 12 but we use native bf16 on this + // platform anyways. + static inline __device__ float2 __bfloat1622float2(__nv_bfloat162 a) { + float2 af; + af.x = __bfloat162float(a.x); + af.y = __bfloat162float(a.y); + return af; + } + + static inline __device__ __nv_bfloat162 __float22bfloat162_rn(float2 af) { + __nv_bfloat162 a; + a.x = __float2bfloat16_rn(af.x); + a.y = __float2bfloat16_rn(af.y); + return a; + } + + static inline __device__ __nv_bfloat162 + __bfloat162bfloat162(__nv_bfloat16 v) { + __nv_bfloat162 a; + a.x = v; + a.y = v; + return a; + } +#endif + + static inline __device__ __nv_bfloat162 + add(__nv_bfloat162 a, __nv_bfloat162 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hadd2(a, b); +#else + float2 af = __bfloat1622float2(a); + float2 bf = __bfloat1622float2(b); + + af.x += bf.x; + af.y += bf.y; + + return __float22bfloat162_rn(af); +#endif + } + + static inline __device__ __nv_bfloat162 + sub(__nv_bfloat162 a, __nv_bfloat162 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hsub2(a, b); +#else + float2 af = __bfloat1622float2(a); + float2 bf = __bfloat1622float2(b); + + af.x -= bf.x; + af.y -= bf.y; + + return __float22bfloat162_rn(af); +#endif + } + + static inline __device__ __nv_bfloat162 + add(__nv_bfloat162 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + __nv_bfloat162 b2 = __bfloat162bfloat162(b); + return __hadd2(a, b2); +#else + float2 af = __bfloat1622float2(a); + float bf = __bfloat162float(b); + + af.x += bf; + af.y += bf; + + return __float22bfloat162_rn(af); +#endif + } + + static inline __device__ __nv_bfloat162 + sub(__nv_bfloat162 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + __nv_bfloat162 b2 = __bfloat162bfloat162(b); + return __hsub2(a, b2); +#else + float2 af = __bfloat1622float2(a); + float bf = __bfloat162float(b); + + af.x -= bf; + af.y -= bf; + + return __float22bfloat162_rn(af); +#endif + } + + static inline __device__ __nv_bfloat162 + mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hmul2(a, b); +#else + float2 af = __bfloat1622float2(a); + float2 bf = __bfloat1622float2(b); + + af.x *= bf.x; + af.y *= bf.y; + + return __float22bfloat162_rn(af); +#endif + } + + static inline __device__ __nv_bfloat162 + mul(__nv_bfloat162 a, __nv_bfloat16 b) { +#ifdef FAISS_USE_FULL_BFLOAT16 + __nv_bfloat162 b2 = __bfloat162bfloat162(b); + return __hmul2(a, b2); +#else + float2 af = __bfloat1622float2(a); + float bf = __bfloat162float(b); + + af.x *= bf; + af.y *= bf; + + return __float22bfloat162_rn(af); +#endif + } + + static inline __device__ __nv_bfloat162 neg(__nv_bfloat162 v) { +#ifdef FAISS_USE_FULL_BFLOAT16 + return __hneg2(v); +#else + float2 vf = __bfloat1622float2(v); + vf.x = -vf.x; + vf.y = -vf.y; + + return __float22bfloat162_rn(vf); +#endif + } + + static inline __device__ float reduceAdd(__nv_bfloat162 v) { + float2 vf = __bfloat1622float2(v); + vf.x += vf.y; + + return vf.x; + } + + // not implemented for vector types + // static inline __device__ bool lt(__nv_bfloat162 a, __nv_bfloat162 b); + // static inline __device__ bool gt(__nv_bfloat162 a, __nv_bfloat162 b); + // static inline __device__ bool eq(__nv_bfloat162 a, __nv_bfloat162 b); + + static inline __device__ __nv_bfloat162 zero() { + return __bfloat162bfloat162(Math<__nv_bfloat16>::zero()); + } +}; + +#endif // !USE_AMD_ROCM + } // namespace gpu } // namespace faiss diff --git a/faiss/gpu/utils/MatrixMult-inl.cuh b/faiss/gpu/utils/MatrixMult-inl.cuh index 98fd0956cd..2c85d7244d 100644 --- a/faiss/gpu/utils/MatrixMult-inl.cuh +++ b/faiss/gpu/utils/MatrixMult-inl.cuh @@ -21,6 +21,7 @@ template struct GetCudaType; #ifdef USE_AMD_ROCM + template <> struct GetCudaType { static constexpr hipblasDatatype_t Type = HIPBLAS_R_32F; @@ -30,7 +31,15 @@ template <> struct GetCudaType { static constexpr hipblasDatatype_t Type = HIPBLAS_R_16F; }; + +// FIXME: no AMD support for bf16 +// template <> +// struct GetCudaType<__nv_bfloat16> { +// static constexpr hipblasDatatype_t Type = HIPBLAS_R_16B; +// }; + #else + template <> struct GetCudaType { static constexpr cudaDataType_t Type = CUDA_R_32F; @@ -40,6 +49,12 @@ template <> struct GetCudaType { static constexpr cudaDataType_t Type = CUDA_R_16F; }; + +template <> +struct GetCudaType<__nv_bfloat16> { + static constexpr cudaDataType_t Type = CUDA_R_16BF; +}; + #endif template