Skip to content

Commit

Permalink
TimeoutCallback
Browse files Browse the repository at this point in the history
Summary: #3351

Differential Revision: D56732720
  • Loading branch information
Amir Sadoughi authored and facebook-github-bot committed Apr 30, 2024
1 parent 825cbac commit e9c23b2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
7 changes: 7 additions & 0 deletions faiss/gpu/perf/PerfClustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

#include <faiss/Clustering.h>
#include <faiss/MetricType.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <faiss/gpu/perf/IndexWrapper.h>
Expand All @@ -17,6 +18,7 @@
#include <vector>

#include <cuda_profiler_api.h>
#include <faiss/impl/AuxIndexStructures.h>

DEFINE_int32(num, 10000, "# of vecs");
DEFINE_int32(k, 100, "# of clusters");
Expand All @@ -34,6 +36,7 @@ DEFINE_int64(
"minimum size to use CPU -> GPU paged copies");
DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use");
DEFINE_int32(max_points, -1, "max points per centroid");
DEFINE_double(timeout, 0, "timeout in seconds");

using namespace faiss::gpu;

Expand Down Expand Up @@ -99,10 +102,14 @@ int main(int argc, char** argv) {
cp.max_points_per_centroid = FLAGS_max_points;
}

auto tc = new faiss::TimeoutCallback();
faiss::InterruptCallback::instance.reset(tc);

faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp);

// Time k-means
{
tc->set_timeout(FLAGS_timeout);
CpuTimer timer;

kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex()));
Expand Down
19 changes: 19 additions & 0 deletions faiss/impl/AuxIndexStructures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,4 +236,23 @@ size_t InterruptCallback::get_period_hint(size_t flops) {
return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
}

void TimeoutCallback::set_timeout(double timeout_in_seconds) {
timeout = timeout_in_seconds;
start = std::chrono::steady_clock::now();
}

bool TimeoutCallback::want_interrupt() {
if (timeout == 0) {
return false;
}
auto end = std::chrono::steady_clock::now();
std::chrono::duration<float, std::milli> duration = end - start;
float elapsed_in_seconds = duration.count() / 1000.0;
if (elapsed_in_seconds > timeout) {
timeout = 0;
return true;
}
return false;
}

} // namespace faiss
7 changes: 7 additions & 0 deletions faiss/impl/AuxIndexStructures.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ struct FAISS_API InterruptCallback {
static size_t get_period_hint(size_t flops);
};

struct TimeoutCallback : InterruptCallback {
std::chrono::time_point<std::chrono::steady_clock> start;
double timeout;
bool want_interrupt() override;
void set_timeout(double timeout_in_seconds);
};

/// set implementation optimized for fast access.
struct VisitedTable {
std::vector<uint8_t> visited;
Expand Down

0 comments on commit e9c23b2

Please sign in to comment.