From aa0635585c07298be0195871494713a9161dc316 Mon Sep 17 00:00:00 2001 From: Andrew Kimball Date: Fri, 22 Nov 2024 13:00:17 -0800 Subject: [PATCH] vecindex: add adaptive search to increase accuracy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adaptive search modifies the search algorithm to increase the search breadth in denser sections of the vector space and to decrease it in sparser sections. At each level of the tree, the search algorithm compiles a list of candidate partitions to search the next level down. It computes a Z-score for the candidates’ distances from the query vector, which indicates the “spread” of those distances, relative to the average. A negative Z-score indicates that partitions are more densely packed, and more should be searched. A positive Z-score indicates the opposite, and that fewer partitions should be searched. Epic: CRDB-42943 Release note: None --- pkg/cmd/vecbench/main.go | 2 +- pkg/sql/vecindex/fixup_processor.go | 2 + pkg/sql/vecindex/index_stats.go | 210 ++++++++++++++++++ pkg/sql/vecindex/testdata/insert.ddt | 12 +- pkg/sql/vecindex/testdata/search-features.ddt | 147 ++++++------ pkg/sql/vecindex/vecstore/in_memory_store.go | 49 ++-- .../vecindex/vecstore/in_memory_store_test.go | 99 ++++++--- pkg/sql/vecindex/vecstore/partition.go | 13 +- pkg/sql/vecindex/vecstore/partition_test.go | 14 +- pkg/sql/vecindex/vecstore/vecstorepb.go | 11 + pkg/sql/vecindex/vecstore/vecstorepb_test.go | 47 ++++ pkg/sql/vecindex/vector_index.go | 72 +++++- pkg/sql/vecindex/vector_index_test.go | 44 +++- 13 files changed, 568 insertions(+), 154 deletions(-) create mode 100644 pkg/sql/vecindex/index_stats.go diff --git a/pkg/cmd/vecbench/main.go b/pkg/cmd/vecbench/main.go index 3d1edea8e3ab..35ad74d50589 100644 --- a/pkg/cmd/vecbench/main.go +++ b/pkg/cmd/vecbench/main.go @@ -187,7 +187,7 @@ func searchIndex(ctx context.Context, datasetName string) { fmt.Printf("%d train vectors, %d test vectors, %d dimensions, %d/%d min/max partitions, base beam size %d\n", data.Train.Count, data.Test.Count, data.Test.Dims, indexOptions.MinPartitionSize, indexOptions.MaxPartitionSize, indexOptions.BaseBeamSize) - fmt.Println() + fmt.Println(index.FormatStats()) fmt.Printf("beam\trecall\tleaf\tall\tfull\tpartns\tqps\n") diff --git a/pkg/sql/vecindex/fixup_processor.go b/pkg/sql/vecindex/fixup_processor.go index 04c72e4a32bd..78b0b272cc61 100644 --- a/pkg/sql/vecindex/fixup_processor.go +++ b/pkg/sql/vecindex/fixup_processor.go @@ -71,6 +71,8 @@ type partitionFixupKey struct { // fixup, then that will likewise be enqueued and performed in a separate // transaction, in order to avoid contention and re-entrancy, both of which can // cause problems. +// +// All entry methods (i.e. capitalized methods) in fixupProcess are thread-safe. type fixupProcessor struct { // -------------------------------------------------- // These fields can be accessed on any goroutine once the lock is acquired. diff --git a/pkg/sql/vecindex/index_stats.go b/pkg/sql/vecindex/index_stats.go new file mode 100644 index 000000000000..7e8b561f21f0 --- /dev/null +++ b/pkg/sql/vecindex/index_stats.go @@ -0,0 +1,210 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package vecindex + +import ( + "context" + "math" + "slices" + "sync/atomic" + + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" + "gonum.org/v1/gonum/stat" +) + +// statsAlpha is the weight applied to a new sample of search distances when +// computing exponentially weighted moving averages. +const statsAlpha = 0.01 + +// statsReportingInterval specifies how many vectors need to be inserted or +// deleted before local statistics will be merged with global statistics. +const statsReportingInterval = 100 + +// statsManager maintains locally-cached statistics about the vector index that +// are used by adaptive search to improve search accuracy. Local statistics are +// updated as the index is searched during Insert and Delete operations. +// Periodically, the local statistics maintained by various processes are merged +// with global statistics that are centrally stored. +// +// All methods in statsManager are thread-safe. +type statsManager struct { + // store is used to read and update global statistics. + store vecstore.Store + + // addRemoveCount counts the number of vectors added to the index or removed + // from it since the last stats merge. + addRemoveCount atomic.Int64 + + // mu protects its fields from concurrent access on multiple goroutines. + // The lock must be acquired before using these fields. + mu struct { + syncutil.Mutex + + // stats maintains locally-updated statistics. These are periodically + // merged with global statistics. + stats vecstore.IndexStats + } +} + +// Init initializes the stats manager for use. +func (sm *statsManager) Init(ctx context.Context, store vecstore.Store) error { + sm.store = store + + // Fetch global statistics to be used as the initial starting point for local + // statistics. + err := sm.store.MergeStats(ctx, &sm.mu.stats, true /* skipMerge */) + if err != nil { + return errors.Wrap(err, "fetching starting stats") + } + return nil +} + +// Format returns the local statistics as a formatted string. +func (sm *statsManager) Format() string { + sm.mu.Lock() + defer sm.mu.Unlock() + + return sm.mu.stats.String() +} + +// OnAddOrRemoveVector is called when vectors are added to the index or removed +// from it. Every N adds/removes, local statistics are merged with global +// statistics. +func (sm *statsManager) OnAddOrRemoveVector(ctx context.Context) error { + // Determine whether to merge local statistics with global statistics. Do + // this in a separate function to avoid holding the lock during the call to + // MergeStats. + stats, shouldMerge := func() (stats vecstore.IndexStats, shouldMerge bool) { + // Determine if it's time to merge statistics. + if sm.addRemoveCount.Add(1) != statsReportingInterval { + return vecstore.IndexStats{}, false + } + + // Copy CVStats while holding the lock. + sm.mu.Lock() + defer sm.mu.Unlock() + return sm.mu.stats.Clone(), true + }() + if !shouldMerge { + return nil + } + + // Merge local stats with store stats. + err := sm.store.MergeStats(ctx, &stats, false /* skipMerge */) + if err != nil { + return errors.Wrap(err, "merging stats") + } + + // Update local stats with the merged stats, within scope of lock. + // NOTE: This will lose any updates that have been made to local stats + // during the merge. This is typically a short interval, and exact stats + // aren't necessary, so this is OK. + sm.mu.Lock() + defer sm.mu.Unlock() + sm.mu.stats = stats + sm.addRemoveCount.Store(0) + + return nil +} + +// ReportSearch returns a Z-score that is statistically correlated with the +// difficulty of the search. It measures how "spread out" search candidates are, +// in terms of distance to one another, relative to past searches at the same +// level of the K-means tree. A negative Z-score indicates that candidates were +// more bunched up than usual. This means that the search could be more +// difficult, with many good candidates scattered across many partitions. A +// positive Z-score indicates the opposite, that candidates are more spread out +// than usual - less effort is probably needed to find the best matches. +// +// If "updateStats" is true, then per-level coefficient of variation (CV) +// statistics are updated to reflect this search. CV statistics record the +// "spread" of distances at a given level of the tree and are used to calculate +// the Z-score of a particular search. +func (sm *statsManager) ReportSearch( + level vecstore.Level, squaredDistances []float64, updateStats bool, +) float64 { + sm.mu.Lock() + defer sm.mu.Unlock() + + if len(squaredDistances) < 2 { + // Not enough distances to compute stats, so return Z-score of zero. + return 0 + } + + offset := int(level - 2) + if offset < 0 { + panic(errors.AssertionFailedf("ReportSearch should never be called for the leaf level")) + } else if offset >= len(sm.mu.stats.CVStats) { + // Need to add more Z-Score levels. + sm.mu.stats.CVStats = slices.Grow(sm.mu.stats.CVStats, offset+1-len(sm.mu.stats.CVStats)) + sm.mu.stats.CVStats = sm.mu.stats.CVStats[:offset+1] + } + + return deriveZScore(&sm.mu.stats.CVStats[offset], squaredDistances, updateStats) +} + +// deriveZScore calculates the Z-score of a search, which is given by this +// formula: +// +// ZScore = (CV - Mean_CV) / StdDev_CV +// +// CV stands for coefficient of variation, and measures the normalized spread +// of distances between search candidates: +// +// CV = Mean_Distances / StdDev_Distances +// +// The Z-score compares the CV of this search with the average, normalized CV of +// previous searches. +func deriveZScore(cvstats *vecstore.CVStats, squaredDistances []float64, updateStats bool) float64 { + // Need at least 2 distance values to calculate the CV. + if len(squaredDistances) < 2 { + // Return zero Z-score, meaning no variation from the mean. + return 0 + } + + // Compute the coefficient of variation (CV) for the set of distances using + // this formula: cv = stdev / mean. CV gives the variation of values relative + // to the mean so that different distance scales are more comparable. + mean, stdev := stat.MeanStdDev(squaredDistances, nil) + if mean == 0 { + // Mean of zero could happen if all distances were zero. In this + // pathological case, just return a Z-score of zero. + return 0 + } + cv := stdev / mean + + if updateStats { + if cvstats.Mean == 0 { + // Use first CV value as initial mean. + cvstats.Mean = cv + } else { + // Calculate the exponentially weighted moving average and standard + // deviation for the last ~100 CV samples. Formulas can be found in + // the paper "Incremental calculation of weighted mean and variance": + // https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf + cvstats.Mean = cv*statsAlpha + (1-statsAlpha)*cvstats.Mean + + diff := cv - cvstats.Mean + if cvstats.Variance == 0 { + // Compute variance of first 2 CV values. + cvstats.Variance = diff * diff + } else { + incr := statsAlpha * diff + cvstats.Variance = (1 - statsAlpha) * (cvstats.Variance + diff*incr) + } + } + } + + // Calculate the Z-score. + if cvstats.Variance == 0 { + // Variance of zero could happen if all distances have been the same. In + // this pathological case, just return a Z-score of zero. + return 0 + } + return (cv - cvstats.Mean) / math.Sqrt(cvstats.Variance) +} diff --git a/pkg/sql/vecindex/testdata/insert.ddt b/pkg/sql/vecindex/testdata/insert.ddt index b06d4e601048..1141e9ef72f2 100644 --- a/pkg/sql/vecindex/testdata/insert.ddt +++ b/pkg/sql/vecindex/testdata/insert.ddt @@ -35,9 +35,9 @@ vec2: (5, 6) • 1 (0, 0) │ ├───• vec1 (1, 2) -├───• vec2 (5, 6) +├───• vec4 (4, 3) ├───• vec3 (4, 3) -└───• vec4 (4, 3) +└───• vec2 (5, 6) # Insert more vectors. insert @@ -57,8 +57,8 @@ vec9: (-2, 8) ├───• 4 (4.3333, 4) │ │ │ ├───• vec3 (4, 3) -│ ├───• vec4 (4, 3) -│ └───• vec2 (5, 6) +│ ├───• vec2 (5, 6) +│ └───• vec4 (4, 3) │ └───• 5 (0.3333, 2) │ @@ -86,8 +86,8 @@ vec2: (-5, -5) ├───• 4 (4.3333, 4) │ │ │ ├───• vec3 (4, 3) -│ ├───• vec4 (4, 3) -│ └───• vec2 (-5, -5) +│ ├───• vec2 (-5, -5) +│ └───• vec4 (4, 3) │ ├───• 6 (-1, 6) │ │ diff --git a/pkg/sql/vecindex/testdata/search-features.ddt b/pkg/sql/vecindex/testdata/search-features.ddt index 790dd1c71319..cd068aecabd5 100644 --- a/pkg/sql/vecindex/testdata/search-features.ddt +++ b/pkg/sql/vecindex/testdata/search-features.ddt @@ -4,93 +4,114 @@ new-index dims=512 min-partition-size=4 max-partition-size=16 quality-samples=8 beam-size=4 load-features=1000 hide-tree ---- Created index with 1000 vectors with 512 dimensions. +3 levels, 95 partitions, 11.91 vectors/partition. +CV stats: + level 2 - mean: 0.1250, stdev: 0.0379 + level 3 - mean: 0.1325, stdev: 0.0252 -# Start with 1 result and default beam size of 4. -search max-results=1 use-feature=5000 +# Search with small beam size. +search max-results=1 use-feature=5000 beam-size=1 ---- -vec640: 0.6525 (centroid=0.5514) -44 leaf vectors, 72 vectors, 9 full vectors, 7 partitions +vec302: 0.6601 (centroid=0.5473) +20 leaf vectors, 40 vectors, 3 full vectors, 4 partitions # Search for additional results. -search max-results=6 use-feature=5000 +search max-results=6 use-feature=5000 beam-size=1 ---- -vec640: 0.6525 (centroid=0.5514) -vec309: 0.7311 (centroid=0.552) -vec704: 0.7916 (centroid=0.5851) -vec637: 0.8039 (centroid=0.5594) -vec979: 0.8066 (centroid=0.4849) -vec246: 0.8141 (centroid=0.5458) -44 leaf vectors, 72 vectors, 14 full vectors, 7 partitions +vec302: 0.6601 (centroid=0.5473) +vec95: 0.7008 (centroid=0.5893) +vec240: 0.7723 (centroid=0.6093) +vec525: 0.8184 (centroid=0.5317) +vec202: 0.8218 (centroid=0.5217) +vec586: 0.8472 (centroid=0.5446) +20 leaf vectors, 40 vectors, 15 full vectors, 4 partitions # Use a larger beam size. -search max-results=6 use-feature=5000 beam-size=8 +search max-results=6 use-feature=5000 beam-size=4 ---- -vec356: 0.5976 (centroid=0.4578) -vec640: 0.6525 (centroid=0.5514) -vec329: 0.6871 (centroid=0.6602) -vec309: 0.7311 (centroid=0.552) -vec117: 0.7576 (centroid=0.5359) -vec25: 0.761 (centroid=0.4909) -78 leaf vectors, 130 vectors, 22 full vectors, 13 partitions +vec356: 0.5976 (centroid=0.4951) +vec302: 0.6601 (centroid=0.5473) +vec95: 0.7008 (centroid=0.5893) +vec117: 0.7576 (centroid=0.4857) +vec25: 0.761 (centroid=0.4699) +vec240: 0.7723 (centroid=0.6093) +69 leaf vectors, 121 vectors, 18 full vectors, 13 partitions # Turn off re-ranking, which results in increased inaccuracy. -search max-results=6 use-feature=5000 beam-size=8 skip-rerank +search max-results=6 use-feature=5000 beam-size=4 skip-rerank ---- -vec640: 0.6316 ±0.0382 (centroid=0.5514) -vec356: 0.6319 ±0.0288 (centroid=0.4578) -vec329: 0.707 ±0.0415 (centroid=0.6602) -vec309: 0.7518 ±0.0355 (centroid=0.552) -vec704: 0.7535 ±0.0376 (centroid=0.5851) -vec117: 0.7669 ±0.0337 (centroid=0.5359) -78 leaf vectors, 130 vectors, 0 full vectors, 13 partitions +vec356: 0.6136 ±0.033 (centroid=0.4951) +vec302: 0.6227 ±0.0358 (centroid=0.5473) +vec95: 0.6827 ±0.037 (centroid=0.5893) +vec240: 0.7161 ±0.0398 (centroid=0.6093) +vec11: 0.7594 ±0.036 (centroid=0.5305) +vec25: 0.7704 ±0.0313 (centroid=0.4699) +69 leaf vectors, 121 vectors, 0 full vectors, 13 partitions # Return top 25 results with large beam size. -search max-results=25 use-feature=5000 beam-size=32 +search max-results=25 use-feature=5000 beam-size=16 ---- -vec771: 0.5624 (centroid=0.6715) -vec356: 0.5976 (centroid=0.4578) -vec640: 0.6525 (centroid=0.5514) -vec302: 0.6601 (centroid=0.5498) -vec329: 0.6871 (centroid=0.6602) -vec95: 0.7008 (centroid=0.5807) -vec386: 0.7301 (centroid=0.5575) -vec309: 0.7311 (centroid=0.552) -vec117: 0.7576 (centroid=0.5359) -vec556: 0.7595 (centroid=0.5041) -vec25: 0.761 (centroid=0.4909) -vec776: 0.7633 (centroid=0.4385) -vec872: 0.7707 (centroid=0.5722) -vec859: 0.7708 (centroid=0.6085) -vec240: 0.7723 (centroid=0.6017) -vec347: 0.7745 (centroid=0.5306) -vec11: 0.777 (centroid=0.6096) -vec340: 0.7858 (centroid=0.5223) -vec239: 0.7878 (centroid=0.4991) -vec704: 0.7916 (centroid=0.5851) -vec423: 0.7956 (centroid=0.5476) -vec220: 0.7957 (centroid=0.4112) -vec387: 0.8038 (centroid=0.4619) -vec637: 0.8039 (centroid=0.5594) -vec410: 0.8062 (centroid=0.5024) -311 leaf vectors, 415 vectors, 88 full vectors, 42 partitions +vec771: 0.5624 (centroid=0.5931) +vec356: 0.5976 (centroid=0.4951) +vec640: 0.6525 (centroid=0.5531) +vec302: 0.6601 (centroid=0.5473) +vec329: 0.6871 (centroid=0.66) +vec95: 0.7008 (centroid=0.5893) +vec249: 0.7268 (centroid=0.4582) +vec386: 0.7301 (centroid=0.6592) +vec309: 0.7311 (centroid=0.535) +vec633: 0.7513 (centroid=0.3684) +vec117: 0.7576 (centroid=0.4857) +vec25: 0.761 (centroid=0.4699) +vec776: 0.7633 (centroid=0.5439) +vec872: 0.7707 (centroid=0.5741) +vec859: 0.7708 (centroid=0.616) +vec240: 0.7723 (centroid=0.6093) +vec347: 0.7745 (centroid=0.6182) +vec11: 0.777 (centroid=0.5305) +vec340: 0.7858 (centroid=0.6581) +vec239: 0.7878 (centroid=0.486) +vec423: 0.7956 (centroid=0.5373) +vec848: 0.7958 (centroid=0.4722) +vec387: 0.8038 (centroid=0.4303) +vec637: 0.8039 (centroid=0.505) +vec410: 0.8062 (centroid=0.5241) +340 leaf vectors, 442 vectors, 85 full vectors, 41 partitions + +# Search for an "easy" result, where adaptive search inspects less partitions. +recall topk=20 use-feature=8601 beam-size=4 +---- +60.00% recall@20 +23.00 leaf vectors, 46.00 vectors, 23.00 full vectors, 4.00 partitions + +# Search for a "hard" result, where adaptive search inspects more partitions. +recall topk=20 use-feature=2717 beam-size=4 +---- +55.00% recall@20 +79.00 leaf vectors, 135.00 vectors, 36.00 full vectors, 13.00 partitions # Test recall at different beam sizes. +recall topk=10 beam-size=2 samples=50 +---- +38.20% recall@10 +32.90 leaf vectors, 57.84 vectors, 18.14 full vectors, 5.34 partitions + recall topk=10 beam-size=4 samples=50 ---- -48.40% recall@10 -43.64 leaf vectors, 74.00 vectors, 19.76 full vectors, 7.00 partitions +61.20% recall@10 +67.84 leaf vectors, 103.70 vectors, 23.12 full vectors, 9.42 partitions recall topk=10 beam-size=8 samples=50 ---- -73.00% recall@10 -88.06 leaf vectors, 140.62 vectors, 24.82 full vectors, 13.00 partitions +82.40% recall@10 +148.26 leaf vectors, 207.98 vectors, 27.60 full vectors, 18.88 partitions recall topk=10 beam-size=16 samples=50 ---- -87.80% recall@10 -173.74 leaf vectors, 268.50 vectors, 28.26 full vectors, 25.00 partitions +94.60% recall@10 +293.04 leaf vectors, 372.96 vectors, 30.48 full vectors, 34.08 partitions recall topk=10 beam-size=32 samples=50 ---- -97.00% recall@10 -344.68 leaf vectors, 448.68 vectors, 32.10 full vectors, 42.00 partitions +99.60% recall@10 +580.84 leaf vectors, 682.84 vectors, 34.14 full vectors, 62.84 partitions diff --git a/pkg/sql/vecindex/vecstore/in_memory_store.go b/pkg/sql/vecindex/vecstore/in_memory_store.go index 20db29b5ec74..1057117d8b3c 100644 --- a/pkg/sql/vecindex/vecstore/in_memory_store.go +++ b/pkg/sql/vecindex/vecstore/in_memory_store.go @@ -18,7 +18,7 @@ import ( // storeStatsAlpha specifies the ratio of new values to existing values in EMA // calculations. -const storeStatsAlpha = 0.1 +const storeStatsAlpha = 0.05 // lockType specifies the type of lock that transactions have acquired. type lockType int @@ -156,6 +156,7 @@ func (s *InMemoryStore) SetRootPartition(ctx context.Context, txn Txn, partition if !ok { s.mu.stats.NumPartitions++ } + s.reportPartitionSizeLocked(partition.Count()) // Grow or shrink CVStats slice if a new level is being added or removed. expectedLevels := int(partition.Level() - 1) @@ -181,6 +182,7 @@ func (s *InMemoryStore) InsertPartition( s.mu.nextKey++ s.mu.index[partitionKey] = partition s.mu.stats.NumPartitions++ + s.reportPartitionSizeLocked(partition.Count()) return partitionKey, nil } @@ -215,10 +217,12 @@ func (s *InMemoryStore) AddToPartition( if !ok { return 0, ErrPartitionNotFound } + if partition.Add(ctx, vector, childKey) { - return partition.Count(), nil + s.reportPartitionSizeLocked(partition.Count()) } - return -1, nil + + return partition.Count(), nil } // RemoveFromPartition implements the Store interface. @@ -235,19 +239,18 @@ func (s *InMemoryStore) RemoveFromPartition( return 0, ErrPartitionNotFound } - if !partition.ReplaceWithLastByKey(childKey) { - // Key cannot be found. - return -1, nil + if partition.ReplaceWithLastByKey(childKey) { + s.reportPartitionSizeLocked(partition.Count()) } - count := partition.Count() - if count == 0 && partition.Level() > LeafLevel { + if partition.Count() == 0 && partition.Level() > LeafLevel { // A non-leaf partition has zero vectors. If this is still true at the // end of the transaction, the K-means tree will be unbalanced, which // violates a key constraint. txn.(*inMemoryTxn).unbalancedKey = partitionKey } - return count, nil + + return partition.Count(), nil } // SearchPartitions implements the Store interface. @@ -319,15 +322,6 @@ func (s *InMemoryStore) MergeStats(ctx context.Context, stats *IndexStats, skipM defer s.mu.Unlock() if !skipMerge { - // Merge VectorsPerPartition. - if s.mu.stats.VectorsPerPartition == 0 { - // Use first value if this is the first update. - s.mu.stats.VectorsPerPartition = stats.VectorsPerPartition - } else { - s.mu.stats.VectorsPerPartition = (1 - storeStatsAlpha) * s.mu.stats.VectorsPerPartition - s.mu.stats.VectorsPerPartition += stats.VectorsPerPartition * storeStatsAlpha - } - // Merge CVStats. for i := range stats.CVStats { if i >= len(s.mu.stats.CVStats) { @@ -342,8 +336,8 @@ func (s *InMemoryStore) MergeStats(ctx context.Context, stats *IndexStats, skipM if cvstats.Mean == 0 { cvstats.Mean = sample.Mean } else { - cvstats.Mean = sample.Mean*storeStatsAlpha + (1-storeStatsAlpha)*cvstats.Mean - cvstats.Variance = sample.Variance*storeStatsAlpha + (1-storeStatsAlpha)*cvstats.Variance + cvstats.Mean = storeStatsAlpha*sample.Mean + (1-storeStatsAlpha)*cvstats.Mean + cvstats.Variance = storeStatsAlpha*sample.Variance + (1-storeStatsAlpha)*cvstats.Variance } } } @@ -499,6 +493,21 @@ func (s *InMemoryStore) UnmarshalBinary(data []byte) error { return nil } +// reportPartitionSizeLocked updates the vectors per partition statistic. It is +// called with the count of vectors in a partition when a partition is inserted +// or updated. +// NOTE: Callers must have acquired the s.mu lock before calling. +func (s *InMemoryStore) reportPartitionSizeLocked(count int) { + if s.mu.stats.VectorsPerPartition == 0 { + // Use first value if this is the first update. + s.mu.stats.VectorsPerPartition = float64(count) + } else { + // Calculate exponential moving average. + s.mu.stats.VectorsPerPartition = (1 - storeStatsAlpha) * s.mu.stats.VectorsPerPartition + s.mu.stats.VectorsPerPartition += storeStatsAlpha * float64(count) + } +} + // acquireTxnLock acquires a data or partition lock within the scope of the // given transaction. It is an error to attempt to acquire a partition lock if // the transaction already holds a data lock. diff --git a/pkg/sql/vecindex/vecstore/in_memory_store_test.go b/pkg/sql/vecindex/vecstore/in_memory_store_test.go index 1edbb20da3cf..5df8a599b696 100644 --- a/pkg/sql/vecindex/vecstore/in_memory_store_test.go +++ b/pkg/sql/vecindex/vecstore/in_memory_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/vecindex/quantize" "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/stretchr/testify/require" + "gonum.org/v1/gonum/floats/scalar" ) func TestInMemoryStore(t *testing.T) { @@ -99,10 +100,10 @@ func TestInMemoryStore(t *testing.T) { require.NoError(t, err) require.Equal(t, 3, count) - // Try to add duplicate. + // Add duplicate and expect value to be overwritten count, err = store.AddToPartition(ctx, txn, RootKey, vector.T{5, 5}, childKey30) require.NoError(t, err) - require.Equal(t, -1, count) + require.Equal(t, 3, count) // Search root partition. searchSet := SearchSet{MaxResults: 2} @@ -112,7 +113,7 @@ func TestInMemoryStore(t *testing.T) { require.NoError(t, err) require.Equal(t, Level(1), level) result1 := SearchResult{QuerySquaredDistance: 1, ErrorBound: 0, CentroidDistance: 2.2361, ParentPartitionKey: 1, ChildKey: childKey10} - result2 := SearchResult{QuerySquaredDistance: 13, ErrorBound: 0, CentroidDistance: 5, ParentPartitionKey: 1, ChildKey: childKey30} + result2 := SearchResult{QuerySquaredDistance: 32, ErrorBound: 0, CentroidDistance: 7.0711, ParentPartitionKey: 1, ChildKey: childKey30} results := searchSet.PopResults() roundResults(results, 4) require.Equal(t, SearchResults{result1, result2}, results) @@ -182,17 +183,19 @@ func TestInMemoryStore(t *testing.T) { count, err := store.RemoveFromPartition(ctx, txn, partitionKey1, childKey20) require.NoError(t, err) require.Equal(t, 2, count) + + // Try to remove the same key again. count, err = store.RemoveFromPartition(ctx, txn, partitionKey1, childKey20) require.NoError(t, err) - require.Equal(t, -1, count) + require.Equal(t, 2, count) - // Add an alternate element and try to add duplicate. + // Add an alternate element and add duplicate, expecting value to be overwritten. count, err = store.AddToPartition(ctx, txn, partitionKey1, vector.T{-1, 0}, childKey40) require.NoError(t, err) require.Equal(t, 3, count) count, err = store.AddToPartition(ctx, txn, partitionKey1, vector.T{1, 1}, childKey40) require.NoError(t, err) - require.Equal(t, -1, count) + require.Equal(t, 3, count) searchSet := SearchSet{MaxResults: 2} partitionCounts := []int{0} @@ -200,9 +203,9 @@ func TestInMemoryStore(t *testing.T) { ctx, txn, []PartitionKey{partitionKey1}, vector.T{1, 1}, &searchSet, partitionCounts) require.NoError(t, err) require.Equal(t, Level(1), level) - result4 := SearchResult{QuerySquaredDistance: 1, ErrorBound: 0, CentroidDistance: 2.23606797749979, ParentPartitionKey: 2, ChildKey: childKey10} - result5 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 1, ParentPartitionKey: 2, ChildKey: childKey40} - require.Equal(t, SearchResults{result4, result5}, searchSet.PopResults()) + result4 := SearchResult{QuerySquaredDistance: 0, ErrorBound: 0, CentroidDistance: 1.4142, ParentPartitionKey: 2, ChildKey: childKey40} + result5 := SearchResult{QuerySquaredDistance: 1, ErrorBound: 0, CentroidDistance: 2.2361, ParentPartitionKey: 2, ChildKey: childKey10} + require.Equal(t, SearchResults{result4, result5}, roundResults(searchSet.PopResults(), 4)) require.Equal(t, 3, partitionCounts[0]) }) @@ -225,8 +228,8 @@ func TestInMemoryStore(t *testing.T) { ctx, txn, []PartitionKey{partitionKey1, partitionKey2}, vector.T{3, 1}, &searchSet, partitionCounts) require.NoError(t, err) require.Equal(t, Level(1), level) - result4 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 2.24, ParentPartitionKey: 2, ChildKey: childKey10} - result5 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 5, ParentPartitionKey: 2, ChildKey: childKey30} + result4 := SearchResult{QuerySquaredDistance: 4, ErrorBound: 0, CentroidDistance: 1.41, ParentPartitionKey: 2, ChildKey: childKey40} + result5 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 2.24, ParentPartitionKey: 2, ChildKey: childKey10} require.Equal(t, SearchResults{result4, result5}, roundResults(searchSet.PopResults(), 2)) require.Equal(t, []int{3, 2}, partitionCounts) }) @@ -318,52 +321,74 @@ func TestInMemoryStoreUpdateStats(t *testing.T) { txn := beginTransaction(ctx, t, store) defer commitTransaction(ctx, t, store, txn) - vectors := vector.MakeSet(2) + childKey10 := ChildKey{PartitionKey: 10} + childKey20 := ChildKey{PartitionKey: 20} + childKey30 := ChildKey{PartitionKey: 30} + childKey40 := ChildKey{PartitionKey: 40} + + vectors := vector.MakeSetFromRawData([]float32{1, 2, 3, 4}, 2) quantizedSet := quantizer.Quantize(ctx, &vectors) - root := NewPartition(quantizer, quantizedSet, []ChildKey{}, LeafLevel) + root := NewPartition(quantizer, quantizedSet, []ChildKey{childKey10, childKey20}, LeafLevel) require.NoError(t, store.SetRootPartition(ctx, txn, root)) // Update stats. - stats := IndexStats{ - VectorsPerPartition: 5, - CVStats: []CVStats{{Mean: 1.5, Variance: 0.5}, {Mean: 1, Variance: 0.25}}, - } + stats := IndexStats{CVStats: []CVStats{{Mean: 1.5, Variance: 0.5}, {Mean: 1, Variance: 0.25}}} err := store.MergeStats(ctx, &stats, false /* skipMerge */) require.NoError(t, err) require.Equal(t, int64(1), stats.NumPartitions) - require.Equal(t, float64(5), stats.VectorsPerPartition) + require.Equal(t, float64(2), stats.VectorsPerPartition) require.Equal(t, []CVStats{}, stats.CVStats) // Upsert new root partition with higher level and check stats. root.level = 3 require.NoError(t, store.SetRootPartition(ctx, txn, root)) - stats.VectorsPerPartition = 10 stats.CVStats = []CVStats{{Mean: 2.5, Variance: 0.5}, {Mean: 1, Variance: 0.25}} err = store.MergeStats(ctx, &stats, false /* skipMerge */) require.NoError(t, err) require.Equal(t, int64(1), stats.NumPartitions) - require.Equal(t, float64(5.5), stats.VectorsPerPartition) - require.Equal(t, []CVStats{{Mean: 2.5, Variance: 0}, {Mean: 1, Variance: 0}}, stats.CVStats) + require.Equal(t, float64(2), stats.VectorsPerPartition) + require.Equal(t, []CVStats{{Mean: 2.5, Variance: 0}, {Mean: 1, Variance: 0}}, roundCVStats(stats.CVStats)) + + // Insert new partition with lower level and check stats. + vectors = vector.MakeSetFromRawData([]float32{5, 6}, 2) + quantizedSet = quantizer.Quantize(ctx, &vectors) + partition := NewPartition(quantizer, quantizedSet, []ChildKey{childKey30}, 2) + partitionKey, err := store.InsertPartition(ctx, txn, partition) + require.NoError(t, err) - // Upsert new root partition with lower level and check stats. - root.level = 2 - require.NoError(t, store.SetRootPartition(ctx, txn, root)) - stats.VectorsPerPartition = 20 - stats.CVStats = []CVStats{{Mean: 2.5, Variance: 0.5}, {Mean: 1, Variance: 0.25}} + stats.CVStats = []CVStats{{Mean: 8, Variance: 2}, {Mean: 6, Variance: 1}} err = store.MergeStats(ctx, &stats, false /* skipMerge */) + require.Equal(t, int64(2), stats.NumPartitions) + require.Equal(t, float64(1.95), stats.VectorsPerPartition) + require.Equal(t, []CVStats{{Mean: 2.775, Variance: 0.1}, {Mean: 1.25, Variance: 0.05}}, roundCVStats(stats.CVStats)) + + // Add vector to partition and check stats. + _, err = store.AddToPartition(ctx, txn, partitionKey, vector.T{7, 8}, childKey40) + require.NoError(t, err) + + stats.CVStats = []CVStats{{Mean: 3, Variance: 1}, {Mean: 1.5, Variance: 0.5}} + err = store.MergeStats(ctx, &stats, false /* skipMerge */) + require.Equal(t, int64(2), stats.NumPartitions) + require.Equal(t, float64(1.9525), stats.VectorsPerPartition) + require.Equal(t, []CVStats{{Mean: 2.7863, Variance: 0.145}, {Mean: 1.2625, Variance: 0.0725}}, roundCVStats(stats.CVStats)) + + // Remove vector from partition and check stats. + _, err = store.RemoveFromPartition(ctx, txn, partitionKey, childKey30) require.NoError(t, err) - require.Equal(t, int64(1), stats.NumPartitions) - require.Equal(t, float64(6.95), stats.VectorsPerPartition) - require.Equal(t, []CVStats{{Mean: 2.5, Variance: 0.05}}, stats.CVStats) + + stats.CVStats = []CVStats{{Mean: 5, Variance: 2}, {Mean: 3, Variance: 1.5}} + err = store.MergeStats(ctx, &stats, false /* skipMerge */) + require.Equal(t, int64(2), stats.NumPartitions) + require.Equal(t, float64(1.9049), scalar.Round(stats.VectorsPerPartition, 4)) + require.Equal(t, []CVStats{{Mean: 2.8969, Variance: 0.2378}, {Mean: 1.3494, Variance: 0.1439}}, roundCVStats(stats.CVStats)) // skipMerge = true. - stats.VectorsPerPartition = 100 stats.CVStats = []CVStats{{Mean: 10, Variance: 2}} err = store.MergeStats(ctx, &stats, true /* skipMerge */) require.NoError(t, err) - require.Equal(t, int64(1), stats.NumPartitions) - require.Equal(t, float64(6.95), stats.VectorsPerPartition) - require.Equal(t, []CVStats{{Mean: 2.5, Variance: 0.05}}, stats.CVStats) + require.Equal(t, int64(2), stats.NumPartitions) + require.Equal(t, float64(1.9049), scalar.Round(stats.VectorsPerPartition, 4)) + require.Equal(t, []CVStats{{Mean: 2.8969, Variance: 0.2378}, {Mean: 1.3494, Variance: 0.1439}}, roundCVStats(stats.CVStats)) } func TestInMemoryStoreMarshalling(t *testing.T) { @@ -449,3 +474,11 @@ func abortTransaction(ctx context.Context, t *testing.T, store Store, txn Txn) { err := store.AbortTransaction(ctx, txn) require.NoError(t, err) } + +func roundCVStats(cvstats []CVStats) []CVStats { + for i := range cvstats { + cvstats[i].Mean = scalar.Round(cvstats[i].Mean, 4) + cvstats[i].Variance = scalar.Round(cvstats[i].Variance, 4) + } + return cvstats +} diff --git a/pkg/sql/vecindex/vecstore/partition.go b/pkg/sql/vecindex/vecstore/partition.go index 208d06abff5a..df44d518e072 100644 --- a/pkg/sql/vecindex/vecstore/partition.go +++ b/pkg/sql/vecindex/vecstore/partition.go @@ -147,17 +147,20 @@ func (p *Partition) Search( return p.level, count } -// Add quantizes the given vector as part of this partition. It returns false if -// the vector is already in the partition. +// Add quantizes the given vector as part of this partition. If a vector with +// the same key is already in the partition, update its value and return false. func (p *Partition) Add(ctx context.Context, vector vector.T, childKey ChildKey) bool { - if p.Find(childKey) != -1 { - return false + offset := p.Find(childKey) + if offset != -1 { + // Remove the vector from the partition and re-add it below. + p.ReplaceWithLast(offset) } vectorSet := vector.AsSet() p.quantizer.QuantizeInSet(ctx, p.quantizedSet, &vectorSet) p.childKeys = append(p.childKeys, childKey) - return true + + return offset == -1 } // ReplaceWithLast removes the quantized vector at the given offset from the diff --git a/pkg/sql/vecindex/vecstore/partition_test.go b/pkg/sql/vecindex/vecstore/partition_test.go index b8ba6eb0de38..54bc60f9c67f 100644 --- a/pkg/sql/vecindex/vecstore/partition_test.go +++ b/pkg/sql/vecindex/vecstore/partition_test.go @@ -34,10 +34,11 @@ func TestPartition(t *testing.T) { partition := NewPartition(quantizer, quantizedSet, childKeys, 1) require.True(t, partition.Add(ctx, vector.T{4, 3}, childKey40)) - // Try to add duplicate vector. + // Add vector and expect its value to be updated. require.False(t, partition.Add(ctx, vector.T{10, 10}, childKey20)) + require.Equal(t, 4, partition.Count()) - require.Equal(t, []ChildKey{childKey10, childKey20, childKey30, childKey40}, partition.ChildKeys()) + require.Equal(t, []ChildKey{childKey10, childKey40, childKey30, childKey20}, partition.ChildKeys()) require.Equal(t, []float32{4, 3.33}, roundFloats(partition.Centroid(), 2)) // Ensure that cloning does not disturb anything. @@ -51,13 +52,14 @@ func TestPartition(t *testing.T) { require.Equal(t, 4, count) result1 := SearchResult{QuerySquaredDistance: 1, ErrorBound: 0, CentroidDistance: 3.2830, ParentPartitionKey: 1, ChildKey: childKey10} result2 := SearchResult{QuerySquaredDistance: 13, ErrorBound: 0, CentroidDistance: 0.3333, ParentPartitionKey: 1, ChildKey: childKey40} - result3 := SearchResult{QuerySquaredDistance: 17, ErrorBound: 0, CentroidDistance: 1.6667, ParentPartitionKey: 1, ChildKey: childKey20} + result3 := SearchResult{QuerySquaredDistance: 50, ErrorBound: 0, CentroidDistance: 3.3333, ParentPartitionKey: 1, ChildKey: childKey30} results := roundResults(searchSet.PopResults(), 4) require.Equal(t, SearchResults{result1, result2, result3}, results) // Find method. require.Equal(t, 2, partition.Find(childKey30)) - require.Equal(t, 3, partition.Find(childKey40)) + require.Equal(t, 1, partition.Find(childKey40)) + require.Equal(t, 3, partition.Find(childKey20)) require.Equal(t, -1, partition.Find(ChildKey{PrimaryKey: []byte{1, 2}})) // Remove vectors. @@ -74,11 +76,11 @@ func TestPartition(t *testing.T) { // Check that clone is unaffected. require.Equal(t, 5, cloned.Count()) require.Equal(t, Level(1), cloned.Level()) - require.Equal(t, []ChildKey{childKey10, childKey20, childKey30, childKey40, childKey50}, cloned.ChildKeys()) + require.Equal(t, []ChildKey{childKey10, childKey40, childKey30, childKey20, childKey50}, cloned.ChildKeys()) squaredDistances := []float32{0, 0, 0, 0, 0} errorBounds := []float32{0, 0, 0, 0, 0} cloned.Quantizer().EstimateSquaredDistances(ctx, cloned.QuantizedSet(), vector.T{3, 4}, squaredDistances, errorBounds) - require.Equal(t, []float32{8, 8, 13, 2, 34}, squaredDistances) + require.Equal(t, []float32{8, 2, 13, 85, 34}, squaredDistances) } func roundResults(results SearchResults, prec int) SearchResults { diff --git a/pkg/sql/vecindex/vecstore/vecstorepb.go b/pkg/sql/vecindex/vecstore/vecstorepb.go index 314f7ed942b8..70709e2524f4 100644 --- a/pkg/sql/vecindex/vecstore/vecstorepb.go +++ b/pkg/sql/vecindex/vecstore/vecstorepb.go @@ -9,10 +9,21 @@ import ( "bytes" "fmt" "math" + "slices" _ "github.com/gogo/protobuf/gogoproto" ) +// Clone returns a deep copy of the stats. Changes to the original or clone do +// not affect the other. +func (s *IndexStats) Clone() IndexStats { + return IndexStats{ + NumPartitions: s.NumPartitions, + VectorsPerPartition: s.VectorsPerPartition, + CVStats: slices.Clone(s.CVStats), + } +} + // String returns a human-readable representation of the index stats. func (s *IndexStats) String() string { var buf bytes.Buffer diff --git a/pkg/sql/vecindex/vecstore/vecstorepb_test.go b/pkg/sql/vecindex/vecstore/vecstorepb_test.go index 2b726753b850..f38e2ea894f5 100644 --- a/pkg/sql/vecindex/vecstore/vecstorepb_test.go +++ b/pkg/sql/vecindex/vecstore/vecstorepb_test.go @@ -6,11 +6,58 @@ package vecstore import ( + "strings" "testing" "github.com/stretchr/testify/require" ) +func TestIndexStats(t *testing.T) { + // Empty stats. + stats := IndexStats{} + require.Equal(t, strings.TrimSpace(` +1 levels, 0 partitions, 0.00 vectors/partition. +CV stats: +`), strings.TrimSpace(stats.String())) + + stats = IndexStats{ + NumPartitions: 100, + VectorsPerPartition: 32.59, + CVStats: []CVStats{ + {Mean: 10.12, Variance: 2.13}, + {Mean: 18.42, Variance: 3.87}, + }, + } + require.Equal(t, strings.TrimSpace(` +3 levels, 100 partitions, 32.59 vectors/partition. +CV stats: + level 2 - mean: 10.1200, stdev: 1.4595 + level 3 - mean: 18.4200, stdev: 1.9672 +`), strings.TrimSpace(stats.String())) + + // Clone method. + cloned := stats.Clone() + stats.NumPartitions = 50 + stats.VectorsPerPartition = 16 + stats.CVStats[0].Mean = 100 + stats.CVStats[1].Variance = 100 + + // Ensure that original and clone can be updated independently. + require.Equal(t, strings.TrimSpace(` +3 levels, 50 partitions, 16.00 vectors/partition. +CV stats: + level 2 - mean: 100.0000, stdev: 1.4595 + level 3 - mean: 18.4200, stdev: 10.0000 +`), strings.TrimSpace(stats.String())) + + require.Equal(t, strings.TrimSpace(` +3 levels, 100 partitions, 32.59 vectors/partition. +CV stats: + level 2 - mean: 10.1200, stdev: 1.4595 + level 3 - mean: 18.4200, stdev: 1.9672 +`), strings.TrimSpace(cloned.String())) +} + func TestChildKey(t *testing.T) { childKey1 := ChildKey{PartitionKey: 10} childKey2 := ChildKey{PartitionKey: 20} diff --git a/pkg/sql/vecindex/vector_index.go b/pkg/sql/vecindex/vector_index.go index 2c3d5d5f8fd2..709f08c57dbf 100644 --- a/pkg/sql/vecindex/vector_index.go +++ b/pkg/sql/vecindex/vector_index.go @@ -29,6 +29,9 @@ const RerankMultiplier = 10 // order to account for vectors that may have been deleted in the primary index. const DeletedMultiplier = 1.2 +// MaxQualitySamples specifies the max value of the QualitySamples index option. +const MaxQualitySamples = 32 + // VectorIndexOptions specifies options that control how the index will be // built, as well as default options for how it will be searched. A given search // operation can specify SearchOptions to override the default behavior. @@ -76,6 +79,9 @@ type SearchOptions struct { // in search results. If this is a leaf-level search then the returned // vectors have not been randomized. ReturnVectors bool + // UpdateStats specifies whether index statistics will be modified by this + // search. These stats are used for adaptive search. + UpdateStats bool } // searchContext contains per-thread state needed during index search @@ -102,6 +108,7 @@ type searchContext struct { Randomized vector.T tempResults [1]vecstore.SearchResult + tempQualitySamples [MaxQualitySamples]float64 tempKeys []vecstore.PartitionKey tempCounts []int tempVectorsWithKeys []vecstore.VectorWithKey @@ -132,6 +139,9 @@ type VectorIndex struct { fixups fixupProcessor // cancel stops the background fixup processing goroutine. cancel func() + // stats maintains locally-cached statistics about the vector index that are + // used by adaptive search to improve search accuracy. + stats statsManager } // NewVectorIndex constructs a new vector index instance. Typically, only one @@ -164,12 +174,21 @@ func NewVectorIndex( if vi.options.QualitySamples == 0 { vi.options.QualitySamples = 16 } + if vi.options.MaxPartitionSize < 2 { return nil, errors.AssertionFailedf("MaxPartitionSize cannot be less than 2") } + if vi.options.QualitySamples > MaxQualitySamples { + return nil, errors.Errorf( + "QualitySamples option %d exceeds max allowed value", vi.options.QualitySamples) + } vi.fixups.Init(vi, options.Seed) + if err := vi.stats.Init(ctx, store); err != nil { + return nil, err + } + if stopper != nil { // Start the background goroutine. ctx, vi.cancel = stopper.WithCancelOnQuiesce(ctx) @@ -182,6 +201,17 @@ func NewVectorIndex( return vi, nil } +// Options returns the options that specify how the index should be built and +// searched. +func (vi *VectorIndex) Options() VectorIndexOptions { + return vi.options +} + +// FormatStats returns index statistics as a formatted string. +func (vi *VectorIndex) FormatStats() string { + return vi.stats.Format() +} + // Close cancels the background goroutine, if it's running. func (vi *VectorIndex) Close() { if vi.cancel != nil { @@ -232,6 +262,7 @@ func (vi *VectorIndex) Insert( Options: SearchOptions{ BaseBeamSize: vi.options.BaseBeamSize, SkipRerank: vi.options.DisableErrorBounds, + UpdateStats: true, }, } parentSearchCtx.Ctx = internal.WithWorkspace(ctx, &parentSearchCtx.Workspace) @@ -264,7 +295,8 @@ func (vi *VectorIndex) Delete( Original: vector, Level: vecstore.LeafLevel, Options: SearchOptions{ - SkipRerank: vi.options.DisableErrorBounds, + SkipRerank: vi.options.DisableErrorBounds, + UpdateStats: true, }, } searchCtx.Ctx = internal.WithWorkspace(ctx, &searchCtx.Workspace) @@ -379,12 +411,14 @@ func (vi *VectorIndex) addToPartition( ) (int, error) { count, err := vi.store.AddToPartition(ctx, txn, partitionKey, vector, childKey) if err != nil { - return 0, err + return 0, errors.Wrapf(err, "adding vector to partition %d", partitionKey) } if count > vi.options.MaxPartitionSize { vi.fixups.AddSplit(ctx, parentPartitionKey, partitionKey) } - return count, nil + + vi.stats.OnAddOrRemoveVector(ctx) + return count, err } // removeFromPartition calls the store to remove a vector, by its key, from an @@ -395,7 +429,13 @@ func (vi *VectorIndex) removeFromPartition( partitionKey vecstore.PartitionKey, childKey vecstore.ChildKey, ) (int, error) { - return vi.store.RemoveFromPartition(ctx, txn, partitionKey, childKey) + count, err := vi.store.RemoveFromPartition(ctx, txn, partitionKey, childKey) + if err != nil { + return 0, errors.Wrapf(err, "removing vector from partition %d", partitionKey) + } + + vi.stats.OnAddOrRemoveVector(ctx) + return count, err } // searchHelper contains the core search logic for the K-means tree. It begins @@ -451,9 +491,19 @@ func (vi *VectorIndex) searchHelper( var zscore float64 if searchLevel > vecstore.LeafLevel { - // Compute the z-score of the candidate results list. - // TODO(andyk): Track z-score stats. - zscore = 0 + // Results need to be sorted in order to calculate their "spread". This + // also sorts them for determining which partitions to search next. + results.Sort() + + // Compute the Z-score of the candidate list if there are enough + // samples. Otherwise, use the default Z-score of 0. + if len(results) >= vi.options.QualitySamples { + for i := 0; i < vi.options.QualitySamples; i++ { + searchCtx.tempQualitySamples[i] = float64(results[i].QuerySquaredDistance) + } + samples := searchCtx.tempQualitySamples[:vi.options.QualitySamples] + zscore = vi.stats.ReportSearch(searchLevel, samples, searchCtx.Options.UpdateStats) + } } if searchLevel <= searchCtx.Level { @@ -493,7 +543,7 @@ func (vi *VectorIndex) searchHelper( // more densely packed are the vectors, and the more partitions they're // likely to be spread across. tempBeamSize := float64(beamSize) * math.Pow(2, -zscore) - tempBeamSize = max(min(tempBeamSize, float64(beamSize)*4), float64(beamSize)/2) + tempBeamSize = max(min(tempBeamSize, float64(beamSize)*2), float64(beamSize)/2) if searchLevel > vecstore.LeafLevel+1 { // Use progressively smaller beam size for higher levels, since @@ -519,10 +569,14 @@ func (vi *VectorIndex) searchHelper( subSearchSet.MaxResults = searchSet.MaxResults * RerankMultiplier / 2 subSearchSet.MaxExtraResults = 0 } + + if searchLevel > vecstore.LeafLevel { + // Ensure there are enough results for calculating stats. + subSearchSet.MaxResults = max(subSearchSet.MaxResults, vi.options.QualitySamples) + } } // Search up to beamSize child partitions. - results.Sort() results = results[:min(beamSize, len(results))] _, err = vi.searchChildPartitions(searchCtx, &subSearchSet, results) if errors.Is(err, vecstore.ErrPartitionNotFound) { diff --git a/pkg/sql/vecindex/vector_index_test.go b/pkg/sql/vecindex/vector_index_test.go index c4fb6796301c..ab1a7379d3a0 100644 --- a/pkg/sql/vecindex/vector_index_test.go +++ b/pkg/sql/vecindex/vector_index_test.go @@ -301,8 +301,13 @@ func (s *testState) Insert(d *datadriven.TestData) string { require.NoError(s.T, s.runAllFixups()) if hideTree { - return fmt.Sprintf("Created index with %d vectors with %d dimensions.\n", + str := fmt.Sprintf("Created index with %d vectors with %d dimensions.\n", vectors.Count, vectors.Dims) + if s.Index.cancel == nil { + // Only show stats when building the index is deterministic. + str += s.Index.FormatStats() + } + return str } return s.FormatTree(d) @@ -368,13 +373,22 @@ func (s *testState) Delete(d *datadriven.TestData) string { func (s *testState) Recall(d *datadriven.TestData) string { searchSet := vecstore.SearchSet{MaxResults: 1} options := SearchOptions{} - samples := 50 + numSamples := 50 + var samples []int var err error for _, arg := range d.CmdArgs { switch arg.Key { + case "use-feature": + // Use single designated sample. + require.Len(s.T, arg.Vals, 1) + offset, err := strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + numSamples = 1 + samples = []int{offset} + case "samples": require.Len(s.T, arg.Vals, 1) - samples, err = strconv.Atoi(arg.Vals[0]) + numSamples, err = strconv.Atoi(arg.Vals[0]) require.NoError(s.T, err) case "topk": @@ -389,6 +403,14 @@ func (s *testState) Recall(d *datadriven.TestData) string { } } + // Construct list of feature offsets. + if samples == nil { + samples = make([]int, numSamples) + for i := range samples { + samples[i] = s.Features.Count - numSamples + i + } + } + txn := beginTransaction(s.Ctx, s.T, s.InMemStore) defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn) @@ -417,11 +439,11 @@ func (s *testState) Recall(d *datadriven.TestData) string { data := s.InMemStore.GetAllVectors() - // Search for last "samples" features. + // Search for sampled features. var sumMAP float64 - for feature := s.Features.Count - samples; feature < s.Features.Count; feature++ { + for i := range samples { // Calculate truth set for the vector. - queryVector := s.Features.At(feature) + queryVector := s.Features.At(samples[i]) truth := calcTruth(queryVector, data) // Calculate prediction set for the vector. @@ -437,11 +459,11 @@ func (s *testState) Recall(d *datadriven.TestData) string { sumMAP += findMAP(prediction, truth) } - recall := sumMAP / float64(samples) * 100 - quantizedLeafVectors := float64(searchSet.Stats.QuantizedLeafVectorCount) / float64(samples) - quantizedVectors := float64(searchSet.Stats.QuantizedVectorCount) / float64(samples) - fullVectors := float64(searchSet.Stats.FullVectorCount) / float64(samples) - partitions := float64(searchSet.Stats.PartitionCount) / float64(samples) + recall := sumMAP / float64(numSamples) * 100 + quantizedLeafVectors := float64(searchSet.Stats.QuantizedLeafVectorCount) / float64(numSamples) + quantizedVectors := float64(searchSet.Stats.QuantizedVectorCount) / float64(numSamples) + fullVectors := float64(searchSet.Stats.FullVectorCount) / float64(numSamples) + partitions := float64(searchSet.Stats.PartitionCount) / float64(numSamples) var buf bytes.Buffer buf.WriteString(fmt.Sprintf("%.2f%% recall@%d\n", recall, searchSet.MaxResults))