diff --git a/demos/index_pq_flat_separate_codes_from_codebook.py b/demos/index_pq_flat_separate_codes_from_codebook.py index 982c805262..d138a374b4 100644 --- a/demos/index_pq_flat_separate_codes_from_codebook.py +++ b/demos/index_pq_flat_separate_codes_from_codebook.py @@ -1,20 +1,19 @@ -#!/usr/bin/env -S grimaldi --kernel faiss_binary_local +#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - # fmt: off # flake8: noqa """:md -# IndexPQ: separate codes from codebook +# Serializing codes separately, with IndexLSH and IndexPQ + +Let's say, for example, you have a few vector embeddings per user +and want to shard a flat index by user so you can re-use the same LSH or PQ method + for all users but store each user's codes independently. -This notebook demonstrates how to separate serializing and deserializing the PQ codebook - (via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case - where you have a few vector embeddings per user and want to shard the flat index by user you - can re-use the same PQ method for all users but store each user's codes independently. """ @@ -24,11 +23,9 @@ """:py""" d = 768 -n = 10000 +n = 1_000 ids = np.arange(n).astype('int64') training_data = np.random.rand(n, d).astype('float32') -M = d//8 -nbits = 8 """:py""" def read_ids_codes(): @@ -50,9 +47,76 @@ def write_template_index(template_index): def read_template_index_instance(): return faiss.read_index("/tmp/template.index") +""":md +## IndexLSH: separate codes + +The first half of this notebook demonstrates how to store LSH codes. Unlike PQ, LSH does not require training. In fact, it's compression method, a random projections matrix, is deterministic on construction based on a random seed value that's [hardcoded](https://github.com/facebookresearch/faiss/blob/2c961cc308ade8a85b3aa10a550728ce3387f625/faiss/IndexLSH.cpp#L35). +""" + """:py""" -# at train time +nbits = 1536 + +""":py""" +# demonstrating encoding is deterministic + +codes = [] +database_vector_float32 = np.random.rand(1, d).astype(np.float32) +for i in range(10): + index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits)) + code = index.index.sa_encode(database_vector_float32) + codes.append(code) + +for i in range(1, 10): + assert np.array_equal(codes[0], codes[i]) + +""":py""" +# new database vector + +ids, codes = read_ids_codes() +database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32) +index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits)) + +code = index.index.sa_encode(database_vector_float32) +if ids is not None and codes is not None: + ids = np.concatenate((ids, [database_vector_id])) + codes = np.vstack((codes, code)) +else: + ids = np.array([database_vector_id]) + codes = np.array([code]) + +write_ids_codes(ids, codes) + +""":py '2840581589434841'""" +# then at query time + +query_vector_float32 = np.random.rand(1, d).astype(np.float32) +index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits)) +ids, codes = read_ids_codes() + +index.add_sa_codes(codes, ids) + +index.search(query_vector_float32, k=5) + +""":py""" +!rm /tmp/ids.npy /tmp/codes.npy + +""":md +## IndexPQ: separate codes from codebook + +The second half of this notebook demonstrates how to separate serializing and deserializing the PQ codebook + (via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case + where you have a few vector embeddings per user and want to shard the flat index by user you + can re-use the same PQ method for all users but store each user's codes independently. + +""" + +""":py""" +M = d//8 +nbits = 8 + +""":py""" +# at train time template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}") template_index.train(training_data) write_template_index(template_index) @@ -61,8 +125,8 @@ def read_template_index_instance(): # New database vector index = read_template_index_instance() -database_vector_id, database_vector_float32 = np.random.randint(10000), np.random.rand(1, d).astype(np.float32) ids, codes = read_ids_codes() +database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32) code = index.index.sa_encode(database_vector_float32) @@ -75,7 +139,7 @@ def read_template_index_instance(): write_ids_codes(ids, codes) -""":py '331546060044009'""" +""":py '1858280061369209'""" # then at query time query_vector_float32 = np.random.rand(1, d).astype(np.float32) id_wrapper_index = read_template_index_instance() @@ -87,3 +151,153 @@ def read_template_index_instance(): """:py""" !rm /tmp/ids.npy /tmp/codes.npy /tmp/template.index + +""":md +## Comparing these methods + +- methods: Flat, LSH, PQ +- vary cost: nbits, M for 1x, 2x, 4x, 8x, 16x, 32x compression +- measure: recall@1 + +We don't measure latency as the number of vectors per user shard is insignificant. + +""" + +""":py '2898032417027201'""" +n, d + +""":py""" +database_vector_ids, database_vector_float32s = np.arange(n), np.random.rand(n, d).astype(np.float32) +query_vector_float32s = np.random.rand(n, d).astype(np.float32) + +""":py""" +index = faiss.index_factory(d, "IDMap2,Flat") +index.add_with_ids(database_vector_float32s, database_vector_ids) +_, ground_truth_result_ids= index.search(query_vector_float32s, k=1) + +""":py '857475336204238'""" +from dataclasses import dataclass + +pq_m_nbits = ( + # 96 bytes + (96, 8), + (192, 4), + # 192 bytes + (192, 8), + (384, 4), + # 384 bytes + (384, 8), + (768, 4), +) +lsh_nbits = (768, 1536, 3072, 6144, 12288, 24576) + + +@dataclass +class Record: + type_: str + index: faiss.Index + args: tuple + recall: float + + +results = [] + +for m, nbits in pq_m_nbits: + print("pq", m, nbits) + index = faiss.index_factory(d, f"IDMap2,PQ{m}x{nbits}") + index.train(training_data) + index.add_with_ids(database_vector_float32s, database_vector_ids) + _, result_ids = index.search(query_vector_float32s, k=1) + recall = sum(result_ids == ground_truth_result_ids) + results.append(Record("pq", index, (m, nbits), recall)) + +for nbits in lsh_nbits: + print("lsh", nbits) + index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits)) + index.add_with_ids(database_vector_float32s, database_vector_ids) + _, result_ids = index.search(query_vector_float32s, k=1) + recall = sum(result_ids == ground_truth_result_ids) + results.append(Record("lsh", index, (nbits,), recall)) + +""":py '556918346720794'""" +import matplotlib.pyplot as plt +import numpy as np + +def create_grouped_bar_chart(x_values, y_values_list, labels_list, xlabel, ylabel, title): + num_bars_per_group = len(x_values) + + plt.figure(figsize=(12, 6)) + + for x, y_values, labels in zip(x_values, y_values_list, labels_list): + num_bars = len(y_values) + bar_width = 0.08 * x + bar_positions = np.arange(num_bars) * bar_width - (num_bars - 1) * bar_width / 2 + x + + bars = plt.bar(bar_positions, y_values, width=bar_width) + + for bar, label in zip(bars, labels): + height = bar.get_height() + plt.annotate( + label, + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 3), + textcoords="offset points", + ha='center', va='bottom' + ) + + plt.xscale('log') + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.xticks(x_values, labels=[str(x) for x in x_values]) + plt.tight_layout() + plt.show() + +# # Example usage: +# x_values = [1, 2, 4, 8, 16, 32] +# y_values_list = [ +# [2.5, 3.6, 1.8], +# [3.0, 2.8], +# [2.5, 3.5, 4.0, 1.0], +# [4.2], +# [3.0, 5.5, 2.2], +# [6.0, 4.5] +# ] +# labels_list = [ +# ['A1', 'B1', 'C1'], +# ['A2', 'B2'], +# ['A3', 'B3', 'C3', 'D3'], +# ['A4'], +# ['A5', 'B5', 'C5'], +# ['A6', 'B6'] +# ] + +# create_grouped_bar_chart(x_values, y_values_list, labels_list, "x axis", "y axis", "title") + +""":py '1630106834206134'""" +# x-axis: compression ratio +# y-axis: recall@1 + +from collections import defaultdict + +x = defaultdict(list) +x[1].append(("flat", 1.00)) +for r in results: + y_value = r.recall[0] / n + x_value = int(d * 4 / r.index.sa_code_size()) + label = None + if r.type_ == "pq": + label = f"PQ{r.args[0]}x{r.args[1]}" + if r.type_ == "lsh": + label = f"LSH{r.args[0]}" + x[x_value].append((label, y_value)) + +x_values = sorted(list(x.keys())) +create_grouped_bar_chart( + x_values, + [[e[1] for e in x[x_value]] for x_value in x_values], + [[e[0] for e in x[x_value]] for x_value in x_values], + "compression ratio", + "recall@1 q=1,000 queries", + "recall@1 for a database of n=1,000 d=768 vectors", +)