diff --git a/faiss/impl/NNDescent.cpp b/faiss/impl/NNDescent.cpp index 8878349ff6..b609aba390 100644 --- a/faiss/impl/NNDescent.cpp +++ b/faiss/impl/NNDescent.cpp @@ -195,8 +195,9 @@ void NNDescent::update() { int l = 0; while ((l < maxl) && (c < S)) { - if (nn.pool[l].flag) + if (nn.pool[l].flag) { ++c; + } ++l; } nn.M = l; @@ -305,8 +306,9 @@ void NNDescent::generate_eval_set( for (int i = 0; i < c.size(); i++) { std::vector tmp; for (int j = 0; j < N; j++) { - if (c[i] == j) + if (c[i] == j) { continue; // skip itself + } float dist = qdis.symmetric_dis(c[i], j); tmp.push_back(Neighbor(j, dist, true)); } @@ -360,8 +362,9 @@ void NNDescent::init_graph(DistanceComputer& qdis) { for (int j = 0; j < S; j++) { int id = tmp[j]; - if (id == i) + if (id == i) { continue; + } float dist = qdis.symmetric_dis(i, id); graph[i].pool.push_back(Neighbor(id, dist, true)); @@ -418,30 +421,30 @@ void NNDescent::search( float* dists, VisitedTable& vt) const { FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet."); - int L = std::max(search_L, topk); + int L_2 = std::max(search_L, topk); // candidate pool, the K best items is the result. - std::vector retset(L + 1); + std::vector retset(L_2 + 1); - // Randomly choose L points to initialize the candidate pool - std::vector init_ids(L); + // Randomly choose L_2 points to initialize the candidate pool + std::vector init_ids(L_2); std::mt19937 rng(random_seed); - gen_random(rng, init_ids.data(), L, ntotal); - for (int i = 0; i < L; i++) { + gen_random(rng, init_ids.data(), L_2, ntotal); + for (int i = 0; i < L_2; i++) { int id = init_ids[i]; float dist = qdis(id); retset[i] = Neighbor(id, dist, true); } // Maintain the candidate pool in ascending order - std::sort(retset.begin(), retset.begin() + L); + std::sort(retset.begin(), retset.begin() + L_2); int k = 0; - // Stop until the smallest position updated is >= L - while (k < L) { - int nk = L; + // Stop until the smallest position updated is >= L_2 + while (k < L_2) { + int nk = L_2; if (retset[k].flag) { retset[k].flag = false; @@ -449,25 +452,28 @@ void NNDescent::search( for (int m = 0; m < K; ++m) { int id = final_graph[n * K + m]; - if (vt.get(id)) + if (vt.get(id)) { continue; + } vt.set(id); float dist = qdis(id); - if (dist >= retset[L - 1].distance) + if (dist >= retset[L_2 - 1].distance) { continue; + } Neighbor nn(id, dist, true); - int r = insert_into_pool(retset.data(), L, nn); + int r = insert_into_pool(retset.data(), L_2, nn); if (r < nk) nk = r; } } - if (nk <= k) + if (nk <= k) { k = nk; - else + } else { ++k; + } } for (size_t i = 0; i < topk; i++) { indices[i] = retset[i].id;