Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for COSINE in fused ADC #329

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.pq.FusedADCPQDecoder;
import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.pq.QuickADCPQDecoder;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorizationProvider;
Expand Down Expand Up @@ -82,7 +82,7 @@ static FusedADC load(CommonHeader header, RandomAccessReader reader) {

ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf, OnDiskGraphIndex.View view, ScoreFunction.ExactScoreFunction esf) {
var neighbors = new PackedNeighbors(view);
return QuickADCPQDecoder.newDecoder(neighbors, pq, queryVector, reusableResults.get(), vsf, esf);
return FusedADCPQDecoder.newDecoder(neighbors, pq, queryVector, reusableResults.get(), vsf, esf);
}

@Override
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public CosineDecoder(PQVectors cv, VectorFloat<?> query) {
var pq = this.cv.pq;

// this part is not query-dependent, so we can cache it
aMagnitude = cv.partialMagnitudes().updateAndGet(current -> {
aMagnitude = cv.partialSquaredMagnitudes().updateAndGet(current -> {
if (current != null) {
return current;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ VectorFloat<?> reusablePartialSums() {
return pq.reusablePartialSums();
}

AtomicReference<VectorFloat<?>> partialMagnitudes() {
return pq.partialMagnitudes();
AtomicReference<VectorFloat<?>> partialSquaredMagnitudes() {
return pq.partialSquaredMagnitudes();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ public class ProductQuantization implements VectorCompressor<ByteSequence<?>>, A
final float anisotropicThreshold; // parallel cost multiplier
private final float[][] centroidNormsSquared; // precomputed norms of the centroids, for encoding
private final ThreadLocal<VectorFloat<?>> partialSums; // for dot product, euclidean, and cosine partials
private final AtomicReference<VectorFloat<?>> partialMagnitudes; // for cosine partials
private final ThreadLocal<VectorFloat<?>> partialBestDistances; // for partial best distances during fused ADC
private final ThreadLocal<ByteSequence<?>> partialQuantizedSums; // for quantized sums during fused ADC

private final AtomicReference<VectorFloat<?>> partialSquaredMagnitudes; // for cosine partials
private final AtomicReference<ByteSequence<?>> partialQuantizedSquaredMagnitudes; // for quantized squared magnitude partials during cosine fused ADC
protected volatile float squaredMagnitudeDelta = 0; // for cosine fused ADC squared magnitude quantization delta (since this is invariant for a given PQ)
protected volatile float minSquaredMagnitude = 0; // for cosine fused ADC minimum squared magnitude (invariant for a given PQ)

/**
* Initializes the codebooks by clustering the input data using Product Quantization.
Expand Down Expand Up @@ -201,9 +203,10 @@ public ProductQuantization refine(RandomAccessVectorValues ravv,
}
this.anisotropicThreshold = anisotropicThreshold;
this.partialSums = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(getSubspaceCount() * getClusterCount()));
this.partialQuantizedSums = ThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(getSubspaceCount() * getClusterCount() * 2));
this.partialMagnitudes = new AtomicReference<>(null);
this.partialBestDistances = ThreadLocal.withInitial(() -> vectorTypeSupport.createFloatVector(getSubspaceCount()));
this.partialQuantizedSums = ThreadLocal.withInitial(() -> vectorTypeSupport.createByteSequence(getSubspaceCount() * getClusterCount() * 2));
this.partialSquaredMagnitudes = new AtomicReference<>(null);
this.partialQuantizedSquaredMagnitudes= new AtomicReference<>(null);


centroidNormsSquared = new float[M][clusterCount];
Expand Down Expand Up @@ -523,8 +526,12 @@ VectorFloat<?> reusablePartialBestDistances() {
return partialBestDistances.get();
}

AtomicReference<VectorFloat<?>> partialMagnitudes() {
return partialMagnitudes;
AtomicReference<VectorFloat<?>> partialSquaredMagnitudes() {
return partialSquaredMagnitudes;
}

AtomicReference<ByteSequence<?>> partialQuantizedSquaredMagnitudes() {
return partialQuantizedSquaredMagnitudes;
}

public void write(DataOutput out, int version) throws IOException
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,14 @@ public void calculatePartialSums(VectorFloat<?> codebook, int codebookIndex, int
}

@Override
public void quantizePartialSums(float delta, VectorFloat<?> partialSums, VectorFloat<?> partialBest, ByteSequence<?> partialQuantizedSums) {
var codebookSize = partialSums.length() / partialBest.length();
for (int i = 0; i < partialBest.length(); i++) {
var localBest = partialBest.get(i);
public void quantizePartials(float delta, VectorFloat<?> partials, VectorFloat<?> partialBases, ByteSequence<?> quantizedPartials) {
var codebookSize = partials.length() / partialBases.length();
for (int i = 0; i < partialBases.length(); i++) {
var localBest = partialBases.get(i);
for (int j = 0; j < codebookSize; j++) {
var val = partialSums.get(i * codebookSize + j);
var val = partials.get(i * codebookSize + j);
var quantized = (short) Math.min((val - localBest) / delta, 65535);
partialQuantizedSums.setLittleEndianShort(i * codebookSize + j, quantized);
quantizedPartials.setLittleEndianShort(i * codebookSize + j, quantized);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ public static void bulkShuffleQuantizedSimilarity(ByteSequence<?> shuffles, int
impl.bulkShuffleQuantizedSimilarity(shuffles, codebookCount, quantizedPartials, delta, minDistance, vsf, results);
}

public static void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int codebookCount,
ByteSequence<?> quantizedPartialSums, float sumDelta, float minDistance,
ByteSequence<?> quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude,
float queryMagnitudeSquared, VectorFloat<?> results) {
impl.bulkShuffleQuantizedSimilarityCosine(shuffles, codebookCount, quantizedPartialSums, sumDelta, minDistance, quantizedPartialMagnitudes, magnitudeDelta, minMagnitude, queryMagnitudeSquared, results);
}

public static int hammingDistance(long[] v1, long[] v2) {
return impl.hammingDistance(v1, v2);
}
Expand All @@ -167,8 +174,8 @@ public static void calculatePartialSums(VectorFloat<?> codebook, int codebookInd
impl.calculatePartialSums(codebook, codebookIndex, size, clusterCount, query, offset, vsf, partialSums);
}

public static void quantizePartialSums(float delta, VectorFloat<?> partialSums, VectorFloat<?> partialBest, ByteSequence<?> partialQuantizedSums) {
impl.quantizePartialSums(delta, partialSums, partialBest, partialQuantizedSums);
public static void quantizePartials(float delta, VectorFloat<?> partials, VectorFloat<?> partialBase, ByteSequence<?> quantizedPartials) {
impl.quantizePartials(delta, partials, partialBase, quantizedPartials);
}

/**
Expand Down
Loading
Loading