diff --git a/xml/ConversionOperators_8cuh.xml b/xml/ConversionOperators_8cuh.xml
index d1f322eee9..fc83b0b5f8 100644
--- a/xml/ConversionOperators_8cuh.xml
+++ b/xml/ConversionOperators_8cuh.xml
@@ -31,30 +31,14 @@
//Conversionutilities
//
-template<typenameFrom,typenameTo>
-structConvert{
-inline__device__Tooperator()(Fromv)const{
-return(To)v;
-}
-};
-
-template<>
-structConvert<float,half>{
-inline__device__halfoperator()(floatv)const{
-return__float2half(v);
-}
-};
-
-template<>
-structConvert<half,float>{
-inline__device__floatoperator()(halfv)const{
-return__half2float(v);
+template<typenameT>
+structConvertTo{
+template<typenameU>
+staticinline__device__Tto(Uv){
+returnT(v);
}
};
-template<typenameT>
-structConvertTo{};
-
template<>
structConvertTo<float>{
staticinline__device__floatto(floatv){
@@ -63,6 +47,12 @@
staticinline__device__floatto(halfv){
return__half2float(v);
}
+
+#ifndefUSE_AMD_ROCM
+staticinline__device__floatto(__nv_bfloat16v){
+return__bfloat162float(v);
+}
+#endif//!USE_AMD_ROCM
};
template<>
@@ -115,6 +105,31 @@
}
};
+//nobf16supportforAMD
+#ifndefUSE_AMD_ROCM
+
+template<>
+structConvertTo<__nv_bfloat16>{
+staticinline__device____nv_bfloat16to(floatv){
+return__float2bfloat16(v);
+}
+staticinline__device____nv_bfloat16to(halfv){
+return__float2bfloat16(__half2float(v));
+}
+staticinline__device____nv_bfloat16to(__nv_bfloat16v){
+returnv;
+}
+};
+
+#endif//USE_AMD_ROCM
+
+template<typenameFrom,typenameTo>
+structConvert{
+inline__device__Tooperator()(Fromv)const{
+returnConvertTo<To>::to(v);
+}
+};
+
//Tensorconversion
template<typenameFrom,typenameTo>
voidrunConvert(constFrom*in,To*out,size_tnum,cudaStream_tstream){
diff --git a/xml/Distance_8cuh.xml b/xml/Distance_8cuh.xml
index 637d9ebada..f90964e51f 100644
--- a/xml/Distance_8cuh.xml
+++ b/xml/Distance_8cuh.xml
@@ -50,6 +50,19 @@
boolqueriesRowMajor,
Tensor<float,2,true>&outDistances);
+//nobf16supportforAMD
+#ifndefUSE_AMD_ROCM
+voidrunAllPairwiseL2Distance(
+GpuResources*res,
+cudaStream_tstream,
+Tensor<__nv_bfloat16,2,true>&vectors,
+boolvectorsRowMajor,
+Tensor<float,1,true>*vectorNorms,
+Tensor<__nv_bfloat16,2,true>&queries,
+boolqueriesRowMajor,
+Tensor<float,2,true>&outDistances);
+#endif//USE_AMD_ROCM
+
voidrunAllPairwiseIPDistance(
GpuResources*res,
cudaStream_tstream,
@@ -68,6 +81,18 @@
boolqueriesRowMajor,
Tensor<float,2,true>&outDistances);
+//nobf16supportforAMD
+#ifndefUSE_AMD_ROCM
+voidrunAllPairwiseIPDistance(
+GpuResources*res,
+cudaStream_tstream,
+Tensor<__nv_bfloat16,2,true>&vectors,
+boolvectorsRowMajor,
+Tensor<__nv_bfloat16,2,true>&queries,
+boolqueriesRowMajor,
+Tensor<float,2,true>&outDistances);
+#endif//USE_AMD_ROCM
+
///Calculatesbrute-forceL2distancebetween`vectors`and
///`queries`,returningthekclosestresultsseen
voidrunL2Distance(
@@ -100,6 +125,22 @@
Tensor<idx_t,2,true>&outIndices,
boolignoreOutDistances=false);
+//nobf16supportforAMD
+#ifndefUSE_AMD_ROCM
+voidrunL2Distance(
+GpuResources*resources,
+cudaStream_tstream,
+Tensor<__nv_bfloat16,2,true>&vectors,
+boolvectorsRowMajor,
+Tensor<float,1,true>*vectorNorms,
+Tensor<__nv_bfloat16,2,true>&queries,
+boolqueriesRowMajor,
+intk,
+Tensor<float,2,true>&outDistances,
+Tensor<idx_t,2,true>&outIndices,
+boolignoreOutDistances=false);
+#endif//USE_AMD_ROCM
+
///Calculatesbrute-forceinnerproductdistancebetween`vectors`
///and`queries`,returningthekclosestresultsseen
voidrunIPDistance(
@@ -124,6 +165,20 @@
Tensor<float,2,true>&outDistances,
Tensor<idx_t,2,true>&outIndices);
+//nobf16supportforAMD
+#ifndefUSE_AMD_ROCM
+voidrunIPDistance(
+GpuResources*resources,
+cudaStream_tstream,
+Tensor<__nv_bfloat16,2,true>&vectors,
+boolvectorsRowMajor,
+Tensor<__nv_bfloat16,2,true>&queries,
+boolqueriesRowMajor,
+intk,
+Tensor<float,2,true>&outDistances,
+Tensor<idx_t,2,true>&outIndices);
+#endif//USE_AMD_ROCM
+
//
//Generaldistanceimplementation,assumesthatallargumentsareonthe
//device.Thisisthetop-levelinternaldistancefunctiontocalltodispatch
diff --git a/xml/Float16_8cuh.xml b/xml/Float16_8cuh.xml
index f46f54a6f8..8e0033b896 100644
--- a/xml/Float16_8cuh.xml
+++ b/xml/Float16_8cuh.xml
@@ -25,7 +25,21 @@
#defineFAISS_USE_FULL_FLOAT161
#endif//__CUDA_ARCH__types
+//Somecomputecapabilitieshavefullbfloat16ALUs.
+//FIXME:nosupportinROCmyet
+#if__CUDA_ARCH__>=800//||defined(USE_AMD_ROCM)
+#defineFAISS_USE_FULL_BFLOAT161
+#endif//__CUDA_ARCH__types
+
#include<cuda_fp16.h>
+#if!defined(USE_AMD_ROCM)
+#include<cuda_bf16.h>
+#endif
+//#else
+//FIXME:nosupportinROCmyet
+//#include<amd_hip_bf16.h>
+//#include<amd_hip_fp16.h>
+//#endif//!defined(USE_AMD_ROCM)
namespacefaiss{
namespacegpu{
diff --git a/xml/GeneralDistance_8cuh.xml b/xml/GeneralDistance_8cuh.xml
index b899f8eb95..d4374aff2e 100644
--- a/xml/GeneralDistance_8cuh.xml
+++ b/xml/GeneralDistance_8cuh.xml
@@ -160,10 +160,10 @@
boolkInBounds=k<query.getSize(1);
queryTileBase[threadIdx.x+i*TILE_SIZE]=
-kInBounds?queryBase[k]:ConvertTo<T>::to(0);
+kInBounds?queryBase[k]:ConvertTo<T>::to(0.0f);
vecTileBase[threadIdx.x+i*TILE_SIZE]=
-kInBounds?vecBase[k]:ConvertTo<T>::to(0);
+kInBounds?vecBase[k]:ConvertTo<T>::to(0.0f);
}
__syncthreads();
@@ -194,10 +194,10 @@
for(idx_tk=threadIdx.x;k<limit;k+=TILE_SIZE){
//Loadquerytile
queryTileBase[threadIdx.x]=
-queryThreadInBounds?queryBase[k]:ConvertTo<T>::to(0);
+queryThreadInBounds?queryBase[k]:ConvertTo<T>::to(0.0f);
vecTileBase[threadIdx.x]=
-vecThreadInBoundsLoad?vecBase[k]:ConvertTo<T>::to(0);
+vecThreadInBoundsLoad?vecBase[k]:ConvertTo<T>::to(0.0f);
__syncthreads();
@@ -220,11 +220,11 @@
//Loadquerytile
queryTileBase[threadIdx.x]=queryThreadInBounds&&kInBounds
?queryBase[k]
-:ConvertTo<T>::to(0);
+:ConvertTo<T>::to(0.0f);
vecTileBase[threadIdx.x]=vecThreadInBoundsLoad&&kInBounds
?vecBase[k]
-:ConvertTo<T>::to(0);
+:ConvertTo<T>::to(0.0f);
__syncthreads();
diff --git a/xml/GpuDistance_8h.xml b/xml/GpuDistance_8h.xml
index d49fa61ffb..7587761d81 100644
--- a/xml/GpuDistance_8h.xml
+++ b/xml/GpuDistance_8h.xml
@@ -83,165 +83,166 @@
enumclassDistanceDataType{
F32=1,
F16,
-};
-
-
-enumclassIndicesDataType{
-I64=1,
-I32,
-};
-
-
-struct[GpuDistanceParams]{
-
-
-
-
-
-[faiss::MetricType][metric]=[METRIC_L2];
-
-
-
-float[metricArg]=0;
-
-
-
-
-int[k]=0;
-
-
-int[dims]=0;
-
-
-
-
-
-
-
-
-constvoid*[vectors]=nullptr;
-DistanceDataTypevectorType=DistanceDataType::F32;
-boolvectorsRowMajor=true;
-[idx_t]numVectors=0;
-
-
-
-constfloat*[vectorNorms]=nullptr;
-
-
-
-
-
-
-
-
-
-constvoid*[queries]=nullptr;
-DistanceDataTypequeryType=DistanceDataType::F32;
-boolqueriesRowMajor=true;
-[idx_t]numQueries=0;
-
-
-
-
-
-
-
-
-float*[outDistances]=nullptr;
-
-
-
-bool[ignoreOutDistances]=false;
-
-
-
-IndicesDataType[outIndicesType]=IndicesDataType::I64;
-void*outIndices=nullptr;
-
-
-
-
-
-
-
-
-
-
-int[device]=-1;
-
-
-#ifdefinedUSE_NVIDIA_CUVS
-bool[use_cuvs]=true;
-#else
-bool[use_cuvs]=false;
-#endif
-};
-
-
-
-boolshould_use_cuvs([GpuDistanceParams]args);
-
-
-
-
-
-
-
-
-
-
-
-
-
-voidbfKnn([GpuResourcesProvider]*resources,const[GpuDistanceParams]&args);
-
-
-
-
-
-
-
-
-
-
-
-
-
-voidbfKnn_tiling(
-[GpuResourcesProvider]*resources,
-const[GpuDistanceParams]&args,
-size_tvectorsMemoryLimit,
-size_tqueriesMemoryLimit);
-
-
-voidbruteForceKnn(
-[GpuResourcesProvider]*resources,
-[faiss::MetricType]metric,
-
-
-
-constfloat*vectors,
-boolvectorsRowMajor,
-[idx_t]numVectors,
-
-
-
-constfloat*queries,
-boolqueriesRowMajor,
-[idx_t]numQueries,
-intdims,
-intk,
-
-
-float*outDistances,
-
-
-[idx_t]*outIndices);
-
-}
-}
-#pragmaGCCvisibilitypop
+BF16,
+};
+
+
+enumclassIndicesDataType{
+I64=1,
+I32,
+};
+
+
+struct[GpuDistanceParams]{
+
+
+
+
+
+[faiss::MetricType][metric]=[METRIC_L2];
+
+
+
+float[metricArg]=0;
+
+
+
+
+int[k]=0;
+
+
+int[dims]=0;
+
+
+
+
+
+
+
+
+constvoid*[vectors]=nullptr;
+DistanceDataTypevectorType=DistanceDataType::F32;
+boolvectorsRowMajor=true;
+[idx_t]numVectors=0;
+
+
+
+constfloat*[vectorNorms]=nullptr;
+
+
+
+
+
+
+
+
+
+constvoid*[queries]=nullptr;
+DistanceDataTypequeryType=DistanceDataType::F32;
+boolqueriesRowMajor=true;
+[idx_t]numQueries=0;
+
+
+
+
+
+
+
+
+float*[outDistances]=nullptr;
+
+
+
+bool[ignoreOutDistances]=false;
+
+
+
+IndicesDataType[outIndicesType]=IndicesDataType::I64;
+void*outIndices=nullptr;
+
+
+
+
+
+
+
+
+
+
+int[device]=-1;
+
+
+#ifdefinedUSE_NVIDIA_CUVS
+bool[use_cuvs]=true;
+#else
+bool[use_cuvs]=false;
+#endif
+};
+
+
+
+boolshould_use_cuvs([GpuDistanceParams]args);
+
+
+
+
+
+
+
+
+
+
+
+
+
+voidbfKnn([GpuResourcesProvider]*resources,const[GpuDistanceParams]&args);
+
+
+
+
+
+
+
+
+
+
+
+
+
+voidbfKnn_tiling(
+[GpuResourcesProvider]*resources,
+const[GpuDistanceParams]&args,
+size_tvectorsMemoryLimit,
+size_tqueriesMemoryLimit);
+
+
+voidbruteForceKnn(
+[GpuResourcesProvider]*resources,
+[faiss::MetricType]metric,
+
+
+
+constfloat*vectors,
+boolvectorsRowMajor,
+[idx_t]numVectors,
+
+
+
+constfloat*queries,
+boolqueriesRowMajor,
+[idx_t]numQueries,
+intdims,
+intk,
+
+
+float*outDistances,
+
+
+[idx_t]*outIndices);
+
+}
+}
+#pragmaGCCvisibilitypop
diff --git a/xml/GpuResources_8h.xml b/xml/GpuResources_8h.xml
index 8f055a812d..c2f95dd464 100644
--- a/xml/GpuResources_8h.xml
+++ b/xml/GpuResources_8h.xml
@@ -407,105 +407,111 @@
virtualvoid[initializeForDevice](intdevice)=0;
-
-virtualcublasHandle_t[getBlasHandle](intdevice)=0;
+
+virtualbool[supportsBFloat16](intdevice)=0;
-
-
-virtualcudaStream_t[getDefaultStream](intdevice)=0;
-
-#ifdefinedUSE_NVIDIA_CUVS
-
-
-virtualraft::device_resources&getRaftHandle(intdevice)=0;
-raft::device_resources&getRaftHandleCurrentDevice();
-#endif
-
-
-
-
-virtualvoid[setDefaultStream](intdevice,cudaStream_tstream)=0;
-
-
-virtualstd::vector<cudaStream_t>[getAlternateStreams](intdevice)=0;
+
+virtualcublasHandle_t[getBlasHandle](intdevice)=0;
+
+
+
+virtualcudaStream_t[getDefaultStream](intdevice)=0;
+
+#ifdefinedUSE_NVIDIA_CUVS
+
+
+virtualraft::device_resources&getRaftHandle(intdevice)=0;
+raft::device_resources&getRaftHandleCurrentDevice();
+#endif
+
+
+
+
+virtualvoid[setDefaultStream](intdevice,cudaStream_tstream)=0;
-
-
-
-
-
-
-virtualvoid*[allocMemory](const[AllocRequest]&req)=0;
-
-
-virtualvoid[deallocMemory](intdevice,void*in)=0;
+
+virtualstd::vector<cudaStream_t>[getAlternateStreams](intdevice)=0;
+
+
+
+
+
+
+
+virtualvoid*[allocMemory](const[AllocRequest]&req)=0;
-
-
-virtualsize_t[getTempMemoryAvailable](intdevice)const=0;
-
-
-virtualstd::pair<void*,size_t>[getPinnedMemory]()=0;
+
+virtualvoid[deallocMemory](intdevice,void*in)=0;
+
+
+
+virtualsize_t[getTempMemoryAvailable](intdevice)const=0;
-
-virtualcudaStream_t[getAsyncCopyStream](intdevice)=0;
+
+virtualstd::pair<void*,size_t>[getPinnedMemory]()=0;
-
-
-
-
-
-cublasHandle_t[getBlasHandleCurrentDevice]();
-
-
-cudaStream_t[getDefaultStreamCurrentDevice]();
+
+virtualcudaStream_t[getAsyncCopyStream](intdevice)=0;
+
+
+
+
+
+
+bool[supportsBFloat16CurrentDevice]();
-
-size_t[getTempMemoryAvailableCurrentDevice]()const;
+
+cublasHandle_t[getBlasHandleCurrentDevice]();
-
-[GpuMemoryReservation][allocMemoryHandle](const[AllocRequest]&req);
+
+cudaStream_t[getDefaultStreamCurrentDevice]();
-
-
-
-void[syncDefaultStream](intdevice);
-
-
-void[syncDefaultStreamCurrentDevice]();
-
-
-std::vector<cudaStream_t>[getAlternateStreamsCurrentDevice]();
+
+size_t[getTempMemoryAvailableCurrentDevice]()const;
+
+
+[GpuMemoryReservation][allocMemoryHandle](const[AllocRequest]&req);
+
+
+
+
+void[syncDefaultStream](intdevice);
-
-cudaStream_t[getAsyncCopyStreamCurrentDevice]();
-};
-
-
-
-class[GpuResourcesProvider]{
-public:
-virtual~[GpuResourcesProvider]();
+
+void[syncDefaultStreamCurrentDevice]();
+
+
+std::vector<cudaStream_t>[getAlternateStreamsCurrentDevice]();
+
+
+cudaStream_t[getAsyncCopyStreamCurrentDevice]();
+};
-
-virtualstd::shared_ptr<GpuResources>[getResources]()=0;
-};
-
-
-
-class[GpuResourcesProviderFromInstance]:public[GpuResourcesProvider]{
-public:
-explicit[GpuResourcesProviderFromInstance](std::shared_ptr<GpuResources>p);
-~[GpuResourcesProviderFromInstance]()override;
-
-std::shared_ptr<GpuResources>[getResources]()override;
-
-private:
-std::shared_ptr<GpuResources>res_;
-};
+
+
+class[GpuResourcesProvider]{
+public:
+virtual~[GpuResourcesProvider]();
+
+
+virtualstd::shared_ptr<GpuResources>[getResources]()=0;
+};
+
+
+
+class[GpuResourcesProviderFromInstance]:public[GpuResourcesProvider]{
+public:
+explicit[GpuResourcesProviderFromInstance](std::shared_ptr<GpuResources>p);
+~[GpuResourcesProviderFromInstance]()override;
-}
-}
+std::shared_ptr<GpuResources>[getResources]()override;
+
+private:
+std::shared_ptr<GpuResources>res_;
+};
+
+}
+}
diff --git a/xml/GpuScalarQuantizer_8cuh.xml b/xml/GpuScalarQuantizer_8cuh.xml
index b2c667484f..c82398b756 100644
--- a/xml/GpuScalarQuantizer_8cuh.xml
+++ b/xml/GpuScalarQuantizer_8cuh.xml
@@ -163,7 +163,7 @@
inline__device__voiddecode(void*data,idx_tvec,intd,float*out)
const{
half*p=(half*)&((uint8_t*)data)[vec*bytesPerVec];
-out[0]=Convert<half,float>()(p[d]);
+out[0]=ConvertTo<float>::to(p[d]);
}
inline__device__floatdecodePartial(
@@ -181,7 +181,7 @@
intd,
floatv[kDimPerIter])const{
half*p=(half*)&((uint8_t*)data)[vec*bytesPerVec];
-p[d]=Convert<float,half>()(v[0]);
+p[d]=ConvertTo<half>::to(v[0]);
}
inline__device__voidencodePartial(
@@ -200,11 +200,11 @@
staticconstexprintkEncodeBits=16;
inline__device__EncodeTencodeNew(intdim,floatv)const{
-returnConvert<float,half>()(v);
+returnConvertTo<half>::to(v);
}
inline__device__floatdecodeNew(intdim,EncodeTv)const{
-returnConvert<half,float>()(v);
+returnConvertTo<float>::to(v);
}
intbytesPerVec;
diff --git a/xml/L2Norm_8cuh.xml b/xml/L2Norm_8cuh.xml
index 94144f1922..53adc2ac4a 100644
--- a/xml/L2Norm_8cuh.xml
+++ b/xml/L2Norm_8cuh.xml
@@ -16,7 +16,7 @@
#pragmaonce
-#include<cuda_fp16.h>
+#include<faiss/gpu/utils/Float16.cuh>
#include<faiss/gpu/utils/Tensor.cuh>
namespacefaiss{
@@ -36,6 +36,16 @@
boolnormSquared,
cudaStream_tstream);
+//nobf16supportforAMD
+#ifndefUSE_AMD_ROCM
+voidrunL2Norm(
+Tensor<__nv_bfloat16,2,true>&input,
+boolinputRowMajor,
+Tensor<float,1,true>&output,
+boolnormSquared,
+cudaStream_tstream);
+#endif
+
}//namespacegpu
}//namespacefaiss
diff --git a/xml/MathOperators_8cuh.xml b/xml/MathOperators_8cuh.xml
index b4d422d808..dc7362cae3 100644
--- a/xml/MathOperators_8cuh.xml
+++ b/xml/MathOperators_8cuh.xml
@@ -22,7 +22,7 @@
//
//Templatedwrapperstoexpressmathfordifferentscalarandvector
//types,sokernelscanhavethesamewrittenformbutcanoperate
-//overhalfandfloat,andonvectortypestransparently
+//overhalf,bfloat16andfloat,andonvectortypestransparently
//
namespacefaiss{
@@ -565,6 +565,241 @@
}
};
+#ifndefUSE_AMD_ROCM
+
+template<>
+structMath<__nv_bfloat16>{
+typedef__nv_bfloat16ScalarType;
+
+staticinline__device____nv_bfloat16
+add(__nv_bfloat16a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hadd(a,b);
+#else
+return__float2bfloat16(__bfloat162float(a)+__bfloat162float(b));
+#endif
+}
+
+staticinline__device____nv_bfloat16
+sub(__nv_bfloat16a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hsub(a,b);
+#else
+return__float2bfloat16(__bfloat162float(a)-__bfloat162float(b));
+#endif
+}
+
+staticinline__device____nv_bfloat16
+mul(__nv_bfloat16a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hmul(a,b);
+#else
+return__float2bfloat16(__bfloat162float(a)*__bfloat162float(b));
+#endif
+}
+
+staticinline__device____nv_bfloat16neg(__nv_bfloat16v){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hneg(v);
+#else
+return__float2bfloat16(-__bfloat162float(v));
+#endif
+}
+
+staticinline__device__floatreduceAdd(__nv_bfloat16v){
+returnConvertTo<float>::to(v);
+}
+
+staticinline__device__boollt(__nv_bfloat16a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hlt(a,b);
+#else
+return__bfloat162float(a)<__bfloat162float(b);
+#endif
+}
+
+staticinline__device__boolgt(__nv_bfloat16a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hgt(a,b);
+#else
+return__bfloat162float(a)>__bfloat162float(b);
+#endif
+}
+
+staticinline__device__booleq(__nv_bfloat16a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__heq(a,b);
+#else
+return__bfloat162float(a)==__bfloat162float(b);
+#endif
+}
+
+staticinline__device____nv_bfloat16zero(){
+#ifCUDA_VERSION>=9000
+return0.0f;
+#else
+__nv_bfloat16h;
+h.x=0;
+returnh;
+#endif
+}
+};
+
+template<>
+structMath<__nv_bfloat162>{
+typedef__nv_bfloat16ScalarType;
+
+#ifndefFAISS_USE_FULL_BFLOAT16
+//defineafewconversionfunctionsthatdon'texistoncuda11
+//thisoverridestheirdefinitionincuda12butweusenativebf16onthis
+//platformanyways.
+staticinline__device__float2__bfloat1622float2(__nv_bfloat162a){
+float2af;
+af.x=__bfloat162float(a.x);
+af.y=__bfloat162float(a.y);
+returnaf;
+}
+
+staticinline__device____nv_bfloat162__float22bfloat162_rn(float2af){
+__nv_bfloat162a;
+a.x=__float2bfloat16_rn(af.x);
+a.y=__float2bfloat16_rn(af.y);
+returna;
+}
+
+staticinline__device____nv_bfloat162
+__bfloat162bfloat162(__nv_bfloat16v){
+__nv_bfloat162a;
+a.x=v;
+a.y=v;
+returna;
+}
+#endif
+
+staticinline__device____nv_bfloat162
+add(__nv_bfloat162a,__nv_bfloat162b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hadd2(a,b);
+#else
+float2af=__bfloat1622float2(a);
+float2bf=__bfloat1622float2(b);
+
+af.x+=bf.x;
+af.y+=bf.y;
+
+return__float22bfloat162_rn(af);
+#endif
+}
+
+staticinline__device____nv_bfloat162
+sub(__nv_bfloat162a,__nv_bfloat162b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hsub2(a,b);
+#else
+float2af=__bfloat1622float2(a);
+float2bf=__bfloat1622float2(b);
+
+af.x-=bf.x;
+af.y-=bf.y;
+
+return__float22bfloat162_rn(af);
+#endif
+}
+
+staticinline__device____nv_bfloat162
+add(__nv_bfloat162a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+__nv_bfloat162b2=__bfloat162bfloat162(b);
+return__hadd2(a,b2);
+#else
+float2af=__bfloat1622float2(a);
+floatbf=__bfloat162float(b);
+
+af.x+=bf;
+af.y+=bf;
+
+return__float22bfloat162_rn(af);
+#endif
+}
+
+staticinline__device____nv_bfloat162
+sub(__nv_bfloat162a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+__nv_bfloat162b2=__bfloat162bfloat162(b);
+return__hsub2(a,b2);
+#else
+float2af=__bfloat1622float2(a);
+floatbf=__bfloat162float(b);
+
+af.x-=bf;
+af.y-=bf;
+
+return__float22bfloat162_rn(af);
+#endif
+}
+
+staticinline__device____nv_bfloat162
+mul(__nv_bfloat162a,__nv_bfloat162b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hmul2(a,b);
+#else
+float2af=__bfloat1622float2(a);
+float2bf=__bfloat1622float2(b);
+
+af.x*=bf.x;
+af.y*=bf.y;
+
+return__float22bfloat162_rn(af);
+#endif
+}
+
+staticinline__device____nv_bfloat162
+mul(__nv_bfloat162a,__nv_bfloat16b){
+#ifdefFAISS_USE_FULL_BFLOAT16
+__nv_bfloat162b2=__bfloat162bfloat162(b);
+return__hmul2(a,b2);
+#else
+float2af=__bfloat1622float2(a);
+floatbf=__bfloat162float(b);
+
+af.x*=bf;
+af.y*=bf;
+
+return__float22bfloat162_rn(af);
+#endif
+}
+
+staticinline__device____nv_bfloat162neg(__nv_bfloat162v){
+#ifdefFAISS_USE_FULL_BFLOAT16
+return__hneg2(v);
+#else
+float2vf=__bfloat1622float2(v);
+vf.x=-vf.x;
+vf.y=-vf.y;
+
+return__float22bfloat162_rn(vf);
+#endif
+}
+
+staticinline__device__floatreduceAdd(__nv_bfloat162v){
+float2vf=__bfloat1622float2(v);
+vf.x+=vf.y;
+
+returnvf.x;
+}
+
+//notimplementedforvectortypes
+//staticinline__device__boollt(__nv_bfloat162a,__nv_bfloat162b);
+//staticinline__device__boolgt(__nv_bfloat162a,__nv_bfloat162b);
+//staticinline__device__booleq(__nv_bfloat162a,__nv_bfloat162b);
+
+staticinline__device____nv_bfloat162zero(){
+return__bfloat162bfloat162(Math<__nv_bfloat16>::zero());
+}
+};
+
+#endif//!USE_AMD_ROCM
+
}//namespacegpu
}//namespacefaiss
diff --git a/xml/MatrixMult-inl_8cuh.xml b/xml/MatrixMult-inl_8cuh.xml
index ad6180ba40..de7aede8b1 100644
--- a/xml/MatrixMult-inl_8cuh.xml
+++ b/xml/MatrixMult-inl_8cuh.xml
@@ -30,6 +30,7 @@
structGetCudaType;
#ifdefUSE_AMD_ROCM
+
template<>
structGetCudaType<float>{
staticconstexprhipblasDatatype_tType=HIPBLAS_R_32F;
@@ -39,7 +40,15 @@
structGetCudaType<half>{
staticconstexprhipblasDatatype_tType=HIPBLAS_R_16F;
};
+
+//FIXME:noAMDsupportforbf16
+//template<>
+//structGetCudaType<__nv_bfloat16>{
+//staticconstexprhipblasDatatype_tType=HIPBLAS_R_16B;
+//};
+
#else
+
template<>
structGetCudaType<float>{
staticconstexprcudaDataType_tType=CUDA_R_32F;
@@ -49,6 +58,12 @@
structGetCudaType<half>{
staticconstexprcudaDataType_tType=CUDA_R_16F;
};
+
+template<>
+structGetCudaType<__nv_bfloat16>{
+staticconstexprcudaDataType_tType=CUDA_R_16BF;
+};
+
#endif
template<typenameAT,typenameBT>
diff --git a/xml/StandardGpuResources_8h.xml b/xml/StandardGpuResources_8h.xml
index 03fcebf86e..3a3361566d 100644
--- a/xml/StandardGpuResources_8h.xml
+++ b/xml/StandardGpuResources_8h.xml
@@ -224,216 +224,225 @@
~[StandardGpuResourcesImpl]()override;
-
-
-void[noTempMemory]();
-
-
-
-
-
-
-void[setTempMemory](size_tsize);
-
-
-
-void[setPinnedMemory](size_tsize);
-
-
-
-
-
-
-void[setDefaultStream](intdevice,cudaStream_tstream)override;
-
-
-
-void[revertDefaultStream](intdevice);
-
-
-
-
-
-cudaStream_t[getDefaultStream](intdevice)override;
-
-#ifdefinedUSE_NVIDIA_CUVS
-
-
-raft::device_resources&getRaftHandle(intdevice)override;
-#endif
-
-
-
-void[setDefaultNullStreamAllDevices]();
-
-
-
-void[setLogMemoryAllocations](boolenable);
-
-public:
-
-
-
-void[initializeForDevice](intdevice)override;
-
-cublasHandle_t[getBlasHandle](intdevice)override;
-
-std::vector<cudaStream_t>[getAlternateStreams](intdevice)override;
-
-
-void*[allocMemory](const[AllocRequest]&req)override;
+
+bool[supportsBFloat16](intdevice)override;
+
+
+
+void[noTempMemory]();
+
+
+
+
+
+
+void[setTempMemory](size_tsize);
+
+
+
+void[setPinnedMemory](size_tsize);
+
+
+
+
+
+
+void[setDefaultStream](intdevice,cudaStream_tstream)override;
+
+
+
+void[revertDefaultStream](intdevice);
+
+
+
+
+
+cudaStream_t[getDefaultStream](intdevice)override;
+
+#ifdefinedUSE_NVIDIA_CUVS
+
+
+raft::device_resources&getRaftHandle(intdevice)override;
+#endif
+
+
+
+void[setDefaultNullStreamAllDevices]();
+
+
+
+void[setLogMemoryAllocations](boolenable);
+
+public:
+
+
+
+void[initializeForDevice](intdevice)override;
+
+cublasHandle_t[getBlasHandle](intdevice)override;
+
+std::vector<cudaStream_t>[getAlternateStreams](intdevice)override;
-
-void[deallocMemory](intdevice,void*in)override;
-
-size_t[getTempMemoryAvailable](intdevice)constoverride;
-
-
-std::map<int,std::map<std::string,std::pair<int,size_t>>>[getMemoryInfo]()
-const;
-
-std::pair<void*,size_t>[getPinnedMemory]()override;
-
-cudaStream_t[getAsyncCopyStream](intdevice)override;
-
-protected:
-
-bool[isInitialized](intdevice)const;
-
-
-
-staticsize_t[getDefaultTempMemForGPU](intdevice,size_trequested);
-
-protected:
-
-
-std::unordered_map<int,std::unordered_map<void*,AllocRequest>>[allocs_];
-
-
-std::unordered_map<int,std::unique_ptr<StackDeviceMemory>>[tempMemory_];
+
+void*[allocMemory](const[AllocRequest]&req)override;
+
+
+void[deallocMemory](intdevice,void*in)override;
+
+size_t[getTempMemoryAvailable](intdevice)constoverride;
+
+
+std::map<int,std::map<std::string,std::pair<int,size_t>>>[getMemoryInfo]()
+const;
+
+std::pair<void*,size_t>[getPinnedMemory]()override;
+
+cudaStream_t[getAsyncCopyStream](intdevice)override;
+
+protected:
+
+bool[isInitialized](intdevice)const;
+
+
+
+staticsize_t[getDefaultTempMemForGPU](intdevice,size_trequested);
+
+protected:
+
+
+std::unordered_map<int,std::unordered_map<void*,AllocRequest>>[allocs_];
-
-std::unordered_map<int,cudaStream_t>[defaultStreams_];
+
+std::unordered_map<int,std::unique_ptr<StackDeviceMemory>>[tempMemory_];
-
-
-std::unordered_map<int,cudaStream_t>[userDefaultStreams_];
-
-
-std::unordered_map<int,std::vector<cudaStream_t>>[alternateStreams_];
+
+std::unordered_map<int,cudaStream_t>[defaultStreams_];
+
+
+
+std::unordered_map<int,cudaStream_t>[userDefaultStreams_];
-
-std::unordered_map<int,cudaStream_t>[asyncCopyStreams_];
+
+std::unordered_map<int,std::vector<cudaStream_t>>[alternateStreams_];
-
-std::unordered_map<int,cublasHandle_t>[blasHandles_];
-
-#ifdefinedUSE_NVIDIA_CUVS
-
-std::unordered_map<int,raft::device_resources>raftHandles_;
-
-
-
-
-
-
-
-
-
-
-
-
-std::unique_ptr<rmm::mr::device_memory_resource>mmr_;
+
+std::unordered_map<int,cudaStream_t>[asyncCopyStreams_];
+
+
+std::unordered_map<int,cublasHandle_t>[blasHandles_];
+
+#ifdefinedUSE_NVIDIA_CUVS
+
+std::unordered_map<int,raft::device_resources>raftHandles_;
+
+
+
+
+
+
+
+
+
+
-
-std::unique_ptr<rmm::mr::host_memory_resource>pmr_;
-#endif
-
-
-void*[pinnedMemAlloc_];
-size_tpinnedMemAllocSize_;
-
-
-
-size_t[tempMemSize_];
-
-
-size_t[pinnedMemSize_];
+
+std::unique_ptr<rmm::mr::device_memory_resource>mmr_;
+
+
+std::unique_ptr<rmm::mr::host_memory_resource>pmr_;
+#endif
+
+
+void*[pinnedMemAlloc_];
+size_tpinnedMemAllocSize_;
+
+
+
+size_t[tempMemSize_];
-
-bool[allocLogging_];
-};
-
-
-
-
-
-class[StandardGpuResources]:public[GpuResourcesProvider]{
-public:
-[StandardGpuResources]();
-~[StandardGpuResources]()override;
-
-std::shared_ptr<GpuResources>[getResources]()override;
-
-
-
-void[noTempMemory]();
-
-
-
-
-
-
-void[setTempMemory](size_tsize);
-
-
-
-void[setPinnedMemory](size_tsize);
-
-
-
-
-
-
-void[setDefaultStream](intdevice,cudaStream_tstream);
-
-
-
-void[revertDefaultStream](intdevice);
-
-
-
-void[setDefaultNullStreamAllDevices]();
-
-
-std::map<int,std::map<std::string,std::pair<int,size_t>>>[getMemoryInfo]()
-const;
-
-cudaStream_t[getDefaultStream](intdevice);
-
-#ifdefinedUSE_NVIDIA_CUVS
-
-
-raft::device_resources&getRaftHandle(intdevice);
-#endif
-
-
-size_t[getTempMemoryAvailable](intdevice)const;
-
-
-void[syncDefaultStreamCurrentDevice]();
-
-
-
-void[setLogMemoryAllocations](boolenable);
-
-private:
-std::shared_ptr<StandardGpuResourcesImpl>res_;
-};
-
-}
-}
-#pragmaGCCvisibilitypop
+
+size_t[pinnedMemSize_];
+
+
+bool[allocLogging_];
+};
+
+
+
+
+
+class[StandardGpuResources]:public[GpuResourcesProvider]{
+public:
+[StandardGpuResources]();
+~[StandardGpuResources]()override;
+
+std::shared_ptr<GpuResources>[getResources]()override;
+
+
+bool[supportsBFloat16](intdevice);
+
+
+bool[supportsBFloat16CurrentDevice]();
+
+
+
+void[noTempMemory]();
+
+
+
+
+
+
+void[setTempMemory](size_tsize);
+
+
+
+void[setPinnedMemory](size_tsize);
+
+
+
+
+
+
+void[setDefaultStream](intdevice,cudaStream_tstream);
+
+
+
+void[revertDefaultStream](intdevice);
+
+
+
+void[setDefaultNullStreamAllDevices]();
+
+
+std::map<int,std::map<std::string,std::pair<int,size_t>>>[getMemoryInfo]()
+const;
+
+cudaStream_t[getDefaultStream](intdevice);
+
+#ifdefinedUSE_NVIDIA_CUVS
+
+
+raft::device_resources&getRaftHandle(intdevice);
+#endif
+
+
+size_t[getTempMemoryAvailable](intdevice)const;
+
+
+void[syncDefaultStreamCurrentDevice]();
+
+
+
+void[setLogMemoryAllocations](boolenable);
+
+private:
+std::shared_ptr<StandardGpuResourcesImpl>res_;
+};
+
+}
+}
+#pragmaGCCvisibilitypop
diff --git a/xml/classfaiss_1_1gpu_1_1GpuResources.xml b/xml/classfaiss_1_1gpu_1_1GpuResources.xml
index e89727a5e1..25b1df88a2 100644
--- a/xml/classfaiss_1_1gpu_1_1GpuResources.xml
+++ b/xml/classfaiss_1_1gpu_1_1GpuResources.xml
@@ -37,6 +37,25 @@
+
+ bool
+ virtual bool faiss::gpu::GpuResources::supportsBFloat16
+ (int device)=0
+ supportsBFloat16
+ supportsBFloat16
+
+ int
+ device
+
+
+Does the given GPU support bfloat16?
+
+
+
+
+
+
+
cublasHandle_t
virtual cublasHandle_t faiss::gpu::GpuResources::getBlasHandle
@@ -54,7 +73,7 @@
-
+
cudaStream_t
@@ -73,7 +92,7 @@
-
+
void
@@ -96,7 +115,7 @@
-
+
std::vector< cudaStream_t >
@@ -115,7 +134,7 @@
-
+
void *
@@ -134,7 +153,7 @@
-
+
void
@@ -157,7 +176,7 @@
-
+
size_t
@@ -176,7 +195,7 @@
-
+
std::pair< void *, size_t >
@@ -191,7 +210,7 @@
-
+
cudaStream_t
@@ -210,7 +229,22 @@
-
+
+
+
+ bool
+ bool faiss::gpu::GpuResources::supportsBFloat16CurrentDevice
+ ()
+ supportsBFloat16CurrentDevice
+
+Does the current GPU support bfloat16?
+
+
+Functions provided by default
+
+
+
+
cublasHandle_t
@@ -221,11 +255,10 @@
Calls getBlasHandle with the current device.
-Functions provided by default
-
+
cudaStream_t
@@ -239,7 +272,7 @@
-
+
size_t
@@ -253,7 +286,7 @@
-
+
[GpuMemoryReservation]
@@ -271,7 +304,7 @@
-
+
void
@@ -289,7 +322,7 @@
-
+
void
@@ -303,7 +336,7 @@
-
+
std::vector< cudaStream_t >
@@ -317,7 +350,7 @@
-
+
cudaStream_t
@@ -331,7 +364,7 @@
-
+
@@ -351,7 +384,7 @@
-
+
faiss::gpu::GpuResourcesallocMemory
faiss::gpu::GpuResourcesallocMemoryHandle
@@ -369,6 +402,8 @@
faiss::gpu::GpuResourcesgetTempMemoryAvailableCurrentDevice
faiss::gpu::GpuResourcesinitializeForDevice
faiss::gpu::GpuResourcessetDefaultStream
+ faiss::gpu::GpuResourcessupportsBFloat16
+ faiss::gpu::GpuResourcessupportsBFloat16CurrentDevice
faiss::gpu::GpuResourcessyncDefaultStream
faiss::gpu::GpuResourcessyncDefaultStreamCurrentDevice
faiss::gpu::GpuResources~GpuResources
diff --git a/xml/classfaiss_1_1gpu_1_1GpuResourcesProvider.xml b/xml/classfaiss_1_1gpu_1_1GpuResourcesProvider.xml
index db0b7ba05d..a725ee71dc 100644
--- a/xml/classfaiss_1_1gpu_1_1GpuResourcesProvider.xml
+++ b/xml/classfaiss_1_1gpu_1_1GpuResourcesProvider.xml
@@ -17,7 +17,7 @@
-
+
std::shared_ptr< [GpuResources] >
@@ -33,7 +33,7 @@
-
+
@@ -59,7 +59,7 @@
-
+
faiss::gpu::GpuResourcesProvidergetResources
faiss::gpu::GpuResourcesProvider~GpuResourcesProvider
diff --git a/xml/classfaiss_1_1gpu_1_1GpuResourcesProviderFromInstance.xml b/xml/classfaiss_1_1gpu_1_1GpuResourcesProviderFromInstance.xml
index a2f59d7ea7..93dde0263a 100644
--- a/xml/classfaiss_1_1gpu_1_1GpuResourcesProviderFromInstance.xml
+++ b/xml/classfaiss_1_1gpu_1_1GpuResourcesProviderFromInstance.xml
@@ -16,7 +16,7 @@
-
+
@@ -35,7 +35,7 @@
-
+
@@ -48,7 +48,7 @@
-
+
std::shared_ptr< [GpuResources] >
@@ -63,7 +63,7 @@
-
+
@@ -95,7 +95,7 @@
-
+
faiss::gpu::GpuResourcesProviderFromInstancegetResources
faiss::gpu::GpuResourcesProviderFromInstanceGpuResourcesProviderFromInstance
diff --git a/xml/classfaiss_1_1gpu_1_1StandardGpuResources.xml b/xml/classfaiss_1_1gpu_1_1StandardGpuResources.xml
index 4999a50704..d3dc5124ac 100644
--- a/xml/classfaiss_1_1gpu_1_1StandardGpuResources.xml
+++ b/xml/classfaiss_1_1gpu_1_1StandardGpuResources.xml
@@ -16,7 +16,7 @@
-
+
@@ -31,7 +31,7 @@
-
+
@@ -44,7 +44,7 @@
-
+
std::shared_ptr< [GpuResources] >
@@ -59,7 +59,39 @@
-
+
+
+
+ bool
+ bool faiss::gpu::StandardGpuResources::supportsBFloat16
+ (int device)
+ supportsBFloat16
+
+ int
+ device
+
+
+Whether or not the given device supports native bfloat16 arithmetic.
+
+
+
+
+
+
+
+
+ bool
+ bool faiss::gpu::StandardGpuResources::supportsBFloat16CurrentDevice
+ ()
+ supportsBFloat16CurrentDevice
+
+Whether or not the current device supports native bfloat16 arithmetic.
+
+
+
+
+
+
void
@@ -73,7 +105,7 @@
-
+
void
@@ -91,7 +123,7 @@
-
+
void
@@ -109,7 +141,7 @@
-
+
void
@@ -131,7 +163,7 @@
-
+
void
@@ -149,7 +181,7 @@
-
+
void
@@ -163,7 +195,7 @@
-
+
std::map< int, std::map< std::string, std::pair< int, size_t > > >
@@ -177,7 +209,7 @@
-
+
cudaStream_t
@@ -195,7 +227,7 @@
-
+
size_t
@@ -213,7 +245,7 @@
-
+
void
@@ -227,7 +259,7 @@
-
+
void
@@ -245,7 +277,7 @@
-
+
@@ -277,7 +309,7 @@
-
+
faiss::gpu::StandardGpuResourcesgetDefaultStream
faiss::gpu::StandardGpuResourcesgetMemoryInfo
@@ -292,6 +324,8 @@
faiss::gpu::StandardGpuResourcessetPinnedMemory
faiss::gpu::StandardGpuResourcessetTempMemory
faiss::gpu::StandardGpuResourcesStandardGpuResources
+ faiss::gpu::StandardGpuResourcessupportsBFloat16
+ faiss::gpu::StandardGpuResourcessupportsBFloat16CurrentDevice
faiss::gpu::StandardGpuResourcessyncDefaultStreamCurrentDevice
faiss::gpu::StandardGpuResources~GpuResourcesProvider
faiss::gpu::StandardGpuResources~StandardGpuResources
diff --git a/xml/classfaiss_1_1gpu_1_1StandardGpuResourcesImpl.xml b/xml/classfaiss_1_1gpu_1_1StandardGpuResourcesImpl.xml
index 69dd083d64..df9323c3a2 100644
--- a/xml/classfaiss_1_1gpu_1_1StandardGpuResourcesImpl.xml
+++ b/xml/classfaiss_1_1gpu_1_1StandardGpuResourcesImpl.xml
@@ -17,7 +17,7 @@
-
+
std::unordered_map< int, std::unique_ptr< [StackDeviceMemory] > >
@@ -31,7 +31,7 @@
-
+
std::unordered_map< int, cudaStream_t >
@@ -45,7 +45,7 @@
-
+
std::unordered_map< int, cudaStream_t >
@@ -59,7 +59,7 @@
-
+
std::unordered_map< int, std::vector< cudaStream_t > >
@@ -73,7 +73,7 @@
-
+
std::unordered_map< int, cudaStream_t >
@@ -87,7 +87,7 @@
-
+
std::unordered_map< int, cublasHandle_t >
@@ -101,7 +101,7 @@
-
+
void *
@@ -115,7 +115,7 @@
-
+
size_t
@@ -128,7 +128,7 @@
-
+
size_t
@@ -142,7 +142,7 @@
-
+
size_t
@@ -156,7 +156,7 @@
-
+
bool
@@ -170,7 +170,7 @@
-
+
@@ -200,6 +200,25 @@
+
+ bool
+ bool faiss::gpu::StandardGpuResourcesImpl::supportsBFloat16
+ (int device) override
+ supportsBFloat16
+ supportsBFloat16
+
+ int
+ device
+
+
+Does the given GPU support bfloat16?
+
+
+
+
+
+
+
void
void faiss::gpu::StandardGpuResourcesImpl::noTempMemory
@@ -212,7 +231,7 @@
-
+
void
@@ -230,7 +249,7 @@
-
+
void
@@ -248,7 +267,7 @@
-
+
void
@@ -271,7 +290,7 @@
-
+
void
@@ -289,7 +308,7 @@
-
+
cudaStream_t
@@ -308,7 +327,7 @@
-
+
void
@@ -322,7 +341,7 @@
-
+
void
@@ -340,7 +359,7 @@
-
+
void
@@ -360,7 +379,7 @@
-
+
cublasHandle_t
@@ -379,7 +398,7 @@
-
+
std::vector< cudaStream_t >
@@ -398,7 +417,7 @@
-
+
void *
@@ -417,7 +436,7 @@
-
+
void
@@ -440,7 +459,7 @@
-
+
size_t
@@ -459,7 +478,7 @@
-
+
std::map< int, std::map< std::string, std::pair< int, size_t > > >
@@ -473,7 +492,7 @@
-
+
std::pair< void *, size_t >
@@ -488,7 +507,7 @@
-
+
cudaStream_t
@@ -507,7 +526,22 @@
-
+
+
+
+ bool
+ bool faiss::gpu::GpuResources::supportsBFloat16CurrentDevice
+ ()
+ supportsBFloat16CurrentDevice
+
+Does the current GPU support bfloat16?
+
+
+Functions provided by default
+
+
+
+
cublasHandle_t
@@ -518,11 +552,10 @@
Calls getBlasHandle with the current device.
-Functions provided by default
-
+
cudaStream_t
@@ -536,7 +569,7 @@
-
+
size_t
@@ -550,7 +583,7 @@
-
+
[GpuMemoryReservation]
@@ -568,7 +601,7 @@
-
+
void
@@ -586,7 +619,7 @@
-
+
void
@@ -600,7 +633,7 @@
-
+
std::vector< cudaStream_t >
@@ -614,7 +647,7 @@
-
+
cudaStream_t
@@ -628,7 +661,7 @@
-
+
@@ -648,7 +681,7 @@
-
+
@@ -672,7 +705,7 @@
-
+
@@ -704,7 +737,7 @@
-
+
faiss::gpu::StandardGpuResourcesImplallocLogging_
faiss::gpu::StandardGpuResourcesImplallocMemory
@@ -741,6 +774,8 @@
faiss::gpu::StandardGpuResourcesImplsetPinnedMemory
faiss::gpu::StandardGpuResourcesImplsetTempMemory
faiss::gpu::StandardGpuResourcesImplStandardGpuResourcesImpl
+ faiss::gpu::StandardGpuResourcesImplsupportsBFloat16
+ faiss::gpu::StandardGpuResourcesImplsupportsBFloat16CurrentDevice
faiss::gpu::StandardGpuResourcesImplsyncDefaultStream
faiss::gpu::StandardGpuResourcesImplsyncDefaultStreamCurrentDevice
faiss::gpu::StandardGpuResourcesImpltempMemory_
diff --git a/xml/index.xml b/xml/index.xml
index 87512b4e28..c448539fea 100644
--- a/xml/index.xml
+++ b/xml/index.xml
@@ -1549,6 +1549,7 @@
faiss::gpu::GpuResources
~GpuResources
initializeForDevice
+ supportsBFloat16
getBlasHandle
getDefaultStream
setDefaultStream
@@ -1558,6 +1559,7 @@
getTempMemoryAvailable
getPinnedMemory
getAsyncCopyStream
+ supportsBFloat16CurrentDevice
getBlasHandleCurrentDevice
getDefaultStreamCurrentDevice
getTempMemoryAvailableCurrentDevice
@@ -8725,6 +8727,8 @@
StandardGpuResources
~StandardGpuResources
getResources
+ supportsBFloat16
+ supportsBFloat16CurrentDevice
noTempMemory
setTempMemory
setPinnedMemory
@@ -8752,6 +8756,7 @@
allocLogging_
StandardGpuResourcesImpl
~StandardGpuResourcesImpl
+ supportsBFloat16
noTempMemory
setTempMemory
setPinnedMemory
@@ -8769,6 +8774,7 @@
getMemoryInfo
getPinnedMemory
getAsyncCopyStream
+ supportsBFloat16CurrentDevice
getBlasHandleCurrentDevice
getDefaultStreamCurrentDevice
getTempMemoryAvailableCurrentDevice
@@ -9508,6 +9514,7 @@
DistanceDataType
F32
F16
+ BF16
IndicesDataType
I64
I32
diff --git a/xml/namespacefaiss_1_1gpu.xml b/xml/namespacefaiss_1_1gpu.xml
index 5374a8ab5c..f098dedc98 100644
--- a/xml/namespacefaiss_1_1gpu.xml
+++ b/xml/namespacefaiss_1_1gpu.xml
@@ -68,13 +68,20 @@
+
+ BF16
+
+
+
+
+
-
+
@@ -100,7 +107,7 @@
-
+
@@ -566,7 +573,7 @@
-
+
void
@@ -590,7 +597,7 @@
-
+
void
@@ -619,7 +626,7 @@
-
+
void
@@ -681,7 +688,7 @@
-
+
bool
diff --git a/xml/structfaiss_1_1gpu_1_1GpuDistanceParams.xml b/xml/structfaiss_1_1gpu_1_1GpuDistanceParams.xml
index 9cb9c0a673..14c5c4de5d 100644
--- a/xml/structfaiss_1_1gpu_1_1GpuDistanceParams.xml
+++ b/xml/structfaiss_1_1gpu_1_1GpuDistanceParams.xml
@@ -17,7 +17,7 @@
-
+
float
@@ -32,7 +32,7 @@
-
+
int
@@ -47,7 +47,7 @@
-
+
int
@@ -62,7 +62,7 @@
-
+
const void *
@@ -77,7 +77,7 @@
-
+
DistanceDataType
@@ -91,7 +91,7 @@
-
+
bool
@@ -105,7 +105,7 @@
-
+
[idx_t]
@@ -119,7 +119,7 @@
-
+
const float *
@@ -134,7 +134,7 @@
-
+
const void *
@@ -149,7 +149,7 @@
-
+
DistanceDataType
@@ -163,7 +163,7 @@
-
+
bool
@@ -177,7 +177,7 @@
-
+
[idx_t]
@@ -191,7 +191,7 @@
-
+
float *
@@ -206,7 +206,7 @@
-
+
bool
@@ -221,7 +221,7 @@
-
+
IndicesDataType
@@ -236,7 +236,7 @@
-
+
void *
@@ -250,7 +250,7 @@
-
+
int
@@ -265,7 +265,7 @@
-
+
bool
@@ -280,7 +280,7 @@
-
+
@@ -288,7 +288,7 @@
-
+
faiss::gpu::GpuDistanceParamsdevice
faiss::gpu::GpuDistanceParamsdims