Skip to content

Commit

Permalink
Added IndexLSH to the demo (#4009)
Browse files Browse the repository at this point in the history
Summary:
Demonstrate IndexLSH does not need training or codebook serialization

Pull Request resolved: #4009

Reviewed By: junjieqi

Differential Revision: D65274645

Pulled By: asadoughi

fbshipit-source-id: c9af463757edbd07cc07b1cf607b88373fa334c4
  • Loading branch information
asadoughi authored and facebook-github-bot committed Nov 4, 2024
1 parent 2c961cc commit a11c1db
Showing 1 changed file with 227 additions and 13 deletions.
240 changes: 227 additions & 13 deletions demos/index_pq_flat_separate_codes_from_codebook.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
#!/usr/bin/env -S grimaldi --kernel faiss_binary_local
#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# fmt: off
# flake8: noqa


""":md
# IndexPQ: separate codes from codebook
# Serializing codes separately, with IndexLSH and IndexPQ
Let's say, for example, you have a few vector embeddings per user
and want to shard a flat index by user so you can re-use the same LSH or PQ method
for all users but store each user's codes independently.
This notebook demonstrates how to separate serializing and deserializing the PQ codebook
(via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
where you have a few vector embeddings per user and want to shard the flat index by user you
can re-use the same PQ method for all users but store each user's codes independently.
"""

Expand All @@ -24,11 +23,9 @@

""":py"""
d = 768
n = 10000
n = 1_000
ids = np.arange(n).astype('int64')
training_data = np.random.rand(n, d).astype('float32')
M = d//8
nbits = 8

""":py"""
def read_ids_codes():
Expand All @@ -50,9 +47,76 @@ def write_template_index(template_index):
def read_template_index_instance():
return faiss.read_index("/tmp/template.index")

""":md
## IndexLSH: separate codes
The first half of this notebook demonstrates how to store LSH codes. Unlike PQ, LSH does not require training. In fact, it's compression method, a random projections matrix, is deterministic on construction based on a random seed value that's [hardcoded](https://github.com/facebookresearch/faiss/blob/2c961cc308ade8a85b3aa10a550728ce3387f625/faiss/IndexLSH.cpp#L35).
"""

""":py"""
# at train time
nbits = 1536

""":py"""
# demonstrating encoding is deterministic

codes = []
database_vector_float32 = np.random.rand(1, d).astype(np.float32)
for i in range(10):
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
code = index.index.sa_encode(database_vector_float32)
codes.append(code)

for i in range(1, 10):
assert np.array_equal(codes[0], codes[i])

""":py"""
# new database vector

ids, codes = read_ids_codes()
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))

code = index.index.sa_encode(database_vector_float32)

if ids is not None and codes is not None:
ids = np.concatenate((ids, [database_vector_id]))
codes = np.vstack((codes, code))
else:
ids = np.array([database_vector_id])
codes = np.array([code])

write_ids_codes(ids, codes)

""":py '2840581589434841'"""
# then at query time

query_vector_float32 = np.random.rand(1, d).astype(np.float32)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
ids, codes = read_ids_codes()

index.add_sa_codes(codes, ids)

index.search(query_vector_float32, k=5)

""":py"""
!rm /tmp/ids.npy /tmp/codes.npy

""":md
## IndexPQ: separate codes from codebook
The second half of this notebook demonstrates how to separate serializing and deserializing the PQ codebook
(via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
where you have a few vector embeddings per user and want to shard the flat index by user you
can re-use the same PQ method for all users but store each user's codes independently.
"""

""":py"""
M = d//8
nbits = 8

""":py"""
# at train time
template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}")
template_index.train(training_data)
write_template_index(template_index)
Expand All @@ -61,8 +125,8 @@ def read_template_index_instance():
# New database vector

index = read_template_index_instance()
database_vector_id, database_vector_float32 = np.random.randint(10000), np.random.rand(1, d).astype(np.float32)
ids, codes = read_ids_codes()
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)

code = index.index.sa_encode(database_vector_float32)

Expand All @@ -75,7 +139,7 @@ def read_template_index_instance():

write_ids_codes(ids, codes)

""":py '331546060044009'"""
""":py '1858280061369209'"""
# then at query time
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
id_wrapper_index = read_template_index_instance()
Expand All @@ -87,3 +151,153 @@ def read_template_index_instance():

""":py"""
!rm /tmp/ids.npy /tmp/codes.npy /tmp/template.index

""":md
## Comparing these methods
- methods: Flat, LSH, PQ
- vary cost: nbits, M for 1x, 2x, 4x, 8x, 16x, 32x compression
- measure: recall@1
We don't measure latency as the number of vectors per user shard is insignificant.
"""

""":py '2898032417027201'"""
n, d

""":py"""
database_vector_ids, database_vector_float32s = np.arange(n), np.random.rand(n, d).astype(np.float32)
query_vector_float32s = np.random.rand(n, d).astype(np.float32)

""":py"""
index = faiss.index_factory(d, "IDMap2,Flat")
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, ground_truth_result_ids= index.search(query_vector_float32s, k=1)

""":py '857475336204238'"""
from dataclasses import dataclass

pq_m_nbits = (
# 96 bytes
(96, 8),
(192, 4),
# 192 bytes
(192, 8),
(384, 4),
# 384 bytes
(384, 8),
(768, 4),
)
lsh_nbits = (768, 1536, 3072, 6144, 12288, 24576)


@dataclass
class Record:
type_: str
index: faiss.Index
args: tuple
recall: float


results = []

for m, nbits in pq_m_nbits:
print("pq", m, nbits)
index = faiss.index_factory(d, f"IDMap2,PQ{m}x{nbits}")
index.train(training_data)
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, result_ids = index.search(query_vector_float32s, k=1)
recall = sum(result_ids == ground_truth_result_ids)
results.append(Record("pq", index, (m, nbits), recall))

for nbits in lsh_nbits:
print("lsh", nbits)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, result_ids = index.search(query_vector_float32s, k=1)
recall = sum(result_ids == ground_truth_result_ids)
results.append(Record("lsh", index, (nbits,), recall))

""":py '556918346720794'"""
import matplotlib.pyplot as plt
import numpy as np

def create_grouped_bar_chart(x_values, y_values_list, labels_list, xlabel, ylabel, title):
num_bars_per_group = len(x_values)

plt.figure(figsize=(12, 6))

for x, y_values, labels in zip(x_values, y_values_list, labels_list):
num_bars = len(y_values)
bar_width = 0.08 * x
bar_positions = np.arange(num_bars) * bar_width - (num_bars - 1) * bar_width / 2 + x

bars = plt.bar(bar_positions, y_values, width=bar_width)

for bar, label in zip(bars, labels):
height = bar.get_height()
plt.annotate(
label,
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom'
)

plt.xscale('log')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.xticks(x_values, labels=[str(x) for x in x_values])
plt.tight_layout()
plt.show()

# # Example usage:
# x_values = [1, 2, 4, 8, 16, 32]
# y_values_list = [
# [2.5, 3.6, 1.8],
# [3.0, 2.8],
# [2.5, 3.5, 4.0, 1.0],
# [4.2],
# [3.0, 5.5, 2.2],
# [6.0, 4.5]
# ]
# labels_list = [
# ['A1', 'B1', 'C1'],
# ['A2', 'B2'],
# ['A3', 'B3', 'C3', 'D3'],
# ['A4'],
# ['A5', 'B5', 'C5'],
# ['A6', 'B6']
# ]

# create_grouped_bar_chart(x_values, y_values_list, labels_list, "x axis", "y axis", "title")

""":py '1630106834206134'"""
# x-axis: compression ratio
# y-axis: recall@1

from collections import defaultdict

x = defaultdict(list)
x[1].append(("flat", 1.00))
for r in results:
y_value = r.recall[0] / n
x_value = int(d * 4 / r.index.sa_code_size())
label = None
if r.type_ == "pq":
label = f"PQ{r.args[0]}x{r.args[1]}"
if r.type_ == "lsh":
label = f"LSH{r.args[0]}"
x[x_value].append((label, y_value))

x_values = sorted(list(x.keys()))
create_grouped_bar_chart(
x_values,
[[e[1] for e in x[x_value]] for x_value in x_values],
[[e[0] for e in x[x_value]] for x_value in x_values],
"compression ratio",
"recall@1 q=1,000 queries",
"recall@1 for a database of n=1,000 d=768 vectors",
)

0 comments on commit a11c1db

Please sign in to comment.