Skip to content

Commit

Permalink
fix global centering and add test that raw computation equals precomp…
Browse files Browse the repository at this point in the history
…uted
  • Loading branch information
jbellis committed Jul 16, 2024
1 parent bd75108 commit 0f78056
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ public CosineDecoder(PQVectors cv, VectorFloat<?> query) {

// Compute and cache partial sums and magnitudes for query vector
partialSums = cv.reusablePartialSums();
float bMagSum = 0.0f;

VectorFloat<?> center = pq.globalCentroid;
VectorFloat<?> centeredQuery = center == null ? query : VectorUtil.sub(query, center);
Expand All @@ -121,11 +120,9 @@ public CosineDecoder(PQVectors cv, VectorFloat<?> query) {
for (int j = 0; j < pq.getClusterCount(); ++j) {
partialSums.set((m * pq.getClusterCount()) + j, VectorUtil.dotProduct(codebook, j * size, centeredQuery, offset, size));
}

bMagSum += VectorUtil.dotProduct(centeredQuery, offset, centeredQuery, offset, pq.subvectorSizesAndOffsets[m][0]);
}

this.bMagnitude = bMagSum;
this.bMagnitude = VectorUtil.dotProduct(centeredQuery, centeredQuery);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(Vector

@Override
public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q, VectorSimilarityFunction similarityFunction) {
VectorFloat<?> centeredQuery = pq.globalCentroid == null ? q : VectorUtil.sub(q, pq.globalCentroid);
switch (similarityFunction) {
case DOT_PRODUCT:
return (node2) -> {
Expand All @@ -143,13 +144,13 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, q, centroidOffset, centroidLength);
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
}
// scale to [0, 1]
return (1 + dp) / 2;
};
case COSINE:
float norm1 = VectorUtil.dotProduct(q, q);
float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery);
return (node2) -> {
var encoded = get(node2);
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
Expand All @@ -160,7 +161,7 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
var codebookOffset = centroidIndex * centroidLength;
sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, q, centroidOffset, centroidLength);
sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength);
norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength);
}
float cosine = sum / (float) Math.sqrt(norm1 * norm2);
Expand All @@ -176,7 +177,7 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, q, centroidOffset, centroidLength);
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
}
// scale to [0, 1]
return 1 / (1 + sum);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.disk.SimpleMappedReader;
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import org.junit.Test;

Expand All @@ -30,6 +31,7 @@
import java.util.List;

import static io.github.jbellis.jvector.TestUtil.createRandomVectors;
import static io.github.jbellis.jvector.TestUtil.nextInt;
import static java.lang.Math.abs;
import static java.lang.Math.log;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -87,7 +89,7 @@ public void testSaveLoadBQ() throws Exception {
}

private void testEncodings(int dimension, int codebooks) {
// Generate a PQ for random 2D vectors
// Generate a PQ for random vectors
var vectors = createRandomVectors(512, dimension);
var ravv = new ListRandomAccessVectorValues(vectors, dimension);
var pq = ProductQuantization.compute(ravv, codebooks, 256, false);
Expand Down Expand Up @@ -127,6 +129,35 @@ public void testEncodings() {
}
}

@Test
public void testRawEqualsPrecomputed() {
// Generate a PQ for random vectors
int dimension = nextInt(getRandom(), 4, 2048);
int codebooks = nextInt(getRandom(), 1, dimension / 2);
var vectors = createRandomVectors(512, dimension);
var ravv = new ListRandomAccessVectorValues(vectors, dimension);
for (boolean center : new boolean[] {true, false}) {
var pq = ProductQuantization.compute(ravv, codebooks, 256, center);

// Compress the vectors
var compressed = pq.encodeAll(ravv);
var cv = new PQVectors(pq, compressed);

// compare the precomputed similarities to the raw
for (int i = 0; i < 10; i++) {
var q = TestUtil.randomVector(getRandom(), dimension);
for (var vsf : List.of(VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.DOT_PRODUCT, VectorSimilarityFunction.COSINE)) {
var precomputed = cv.precomputedScoreFunctionFor(q, vsf);
var raw = cv.scoreFunctionFor(q, vsf);
for (int j = 0; j < 10; j++) {
var target = getRandom().nextInt(vectors.size());
assertEquals(raw.similarityTo(target), precomputed.similarityTo(target), 1e-6);
}
}
}
}
}

@Test
public void testCenteringDisturbance() {

Expand Down

0 comments on commit 0f78056

Please sign in to comment.