Skip to content

Commit

Permalink
make topK configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
Vivek Narang committed Nov 11, 2024
1 parent f9c2df7 commit 1597158
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public static void main(String[] args) throws Throwable {

// Query
CuVSQuery query = new CuVSQuery.Builder()
.withTopK(1)
.withSearchParams(cagraSearchParams)
.withQueryVectors(queries)
.withMapping(map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ public SearchResult search(CuVSQuery query) throws Throwable {
MemoryLayout rvML = linker.canonicalLayouts().get("int");
MemorySegment rvMS = arena.allocate(rvML);

searchMH.invokeExact(ref.indexMemorySegment, getMemorySegment(query.queryVectors), 2, 4L, 2L, res.getResource(),
neighborsMS, distancesMS, rvMS, query.searchParams.cagraSearchParamsMS);
searchMH.invokeExact(ref.indexMemorySegment, getMemorySegment(query.getQueries()), query.getTopK(), 4L, 2L, res.getResource(),
neighborsMS, distancesMS, rvMS, query.getSearchParams().cagraSearchParamsMS);

return new SearchResult(neighborsSL, distancesSL, neighborsMS, distancesMS, 2, query.mapping);
return new SearchResult(neighborsSL, distancesSL, neighborsMS, distancesMS, query.getTopK(), query.getMapping());
}

/**
Expand Down
25 changes: 23 additions & 2 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/cagra/CuVSQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ public class CuVSQuery {
PreFilter preFilter;
float[][] queryVectors;
public Map<Integer, Integer> mapping;
int topK;

public CuVSQuery(CagraSearchParams searchParams, PreFilter preFilter, float[][] queryVectors,
Map<Integer, Integer> mapping) {
Map<Integer, Integer> mapping, int topK) {
super();
this.searchParams = searchParams;
this.preFilter = preFilter;
this.queryVectors = queryVectors;
this.mapping = mapping;
this.topK = topK;
}

@Override
Expand All @@ -37,11 +39,20 @@ public float[][] getQueries() {
return queryVectors;
}

public Map<Integer, Integer> getMapping() {
return mapping;
}

public int getTopK() {
return topK;
}

public static class Builder {
CagraSearchParams searchParams;
PreFilter preFilter;
float[][] queryVectors;
Map<Integer, Integer> mapping;
int topK = 2;

/**
*
Expand Down Expand Up @@ -89,14 +100,24 @@ public Builder withMapping(Map<Integer, Integer> mapping) {
this.mapping = mapping;
return this;
}

/**
*
* @param topK
* @return
*/
public Builder withTopK(int topK) {
this.topK = topK;
return this;
}

/**
*
* @return
* @throws Throwable
*/
public CuVSQuery build() throws Throwable {
return new CuVSQuery(searchParams, preFilter, queryVectors, mapping);
return new CuVSQuery(searchParams, preFilter, queryVectors, mapping, topK);
}
}

Expand Down

0 comments on commit 1597158

Please sign in to comment.