Skip to content

Commit

Permalink
Fix: Set dims in metric_punned_t::stateful
Browse files Browse the repository at this point in the history
Co-authored-by: Ash Vardanian <[email protected]>
Co-authored-by: Terence Liu <[email protected]>
Co-authored-by: Terence Z. Liu <[email protected]>
  • Loading branch information
3 people committed Nov 18, 2024
1 parent dccdd8e commit 4fae008
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_
USEARCH_ASSERT(index && error && "Missing arguments");
auto& index_dense = *reinterpret_cast<index_dense_t*>(index);
auto metric_punned =
state ? metric_punned_t::stateful(reinterpret_cast<std::uintptr_t>(metric),
state ? metric_punned_t::stateful(index_dense.dimensions(), reinterpret_cast<std::uintptr_t>(metric),
reinterpret_cast<std::uintptr_t>(state), metric_kind_to_cpp(kind),
index_dense.scalar_kind())
: metric_punned_t::stateless(index_dense.dimensions(), reinterpret_cast<std::uintptr_t>(metric),
Expand Down
14 changes: 9 additions & 5 deletions include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1767,20 +1767,22 @@ class metric_punned_t {
* @brief Creates a metric using the provided function pointer for a stateful metric.
* The third argument is the state that will be passed to the metric function.
*
* @param dimensions The number of elements in the input arrays.
* @param metric_uintptr The function pointer to the metric function.
* @param metric_state The state to pass to the metric function.
* @param metric_kind The kind of metric to use.
* @param scalar_kind The kind of scalar to use.
* @return A metric object that can be used to compute distances between vectors.
*/
inline static metric_punned_t stateful(std::uintptr_t metric_uintptr, std::uintptr_t metric_state,
metric_kind_t metric_kind = metric_kind_t::unknown_k,
scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept {
inline static metric_punned_t stateful( //
std::size_t dimensions, std::uintptr_t metric_uintptr, std::uintptr_t metric_state,
metric_kind_t metric_kind = metric_kind_t::unknown_k,
scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept {
metric_punned_t metric;
metric.metric_routed_ = &metric_punned_t::invoke_array_array_third;
metric.metric_ptr_ = metric_uintptr;
metric.metric_third_arg_ = metric_state;
metric.dimensions_ = 0;
metric.dimensions_ = dimensions;
metric.metric_kind_ = metric_kind;
metric.scalar_kind_ = scalar_kind;
return metric;
Expand Down Expand Up @@ -2223,6 +2225,8 @@ template <typename allocator_at = std::allocator<char>> class kmeans_clustering_
scalar_kind_t original_scalar_kind, std::size_t dimensions, executor_at&& executor = executor_at{},
progress_at&& progress = progress_at{}) {

(void)progress; // TODO

// Perform sanity checks for algorithm settings.
kmeans_clustering_result_t result;
if (max_iterations < 1)
Expand Down Expand Up @@ -2332,7 +2336,7 @@ template <typename allocator_at = std::allocator<char>> class kmeans_clustering_

// For every point, find the closest centroid.
std::atomic<std::size_t> points_shifted{0};
executor.dynamic(points_count, [&](std::size_t thread_idx, std::size_t points_idx) {
executor.dynamic(points_count, [&](std::size_t, std::size_t points_idx) {
byte_t const* quantized_point =
points_quantized_buffer.data() + points_idx * stride_per_vector_quantized;
byte_t const* quantized_centroids = centroids_quantized_buffer.data();
Expand Down
1 change: 1 addition & 0 deletions rust/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ void NativeIndex::change_expansion_search(size_t n) const { index_->change_expan

void NativeIndex::change_metric(uptr_t metric, uptr_t state) const {
index_->change_metric(metric_punned_t::stateful( //
index_->dimensions(), //
static_cast<std::uintptr_t>(metric), //
static_cast<std::uintptr_t>(state), //
index_->metric().metric_kind(), //
Expand Down

0 comments on commit 4fae008

Please sign in to comment.