diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/InlineVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/InlineVectorValues.java index 7e7a2751..6b63fbf2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/InlineVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/InlineVectorValues.java @@ -46,7 +46,7 @@ public InlineVectorValues(int dimension, OnDiskGraphIndexWriter writer) { @Override public int size() { - return writer.getMaxOrdinal(); + return writer.getMaxOrdinal() + 1; // +1 because ordinals are 0-based } @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/LvqVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/LvqVectorValues.java index 2e830d58..cebe1374 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/LvqVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/LvqVectorValues.java @@ -48,7 +48,7 @@ public LvqVectorValues(int dimension, LVQ lvq, OnDiskGraphIndexWriter writer) { @Override public int size() { - return writer.getMaxOrdinal(); + return writer.getMaxOrdinal() + 1; // +1 because ordinals are 0-based; } @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java index bac1c2fb..dfc92e28 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java @@ -29,7 +29,6 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.file.Path; -import java.util.Collection; import java.util.EnumMap; import java.util.Map; import java.util.Set; @@ -54,7 +53,7 @@ public class OnDiskGraphIndexWriter implements Closeable { private final BufferedRandomAccessWriter out; private final long startOffset; private final int headerSize; - private volatile int maxOrdinalWritten; + private volatile int maxOrdinalWritten = -1; private OnDiskGraphIndexWriter(Path outPath, int version, @@ -115,6 +114,9 @@ public synchronized void writeInline(int ordinal, Map maxOrdinalWritten = Math.max(maxOrdinalWritten, ordinal); } + /** + * @return the maximum ordinal written so far, or -1 if no ordinals have been written yet + */ public int getMaxOrdinal() { return maxOrdinalWritten; } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java index 12d2183f..33f11444 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java @@ -24,6 +24,7 @@ import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.TestVectorGraph; +import io.github.jbellis.jvector.pq.LocallyAdaptiveVectorQuantization; import io.github.jbellis.jvector.pq.PQVectors; import io.github.jbellis.jvector.pq.ProductQuantization; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; @@ -35,18 +36,14 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; -import java.util.EnumMap; import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.IntFunction; import static io.github.jbellis.jvector.TestUtil.getNeighborNodes; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class TestOnDiskGraphIndex extends RandomizedTest { @@ -275,13 +272,19 @@ public void testIncrementalWrites() throws IOException { var incrementalPath = testDirectory.resolve("bulk_graph"); try (var writer = new OnDiskGraphIndexWriter.Builder(graph, incrementalPath) .with(new InlineVectors(ravv.dimension())) - .build()) + .build(); + var ivv = new InlineVectorValues(ravv.dimension(), writer)) { + assertEquals(0, ivv.size()); + // write inline vectors incrementally for (int i = 0; i < vectors.size(); i++) { var state = Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(ravv.getVector(i))); writer.writeInline(i, state); } + + assertEquals(vectors.size(), ivv.size()); + // write graph structure writer.write(Map.of()); } @@ -323,5 +326,43 @@ public void testIncrementalWrites() throws IOException { } catch (Exception e) { throw new RuntimeException(e); } + + // write incrementally with LVQ and add Fused ADC feature + var incrementalLvqPath = testDirectory.resolve("incremental_lvq_graph"); + var lvq = LocallyAdaptiveVectorQuantization.compute(ravv); + var lvqFeature = new LVQ(lvq); + + try (var writer = new OnDiskGraphIndexWriter.Builder(graph, incrementalLvqPath) + .with(lvqFeature) + .with(new FusedADC(graph.maxDegree(), pq)) + .build(); + var lvqvv = new LvqVectorValues(ravv.dimension(), lvqFeature, writer);) + { + assertEquals(0, lvqvv.size()); + + // write inline vectors incrementally + for (int i = 0; i < vectors.size(); i++) { + var state = Feature.singleState(FeatureId.LVQ, new LVQ.State(lvq.encode(ravv.getVector(i)))); + writer.writeInline(i, state); + } + + assertEquals(vectors.size(), lvqvv.size()); + + // write graph structure, fused ADC + writer.write(Feature.singleStateFactory(FeatureId.FUSED_ADC, i -> new FusedADC.State(graph.getView(), pqv, i))); + writer.write(Map.of()); + } + + // graph and vectors should be identical + try (var bulkMarr = new SimpleMappedReader(bulkPath.toAbsolutePath().toString()); + var bulkGraph = OnDiskGraphIndex.load(bulkMarr::duplicate); + var incrementalMarr = new SimpleMappedReader(incrementalLvqPath.toAbsolutePath().toString()); + var incrementalGraph = OnDiskGraphIndex.load(incrementalMarr::duplicate)) + { + assertTrue(OnDiskGraphIndex.areHeadersEqual(incrementalGraph, bulkGraph)); + TestUtil.assertGraphEquals(incrementalGraph, bulkGraph); // incremental and bulk graph should have same structure + } catch (Exception e) { + throw new RuntimeException(e); + } } }