Skip to content

Commit

Permalink
🔨 version 1.0.2 - Update hnswlib compatible version to 0.7.0 and Add …
Browse files Browse the repository at this point in the history
…some APIs
  • Loading branch information
sunhailin-Leo committed Feb 27, 2023
1 parent b82b958 commit 7a3eb65
Show file tree
Hide file tree
Showing 14 changed files with 1,106 additions and 496 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
# hnswlib-to-go
Hnswlib to go. Golang interface to hnswlib(https://github.com/nmslib/hnswlib). This is a golang interface of [hnswlib](https://github.com/nmslib/hnswlib). For more information, please follow [hnswlib](https://github.com/nmslib/hnswlib) and [Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs.](https://arxiv.org/abs/1603.09320).
**But in this project, we make compatible hnswlib to 0.5.2.**

**But in this project, we make compatible hnswlib to 0.7.0.**


### Version

* version 1.0.2
* Update hnswlib compatible version to 0.7.0
* Add `AddBatchPoints`, `SearchBatchKNN`, `SetNormalize`, `ResizeIndex`, `MarkDelete`, `UnmarkDelete`, `GetLabelIsMarkedDeleted` API

* version 1.0.1
* Code format
* Add an api support unload the graph(Experimental)

* version 1.0.0
* hnswlib compatible version 0.5.2.
* hnswlib compatible version 0.5.2.


### Build

* Linux/MacOS
* Build Golang Env
* `go mod init`
* `make`
* Build Golang Env
* `go mod init`
* `make`

### Usage

Expand Down
76 changes: 69 additions & 7 deletions example/demo.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ func randVector(dim int) []float32 {
return vec
}

func main() {
var dim, M, ef int = 128, 32, 300
// 单个写入
func exampleAddPoint(indexFileName string) {
var dim, M, ef = 128, 32, 300
// 最大的 elements 数
var maxElements uint32 = 50000
var maxElements uint32 = 10000
// 定义距离 cosine
var spaceType, indexLocation string = "cosine", "hnsw_demo_index.bin"
var randomSeed int = 100
var spaceType = "cosine"
var randomSeed = 100
fmt.Println("Before Create HNSW")
traceMemStats()
// Init new index
Expand All @@ -61,8 +62,43 @@ func main() {
}
h.AddPoint(randVector(dim), i)
}
h.Save(indexLocation)
h = hnswgo.Load(indexLocation, dim, spaceType)
h.Save(indexFileName)
}

// 批量写入
func exampleBatchAddPoint(indexFileName string) {
var dim, M, ef = 128, 32, 300

// 最大的 elements 数
var maxElements uint32 = 20000

// 定义距离 cosine
var spaceType = "cosine"
var randomSeed = 100
fmt.Println("Before Create HNSW")

// 初始化 Init new index
h := hnswgo.New(dim, M, ef, randomSeed, maxElements, spaceType)

vectorList := make([][]float32, maxElements)
ids := make([]uint32, maxElements)
var i uint32
for ; i < maxElements; i++ {
if i%1000 == 0 {
fmt.Println(i)
}
vectorList[i] = randVector(dim)
ids[i] = i
}
h.AddBatchPoints(vectorList, ids, 10)

// 保存索引 Save Index
h.Save(indexFileName)
}

// 读取
func exampleLoadIndex(indexFileName, spaceType string, dim int) {
h := hnswgo.Load(indexFileName, dim, spaceType)
// Search vector with maximum 5 NN
h.SetEf(15)
searchVector := randVector(dim)
Expand All @@ -73,10 +109,36 @@ func main() {
fmt.Println(endTime - startTime)
fmt.Println(labels, vectors)

// Test ResizeIndex API
isResize := h.ResizeIndex(12000)
fmt.Println("Size flag: ", isResize)

// Test Mark API
isMarkDelete := h.MarkDelete(10)
fmt.Println("isMarkDelete: ", isMarkDelete)

labelIsDelete := h.GetLabelIsMarkedDeleted(10)
fmt.Println("labelIsDelete: ", labelIsDelete)

isUnmarkDelete := h.UnmarkDelete(10)
fmt.Println("isUnmarkDelete: ", isUnmarkDelete)

// Test Unload API
fmt.Println("Before Unload")
traceMemStats()
h.Unload()
fmt.Println("After Unload")
traceMemStats()
}

func main() {
// 单条写入 add index point by point
exampleAddPoint("hnsw_demo_single.bin")
// 测试读取 test loading
exampleLoadIndex("hnsw_demo_single.bin", "cosine", 128)

// 批量写入 add index with batch mode
//exampleBatchAddPoint("hnsw_demo_multiple.bin")
// 测试读取 test loading
//exampleLoadIndex("hnsw_demo_multiple.bin", "cosine", 128)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module github.com/sunhailin-Leo/hnswlib-to-go

go 1.15
go 1.18
92 changes: 92 additions & 0 deletions hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@ package hnswgo

// #cgo LDFLAGS: -L${SRCDIR} -lhnsw -lm
// #include <stdlib.h>
// #include <stdbool.h>
// #include "hnsw_wrapper.h"
// HNSW initHNSW(int dim, unsigned long int max_elements, int M, int ef_construction, int rand_seed, char stype);
// HNSW loadHNSW(char *location, int dim, char stype);
// void addPoint(HNSW index, float *vec, unsigned long int label);
// int searchKnn(HNSW index, float *vec, int N, unsigned long int *label, float *dist);
// void setEf(HNSW index, int ef);
// bool resizeIndex(HNSW index, unsigned long int new_max_elements);
// bool markDelete(HNSW index, unsigned long int label);
// bool unmarkDelete(HNSW index, unsigned long int label);
// bool isMarkedDeleted(HNSW index, unsigned long int label);
// bool updatePoint(HNSW index, float *vec, unsigned long int label);
import "C"
import (
"math"
"runtime"
"sync"
"unsafe"
)

Expand Down Expand Up @@ -105,6 +112,33 @@ func (h *HNSW) AddPoint(vector []float32, label uint32) bool {
return true
}

// AddBatchPoints add some points on graph with goroutine
func (h *HNSW) AddBatchPoints(vectors [][]float32, labels []uint32, coroutines int) bool {
if len(vectors) != len(labels) {
return false
}

b := len(vectors) / coroutines
var wg sync.WaitGroup
for i := 0; i < coroutines; i++ {
wg.Add(1)

end := (i + 1) * b
if i == coroutines-1 && len(vectors) > end {
end = len(vectors)
}
go func(thisVectors [][]float32, thisLabels []uint32) {
defer wg.Done()
for j := 0; j < len(thisVectors); j++ {
h.AddPoint(thisVectors[j], thisLabels[j])
}
}(vectors[i*b:end], labels[i*b:end])
}

wg.Wait()
return true
}

// SearchKNN search points on graph with knn-algorithm
func (h *HNSW) SearchKNN(vector []float32, N int) ([]uint32, []float32) {
if h.index == nil {
Expand All @@ -125,10 +159,68 @@ func (h *HNSW) SearchKNN(vector []float32, N int) ([]uint32, []float32) {
return labels[:numResult], dists[:numResult]
}

func (h *HNSW) SearchBatchKNN(vectors [][]float32, N, coroutines int) ([][]uint32, [][]float32) {
var lock sync.Mutex
labelList := make([][]uint32, len(vectors))
distList := make([][]float32, len(vectors))

b := len(vectors) / coroutines
var wg sync.WaitGroup
for i := 0; i < coroutines; i++ {
wg.Add(1)

end := (i + 1) * b
if i == coroutines-1 && len(vectors) > end {
end = len(vectors)
}
go func(i int) {
defer wg.Done()
for j := i * b; j < end; j++ {
labels, dist := h.SearchKNN(vectors[j], N)
lock.Lock()
labelList[j] = labels
distList[j] = dist
lock.Unlock()
}
}(i)
}
wg.Wait()
return labelList, distList
}

// SetEf set ef argument on graph
func (h *HNSW) SetEf(ef int) {
if h.index == nil {
return
}
C.setEf(h.index, C.int(ef))
}

// SetNormalize set normalize on graph
func (h *HNSW) SetNormalize(isNeedNormalize bool) {
h.normalize = isNeedNormalize
}

// ResizeIndex set new elements count to resize index
func (h *HNSW) ResizeIndex(newMaxElements uint32) bool {
isResize := bool(C.resizeIndex(h.index, C.ulong(newMaxElements)))
return isResize
}

// MarkDelete mark a label to delete mode
func (h *HNSW) MarkDelete(label uint32) bool {
isMark := bool(C.markDelete(h.index, C.ulong(label)))
return isMark
}

// UnmarkDelete unmark a label to delete mode
func (h *HNSW) UnmarkDelete(label uint32) bool {
isUnmark := bool(C.unmarkDelete(h.index, C.ulong(label)))
return isUnmark
}

// GetLabelIsMarkedDeleted get label isDelete
func (h *HNSW) GetLabelIsMarkedDeleted(label uint32) bool {
isDelete := bool(C.isMarkedDeleted(h.index, C.ulong(label)))
return isDelete
}
88 changes: 76 additions & 12 deletions hnsw_wrapper.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//hnsw_wrapper.cpp
#include <vector>
#include <iostream>
#include "hnswlib/hnswlib.h"
#include "hnsw_wrapper.h"
Expand All @@ -12,8 +13,9 @@ HNSW initHNSW(int dim, unsigned long int max_elements, int M, int ef_constructio
} else {
space = new hnswlib::L2Space(dim);
}
hnswlib::HierarchicalNSW<float> *appr_alg = new hnswlib::HierarchicalNSW<float>(space, max_elements, M, ef_construction, rand_seed);
return (void*)appr_alg;
hnswlib::HierarchicalNSW<float> *appr_alg = new hnswlib::HierarchicalNSW<float>(space, max_elements, M,
ef_construction, rand_seed);
return (void *) appr_alg;
}

HNSW loadHNSW(char *location, int dim, char stype) {
Expand All @@ -23,38 +25,100 @@ HNSW loadHNSW(char *location, int dim, char stype) {
} else {
space = new hnswlib::L2Space(dim);
}
hnswlib::HierarchicalNSW<float> *appr_alg = new hnswlib::HierarchicalNSW<float>(space, std::string(location), false, 0);
return (void*)appr_alg;
hnswlib::HierarchicalNSW<float> *appr_alg = new hnswlib::HierarchicalNSW<float>(space, std::string(location), false,
0);
return (void *) appr_alg;
}

HNSW saveHNSW(HNSW index, char *location) {
((hnswlib::HierarchicalNSW<float>*)index)->saveIndex(location);
((hnswlib::HierarchicalNSW<float> *) index)->saveIndex(location);
return 0;
}

void addPoint(HNSW index, float *vec, unsigned long int label) {
((hnswlib::HierarchicalNSW<float>*)index)->addPoint(vec, label);
((hnswlib::HierarchicalNSW<float> *) index)->addPoint(vec, label);
}

int searchKnn(HNSW index, float *vec, int N, unsigned long int *label, float *dist) {
std::priority_queue<std::pair<float, hnswlib::labeltype>> gt;
std::priority_queue <std::pair<float, hnswlib::labeltype>> gt;
try {
gt = ((hnswlib::HierarchicalNSW<float>*)index)->searchKnn(vec, N);
} catch (const std::exception& e) {
gt = ((hnswlib::HierarchicalNSW<float> *) index)->searchKnn(vec, N);
} catch (const std::exception &e) {
return 0;
}

int n = gt.size();
std::pair<float, hnswlib::labeltype> pair;
for (int i = n - 1; i >= 0; i--) {
pair = gt.top();
*(dist+i) = pair.first;
*(label+i) = pair.second;
*(dist + i) = pair.first;
*(label + i) = pair.second;
gt.pop();
}
return n;
}

void setEf(HNSW index, int ef) {
((hnswlib::HierarchicalNSW<float>*)index)->ef_ = ef;
((hnswlib::HierarchicalNSW<float> *) index)->ef_ = ef;
}

bool resizeIndex(HNSW index, unsigned long int new_max_elements) {
if (new_max_elements < ((hnswlib::HierarchicalNSW<float> *) index)->getCurrentElementCount()) {
return false;
}
try {
((hnswlib::HierarchicalNSW<float> *) index)->resizeIndex(new_max_elements);
} catch (const std::exception &e) {
return false;
}
return true;
}

bool markDelete(HNSW index, unsigned long int label) {
try {
((hnswlib::HierarchicalNSW<float> *) index)->markDelete(label);
return true;
} catch (const std::exception &e) {
return false;
}
}

bool unmarkDelete(HNSW index, unsigned long int label) {
try {
((hnswlib::HierarchicalNSW<float> *) index)->unmarkDelete(label);
return true;
} catch (const std::exception &e) {
return false;
}
}

bool isMarkedDeleted(HNSW index, unsigned long int label) {
std::unique_lock <std::mutex> lock_table(((hnswlib::HierarchicalNSW<float> *) index)->label_lookup_lock);
auto search = ((hnswlib::HierarchicalNSW<float> *) index)->label_lookup_.find(label);

if (search != ((hnswlib::HierarchicalNSW<float> *) index)->label_lookup_.end()) {
bool res = ((hnswlib::HierarchicalNSW<float> *) index)->isMarkedDeleted(search->second);
lock_table.unlock();
return res;
}
return false;
}

bool updatePoint(HNSW index, float *vec, unsigned long int label) {
std::unique_lock <std::mutex> lock_table(((hnswlib::HierarchicalNSW<float> *) index)->label_lookup_lock);
auto search = ((hnswlib::HierarchicalNSW<float> *) index)->label_lookup_.find(label);

if (search != ((hnswlib::HierarchicalNSW<float> *) index)->label_lookup_.end()) {
hnswlib::tableint existingInternalId = search->second;
lock_table.unlock();
// const void *dataPoint, tableint internalId, float updateNeighborProbability
((hnswlib::HierarchicalNSW<float> *) index)->updatePoint(vec, existingInternalId, 1.0);
return true;
}
return false;
}

// TODO
//std::vector<float> getDataByLabel(HNSW index, unsigned long int label) {
// return ((hnswlib::HierarchicalNSW<float>*)index)->getDataByLabel<float>(label);
//}
Loading

0 comments on commit 7a3eb65

Please sign in to comment.