From e9c23b24e0b2dd5ace658124be08e3fc59730e3b Mon Sep 17 00:00:00 2001 From: Amir Sadoughi Date: Tue, 30 Apr 2024 09:08:26 -0700 Subject: [PATCH] TimeoutCallback Summary: https://github.com/facebookresearch/faiss/issues/3351 Differential Revision: D56732720 --- faiss/gpu/perf/PerfClustering.cpp | 7 +++++++ faiss/impl/AuxIndexStructures.cpp | 19 +++++++++++++++++++ faiss/impl/AuxIndexStructures.h | 7 +++++++ 3 files changed, 33 insertions(+) diff --git a/faiss/gpu/perf/PerfClustering.cpp b/faiss/gpu/perf/PerfClustering.cpp index 0322f0e490..3589249eff 100644 --- a/faiss/gpu/perf/PerfClustering.cpp +++ b/faiss/gpu/perf/PerfClustering.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include #include @@ -17,6 +18,7 @@ #include #include +#include DEFINE_int32(num, 10000, "# of vecs"); DEFINE_int32(k, 100, "# of clusters"); @@ -34,6 +36,7 @@ DEFINE_int64( "minimum size to use CPU -> GPU paged copies"); DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use"); DEFINE_int32(max_points, -1, "max points per centroid"); +DEFINE_double(timeout, 0, "timeout in seconds"); using namespace faiss::gpu; @@ -99,10 +102,14 @@ int main(int argc, char** argv) { cp.max_points_per_centroid = FLAGS_max_points; } + auto tc = new faiss::TimeoutCallback(); + faiss::InterruptCallback::instance.reset(tc); + faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp); // Time k-means { + tc->set_timeout(FLAGS_timeout); CpuTimer timer; kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex())); diff --git a/faiss/impl/AuxIndexStructures.cpp b/faiss/impl/AuxIndexStructures.cpp index cebe8a1e23..01c7dd5267 100644 --- a/faiss/impl/AuxIndexStructures.cpp +++ b/faiss/impl/AuxIndexStructures.cpp @@ -236,4 +236,23 @@ size_t InterruptCallback::get_period_hint(size_t flops) { return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1); } +void TimeoutCallback::set_timeout(double timeout_in_seconds) { + timeout = timeout_in_seconds; + start = std::chrono::steady_clock::now(); +} + +bool TimeoutCallback::want_interrupt() { + if (timeout == 0) { + return false; + } + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + float elapsed_in_seconds = duration.count() / 1000.0; + if (elapsed_in_seconds > timeout) { + timeout = 0; + return true; + } + return false; +} + } // namespace faiss diff --git a/faiss/impl/AuxIndexStructures.h b/faiss/impl/AuxIndexStructures.h index f8b5cca842..5dc15eae46 100644 --- a/faiss/impl/AuxIndexStructures.h +++ b/faiss/impl/AuxIndexStructures.h @@ -161,6 +161,13 @@ struct FAISS_API InterruptCallback { static size_t get_period_hint(size_t flops); }; +struct TimeoutCallback : InterruptCallback { + std::chrono::time_point start; + double timeout; + bool want_interrupt() override; + void set_timeout(double timeout_in_seconds); +}; + /// set implementation optimized for fast access. struct VisitedTable { std::vector visited;