Skip to content

Commit

Permalink
add support for non-sequential remapped ordinals (#349)
Browse files Browse the repository at this point in the history
* add OrdinalMapper::maxOrdinal and add support for ordinal "holes"
  • Loading branch information
jbellis authored Aug 1, 2024
1 parent 0ddead5 commit b8298a4
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,26 @@ public synchronized void write(Map<FeatureId, IntFunction<Feature.State>> featur
writeHeader();

// for each graph node, write the associated features, followed by its neighbors
int graphSize = graph.size();
for (int newOrdinal = 0; newOrdinal < graphSize; newOrdinal++) {
for (int newOrdinal = 0; newOrdinal <= ordinalMapper.maxOrdinal(); newOrdinal++) {
var originalOrdinal = ordinalMapper.newToOld(newOrdinal);

// if no node exists with the given ordinal, write a placeholder
if (originalOrdinal == OrdinalMapper.OMITTED) {
out.writeInt(-1);
for (var feature : featureMap.values()) {
out.seek(out.position() + feature.inlineSize());
}
out.writeInt(0);
for (int n = 0; n < graph.maxDegree(); n++) {
out.writeInt(-1);
}
continue;
}

if (!graph.containsNode(originalOrdinal)) {
var msg = String.format("Ordinal mapper mapped new ordinal %s to non-existing node %s", newOrdinal, originalOrdinal);
throw new IllegalStateException(msg);
}

out.writeInt(newOrdinal); // unnecessary, but a reasonable sanity check
assert out.position() == featureOffsetForOrdinal(newOrdinal) : String.format("%d != %d", out.position(), featureOffsetForOrdinal(newOrdinal));
for (var feature : featureMap.values()) {
Expand All @@ -190,8 +202,8 @@ public synchronized void write(Map<FeatureId, IntFunction<Feature.State>> featur
int n = 0;
for (; n < neighbors.size(); n++) {
var newNeighborOrdinal = ordinalMapper.oldToNew(neighbors.nextInt());
if (newNeighborOrdinal < 0 || newNeighborOrdinal >= graphSize) {
var msg = String.format("Neighbor ordinal out of bounds: %d/%d", newNeighborOrdinal, graphSize);
if (newNeighborOrdinal < 0 || newNeighborOrdinal > ordinalMapper.maxOrdinal()) {
var msg = String.format("Neighbor ordinal out of bounds: %d/%d", newNeighborOrdinal, ordinalMapper.maxOrdinal());
throw new IllegalStateException(msg);
}
out.writeInt(newNeighborOrdinal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,48 @@

package io.github.jbellis.jvector.graph.disk;

import org.agrona.collections.Int2IntHashMap;

import java.util.Map;

public interface OrdinalMapper {
/**
* Used by newToOld to indicate that the new ordinal is a "hole" that has no corresponding old ordinal.
*/
int OMITTED = Integer.MIN_VALUE;

/**
* OnDiskGraphIndexWriter will iterate from 0..maxOrdinal(), inclusive.
*/
int maxOrdinal();

/**
* Map old ordinals (in the graph as constructed) to new ordinals (written to disk).
* Should always return a valid ordinal (between 0 and maxOrdinal).
*/
int oldToNew(int oldOrdinal);

/**
* Map new ordinals (written to disk) to old ordinals (in the graph as constructed).
* May return OMITTED if there is a "hole" at the new ordinal.
*/
int newToOld(int newOrdinal);

/**
* A mapper that leaves the original ordinals unchanged.
*/
class IdentityMapper implements OrdinalMapper {
private final int maxOrdinal;

public IdentityMapper(int maxOrdinal) {
this.maxOrdinal = maxOrdinal;
}

@Override
public int maxOrdinal() {
return maxOrdinal;
}

@Override
public int oldToNew(int oldOrdinal) {
return oldOrdinal;
Expand All @@ -35,14 +69,24 @@ public int newToOld(int newOrdinal) {
}
}

/**
* Converts a Map of old to new ordinals into an OrdinalMapper.
*/
class MapMapper implements OrdinalMapper {
private final int maxOrdinal;
private final Map<Integer, Integer> oldToNew;
private final int[] newToOld;
private final Int2IntHashMap newToOld;

public MapMapper(Map<Integer, Integer> oldToNew) {
this.oldToNew = oldToNew;
this.newToOld = new int[oldToNew.size()];
oldToNew.forEach((old, newOrdinal) -> newToOld[newOrdinal] = old);
this.newToOld = new Int2IntHashMap(oldToNew.size(), 0.65f, OMITTED);
oldToNew.forEach((old, newOrdinal) -> newToOld.put(newOrdinal, old));
this.maxOrdinal = oldToNew.values().stream().mapToInt(i -> i).max().orElse(-1);
}

@Override
public int maxOrdinal() {
return maxOrdinal;
}

@Override
Expand All @@ -52,7 +96,7 @@ public int oldToNew(int oldOrdinal) {

@Override
public int newToOld(int newOrdinal) {
return newToOld[newOrdinal];
return newToOld.get(newOrdinal);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ public int size() {
* @return the value of the key, or null if not set
*/
public T get(int key) {
var ref = objects;
if (key >= ref.length()) {
if (key >= objects.length()) {
return null;
}
return ref.get(key);

return objects.get(key);
}

private void ensureCapacity(int node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ private static BuilderWithSuppliers builderWithSuppliers(Set<FeatureId> features
ProductQuantization pq)
throws FileNotFoundException
{
var builder = new OnDiskGraphIndexWriter.Builder(onHeapGraph, outPath).withMapper(new OrdinalMapper.IdentityMapper());
var identityMapper = new OrdinalMapper.IdentityMapper(onHeapGraph.getIdUpperBound() - 1);
var builder = new OnDiskGraphIndexWriter.Builder(onHeapGraph, outPath).withMapper(identityMapper);
Map<FeatureId, IntFunction<Feature.State>> suppliers = new EnumMap<>(FeatureId.class);
for (var featureId : features) {
switch (featureId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public static void siftDiskAnnLTM(List<VectorFloat<?>> baseVectors, List<VectorF
// explicit Writer for the first time, this is what's behind OnDiskGraphIndex.write
OnDiskGraphIndexWriter writer = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexPath)
.with(new InlineVectors(ravv.dimension()))
.withMapper(new OrdinalMapper.IdentityMapper())
.withMapper(new OrdinalMapper.IdentityMapper(builder.getGraph().getIdUpperBound() - 1))
.build();
// output for the compressed vectors
DataOutputStream pqOut = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(pqPath))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,35 @@ public void testReorderingRenumbering() throws IOException {
}
}

@Test
public void testReorderingWithHoles() throws IOException {
// graph of 3 vectors
var ravv = new TestVectorGraph.CircularFloatVectorValues(3);
var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.COSINE, 2, 10, 1.0f, 1.0f);
var original = TestUtil.buildSequentially(builder, ravv);

// create renumbering map
Map<Integer, Integer> oldToNewMap = new HashMap<>();
oldToNewMap.put(0, 2);
oldToNewMap.put(1, 10);
oldToNewMap.put(2, 0);

// write the graph
var outputPath = testDirectory.resolve("renumbered_graph");
OnDiskGraphIndex.write(original, ravv, oldToNewMap, outputPath);
// check that written graph ordinals match the new ones
try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString());
var onDiskGraph = OnDiskGraphIndex.load(marr::duplicate);
var onDiskView = onDiskGraph.getView())
{
assertEquals(onDiskView.getVector(0), ravv.getVector(2));
assertEquals(onDiskView.getVector(10), ravv.getVector(1));
assertEquals(onDiskView.getVector(2), ravv.getVector(0));
} catch (Exception e) {
throw new RuntimeException(e);
}
}

private static void validateVectors(OnDiskGraphIndex.View view, RandomAccessVectorValues ravv) {
for (int i = 0; i < view.size(); i++) {
assertEquals("Incorrect vector at " + i, view.getVector(i), ravv.getVector(i));
Expand Down

0 comments on commit b8298a4

Please sign in to comment.