Skip to content

Commit

Permalink
support bfloat16 in python
Browse files Browse the repository at this point in the history
Summary: Add support in python

Differential Revision: D66074156
  • Loading branch information
mdouze authored and facebook-github-bot committed Nov 20, 2024
1 parent eaab46c commit 7dfc2b9
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions faiss/gpu/test/torch_test_contrib_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,11 @@ def test_knn_gpu(self, use_cuvs=False):
def test_knn_gpu_cuvs(self):
self.test_knn_gpu(use_cuvs=True)

def test_knn_gpu_datatypes(self, use_cuvs=False):
def test_knn_gpu_datatypes(self, use_cuvs=False, use_bf16=False):
torch.manual_seed(10)
d = 10
nb = 1024
nq = 5
nq = 50
k = 10
res = faiss.StandardGpuResources()

Expand All @@ -361,29 +361,46 @@ def test_knn_gpu_datatypes(self, use_cuvs=False):
index.add(xb)
gt_D, gt_I = index.search(xq, k)

xb_c = xb.cuda().half()
xq_c = xq.cuda().half()
# convert to float16
if use_bf16:
xb_c = xb.cuda().bfloat16()
xq_c = xq.cuda().bfloat16()
else:
xb_c = xb.cuda().half()
xq_c = xq.cuda().half()

# use i32 output indices
D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32)
I = torch.zeros(nq, k, device=xb_c.device, dtype=torch.int32)

faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs)

self.assertTrue(torch.equal(I.long().cpu(), gt_I))
self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3)
ndiff = (I.cpu() != gt_I).sum().item()
MSE = ((D.float().cpu() - gt_D) ** 2).sum().item()
if use_bf16:
# 57 -- bf16 is not as accurate as fp16
self.assertLess(ndiff, 80)
# 0.00515
self.assertLess(MSE, 8e-3)
else:
# 5
self.assertLess(ndiff, 10)
# 8.565e-5
self.assertLess(MSE, 1e-4)

# Test using numpy
D = np.zeros((nq, k), dtype=np.float32)
I = np.zeros((nq, k), dtype=np.int32)
if not use_bf16: # bf16 not supported by numpy
# use i32 output indices
D = np.zeros((nq, k), dtype=np.float32)
I = np.zeros((nq, k), dtype=np.int32)

xb_c = xb.half().numpy()
xq_c = xq.half().numpy()
xb_c = xb.half().numpy()
xq_c = xq.half().numpy()

faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs)

self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I))
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)
def test_knn_gpu_bf16(self):
self.test_knn_gpu_datatypes(use_bf16=True)


class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase):
Expand Down

0 comments on commit 7dfc2b9

Please sign in to comment.