Skip to content

Commit

Permalink
Improve performance of reconnectOrphanedNodes (#359)
Browse files Browse the repository at this point in the history
* Improve performance of reconnectOrphanedNodes by limiting neighbor connection targets to nodes that were reachable by the entry node at the start of the pass. Instead of using exclusion bits for connection targets, perform several rounds of resumes and post-filter for connectionTargets. Log basic debugging information when reconnecting orphaned nodes by introducing slf4j-api.

* simplify connectionTargets, connectedNodes, and excludedBits

no need for resuming the search again

add backlinking of new edges from search

* preserve connectionTargets across passes

* No need to invert bits -- the argument is for results to include, which already matches connectionTargets

* switch back to Joel's search + resume approach

* we do need to exclude self

* r/m unused

* Re-split connected nodes/global connection targets. Don't filter neighbors found via search by the connected set

* Javadoc updates

* DRY connectToClosestNeighbor by passing Bits.ALL for connectedNodes when connecting through search

---------

Co-authored-by: Jonathan Ellis <[email protected]>
  • Loading branch information
jkni and jbellis authored Sep 27, 2024
1 parent 8f3c682 commit 78ee760
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.AtomicFixedBitSet;
import io.github.jbellis.jvector.util.BitSet;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.util.ExceptionUtils;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
Expand All @@ -29,6 +30,8 @@
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.agrona.collections.IntArrayList;
import org.agrona.collections.IntArrayQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Closeable;
import java.io.IOException;
Expand Down Expand Up @@ -56,6 +59,8 @@
* that spawning a new Thread per call is not advisable. This includes virtual threads.
*/
public class GraphIndexBuilder implements Closeable {
private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class);

private final int beamWidth;
private final ExplicitThreadLocal<NodeArray> naturalScratch;
private final ExplicitThreadLocal<NodeArray> concurrentScratch;
Expand Down Expand Up @@ -235,86 +240,107 @@ public void cleanup() {
}

private void reconnectOrphanedNodes() {
var searchPathNeighbors = new ConcurrentHashMap<Integer, NodeArray>();
// It's possible that reconnecting one node will result in disconnecting another, since we are maintaining
// the maxConnections invariant. So, we do a best effort of 3 loops. We claim the entry node as an
// already used connectionTarget so that we don't clutter its edge list.
var connectionTargets = ConcurrentHashMap.<Integer>newKeySet();
connectionTargets.add(graph.entry());
for (int i = 0; i < 3; i++) {
// find all nodes reachable from the entry node
// Set of nodes already used as connection targets, initialized to the entry point. Since reconnection edges are
// usually worse (by distance and/or diversity) than the original ones, we update this as edges are added to
// avoid reusing the same target node more than once.
AtomicFixedBitSet globalConnectionTargets = new AtomicFixedBitSet(graph.getIdUpperBound());
globalConnectionTargets.set(graph.entry());
// Reconnection is best-effort: reconnecting one node may result in disconnecting another, since we are maintaining
// the maxConnections invariant. So, we do a maximum of 5 loops.
for (int i = 0; i < 5; i++) {
// determine the nodes reachable from the entry point at the start of this pass
var connectedNodes = new AtomicFixedBitSet(graph.getIdUpperBound());
connectedNodes.set(graph.entry());
ConcurrentNeighborMap.Neighbors self1 = graph.getNeighbors(graph.entry());
var entryNeighbors = (NodeArray) self1;
var entryNeighbors = graph.getNeighbors(graph.entry());
parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size()).parallel().forEach(node -> findConnected(connectedNodes, entryNeighbors.getNode(node)))).join();

// reconnect unreachable nodes
var nReconnected = new AtomicInteger();
// Gather basic debug information about efficacy/efficiency of reconnection attempts
var nReconnectAttempts = new AtomicInteger();
var nReconnectedViaNeighbors = new AtomicInteger();
var nResumesRun = new AtomicInteger();
var nReconnectedViaSearch = new AtomicInteger();

simdExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(node -> {
if (connectedNodes.get(node) || !graph.containsNode(node)) {
return;
}
nReconnected.incrementAndGet();
nReconnectAttempts.incrementAndGet();

// first, attempt to connect one of our own neighbors to us
// first, attempt to connect one of our own connected neighbors to us. Filtering
// to connected nodes tends to help for partitioned graphs with large partitions.
ConcurrentNeighborMap.Neighbors self = graph.getNeighbors(node);
var neighbors = (NodeArray) self;
if (connectToClosestNeighbor(node, neighbors, connectionTargets)) {
if (connectToClosestNeighbor(node, neighbors, connectedNodes, globalConnectionTargets) != null) {
nReconnectedViaNeighbors.incrementAndGet();
return;
}

// no unused candidate found -- search for more neighbors and try again
neighbors = searchPathNeighbors.get(node);
// run search again if neighbors is empty or if every neighbor is already in connection targets
if (neighbors == null || isSubset(neighbors, connectionTargets)) {
SearchResult result;
try (var gs = searchers.get()) {
var excludeBits = createExcludeBits(node, connectionTargets);
var ssp = scoreProvider.searchProviderFor(node);
int ep = graph.entry();
result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, excludeBits);
} catch (Exception e) {
throw new RuntimeException(e);
}
// if we can't find a connected neighbor to reconnect to, we'll have to search. We start with a small
// search, and we resume the search in a bounded loop to try to find an eligible connection target.
// This significantly improves behavior for large (1M+ node) partitioned graphs. We don't add
// connectionTargets to excludeBits because large partitions lead to excessively large excludeBits,
// greatly degrading search performance.
SearchResult result;
try (var gs = searchers.get()) {
var ssp = scoreProvider.searchProviderFor(node);
int ep = graph.entry();
result = gs.searchInternal(ssp, beamWidth, beamWidth, 0.0f, 0.0f, ep, other -> other != node);
neighbors = new NodeArray(result.getNodes().length);
toScratchCandidates(result.getNodes(), neighbors);
searchPathNeighbors.put(node, neighbors);
var j = 0;
// no need to filter to connected nodes here, as they're connected by virtue of being reachable via
// search
var reconnectedTo = connectToClosestNeighbor(node, neighbors, Bits.ALL, globalConnectionTargets);
// if we can't find a valid connectionTarget within 2*degree of the search destination, give up
while (reconnectedTo == null && j < 2 * graph.maxDegree) {
j++;
nResumesRun.incrementAndGet();
result = gs.resume(beamWidth, beamWidth);
toScratchCandidates(result.getNodes(), neighbors);
reconnectedTo = connectToClosestNeighbor(node, neighbors, Bits.ALL, globalConnectionTargets);
}

if (reconnectedTo != null) {
nReconnectedViaSearch.incrementAndGet();
// since we went to the trouble of finding the closest available neighbor, let `backlink`
// check to see if it should be added as an edge to the original node as well
var na = new NodeArray(1);
na.addInOrder(reconnectedTo.node, reconnectedTo.score);
graph.nodes.backlink(na, node, 1.0f);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
connectToClosestNeighbor(node, neighbors, connectionTargets);
})).join();
if (nReconnected.get() == 0) {
break;
}
}
}

private boolean isSubset(NodeArray neighbors, Set<Integer> nodeIds) {
for (int i = 0; i < neighbors.size(); i++) {
if (!nodeIds.contains(neighbors.getNode(i))) {
return false;
logger.debug("Reconnecting {} nodes out of {} on pass {}. {} neighbor reconnects. {} searches/resumes run. {} nodes reconnected via search",
nReconnectAttempts.get(), graph.size(), i, nReconnectedViaNeighbors.get(), nResumesRun.get(), nReconnectedViaSearch.get());

if (nReconnectAttempts.get() == 0) {
break;
}
}
return true;
}

/**
* Connect `node` to the closest neighbor that is not already a connection target.
* @return true if such a neighbor was found.
* Connect `node` to the closest connected neighbor that is not already a connection target.
*
* @return the node score id if such a neighbor was found, else null.
*/
private boolean connectToClosestNeighbor(int node, NodeArray neighbors, Set<Integer> connectionTargets) {
// connect this node to the closest neighbor that hasn't already been used as a connection target
private SearchResult.NodeScore connectToClosestNeighbor(int node, NodeArray neighbors, Bits connectedNodes, BitSet connectionTargets) {
// connect this node to the closest connected neighbor that hasn't already been used as a connection target
// (since this edge is likely to be the "worst" one in that target's neighborhood, it's likely to be
// overwritten by the next node to need reconnection if we don't choose a unique target)
for (int i = 0; i < neighbors.size(); i++) {
var neighborNode = neighbors.getNode(i);
if (!connectedNodes.get(neighborNode) || connectionTargets.get(neighborNode))
continue;

var neighborScore = neighbors.getScore(i);
if (connectionTargets.add(neighborNode)) {
graph.nodes.insertEdgeNotDiverse(neighborNode, node, neighborScore);
return true;
}
graph.nodes.insertEdgeNotDiverse(neighborNode, node, neighborScore);
connectionTargets.set(neighborNode);
return new SearchResult.NodeScore(neighborNode, neighborScore);
}
return false;
return null;
}

private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
Expand Down Expand Up @@ -577,10 +603,6 @@ public synchronized long removeDeletedNodes() {
return nRemoved * graph.ramBytesUsedOneNode();
}

private static Bits createExcludeBits(int node, Set<Integer> connectionTargets) {
return index -> index != node && !connectionTargets.contains(index);
}

/**
* Returns the ordinal of the node that is closest to the centroid of the graph,
* or NO_ENTRY_POINT if there are no live nodes in the graph.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,13 @@ public long ramBytesUsed() {
long storageSize = (long) storage.length() * longSizeInBytes + arrayOverhead;
return BASE_RAM_BYTES_USED + storageSize;
}

public AtomicFixedBitSet copy() {
AtomicFixedBitSet copy = new AtomicFixedBitSet(length());
for (int i = 0; i < storage.length(); i++) {
copy.storage.set(i, storage.get(i));
}
return copy;
}
}

5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@
<artifactId>agrona</artifactId>
<version>1.20.0</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>2.0.16</version>
</dependency>
</dependencies>
<dependencyManagement>
<dependencies>
Expand Down

0 comments on commit 78ee760

Please sign in to comment.