From 2e0a2257e44c7e3ca212be7eb3415e8f70741939 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Sat, 11 May 2024 08:56:33 -0500 Subject: [PATCH] always rerank at least the best result found so that caller will have something to compare this index's results with --- .../jbellis/jvector/graph/NodeQueue.java | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java index 3436e6f4..2eed2447 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/NodeQueue.java @@ -141,30 +141,50 @@ public int[] nodesCopy() { /** * Rerank results and return the worst approximate score that made it into the topK. * The topK results will be placed into `reranked`, and the remainder into `unused`. + *

+ * Only the best result or results whose approximate score is at least `rerankFloor` will be reranked. */ public float rerank(int topK, ScoreFunction.Reranker reranker, float rerankFloor, NodeQueue reranked, NodesUnsorted unused) { // Rescore the nodes whose approximate score meets the floor. Nodes that do not will be marked as -1 int[] ids = new int[size()]; float[] exactScores = new float[size()]; var approximateScoresById = new Int2ObjectHashMap(); + float bestScore = Float.NEGATIVE_INFINITY; + int bestIndex = -1; + int scoresAboveFloor = 0; for (int i = 0; i < size(); i++) { long heapValue = heap.get(i + 1); float score = decodeScore(heapValue); var nodeId = decodeNodeId(heapValue); + // track the best score found so far in case nothing is above the floor + if (score > bestScore) { + bestScore = score; + bestIndex = i; + } + if (score >= rerankFloor) { + // rerank this one ids[i] = nodeId; exactScores[i] = reranker.similarityTo(ids[i]); approximateScoresById.put(ids[i], Float.valueOf(score)); + scoresAboveFloor++; } else { - // if it didn't qualify for reranking, add it to the unused pile - unused.add(nodeId, score); + // mark it unranked ids[i] = -1; } } + if (scoresAboveFloor == 0 && bestIndex >= 0) { + // if nothing was above the floor, then rerank the best one found + ids[bestIndex] = decodeNodeId(heap.get(bestIndex + 1)); + exactScores[bestIndex] = reranker.similarityTo(ids[bestIndex]); + approximateScoresById.put(ids[bestIndex], Float.valueOf(bestScore)); + } + // go through the entries and add to the appropriate collection for (int i = 0; i < ids.length; i++) { if (ids[i] == -1) { + unused.add(decodeNodeId(heap.get(i + 1)), decodeScore(heap.get(i + 1))); continue; }