-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
157 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
using SimilaritySearch, SurrogatedDistanceModels, JLD2, CSV, Glob, LinearAlgebra, Dates | ||
|
||
include("common.jl") | ||
|
||
# This file is based in the julia's example of the 2023 edition | ||
|
||
""" | ||
load_database(file) | ||
Loads a dataset stored in `file`n | ||
""" | ||
function load_database(file) | ||
@info "loading clip768 (converting Float16 -> Float32)" | ||
X = jldopen(file) do f | ||
Matrix{Float32}(f["emb"]) | ||
end | ||
|
||
#=for col in eachcol(X) | ||
normalize!(col) | ||
end=# | ||
|
||
StrideMatrixDatabase(X) | ||
end | ||
|
||
function run_search_task1(idx, queries::AbstractDatabase, k::Integer, meta, resfile_::AbstractString) | ||
resfile_ = replace(resfile_, ".h5" => "") | ||
resfile = "$resfile_.h5" | ||
@info "searching $resfile" | ||
querytime = @elapsed knns, dists = searchbatch(idx, queries, k) | ||
meta["querytime"] = querytime | ||
save_results(knns, dists, meta, resfile) | ||
end | ||
|
||
""" | ||
task1(; kwargs...) | ||
Runs an entire beenchmark | ||
- `dbsize`: string denoting the size of the dataset (e.g., "300K", "100M"), million scale should not be used in GitHub Actions. | ||
- `k`: the number of neighbors to find | ||
""" | ||
function task1(; | ||
dbsize, | ||
dfile="data2024/laion2B-en-clip768v2-n=$dbsize.h5", | ||
#qfile="data2024/public-queries-2024-laion2B-en-clip768v2-n=10k.h5", | ||
qfile="data2024/private-queries-2024-laion2B-en-clip768v2-n=10k-epsilon=0.2.h5", | ||
k=30, | ||
outdir="results-task1-bruteforce/$dbsize/$(Dates.format(Dates.now(), "yyyymmdd-HHMMSS"))" | ||
) | ||
|
||
mkpath(outdir) | ||
|
||
dist = NormalizedCosineDistance() # 1 - dot(·, ·) | ||
@info "loading $qfile and $dfile" | ||
@time db = load_database(dfile) | ||
@time queries = load_database(qfile) | ||
|
||
# loading or computing knns | ||
@info "indexing, this can take a while!" | ||
G = ExhaustiveSearch(; dist, db) | ||
meta = Dict( | ||
"buildtime" => 0.0, | ||
"matrix_size" => size(db.matrix), | ||
"optimtime" => 0.0, | ||
"algo" => "Bruteforce", | ||
"params" => "" | ||
) | ||
meta["size"] = dbsize | ||
meta["modelingtime"] = 0.0 | ||
meta["encdatabasetime"] = 0.0 | ||
meta["encqueriestime"] = 0.0 | ||
meta["buildtime"] = 0.0 | ||
resfile = joinpath(outdir, "bruteforce-k=$k") | ||
run_search_task1(G, queries, k, meta, resfile) | ||
end | ||
|
||
if !isinteractive() | ||
if length(ARGS) == 0 || any(dbsize -> dbsize ∉ ("300K", "10M", "100M"), ARGS) | ||
throw(ArgumentError("this script must be called with a list of the following arguments: 300K, 10M or 100M")) | ||
end | ||
|
||
for dbsize in ARGS | ||
task1(; dbsize) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
using SimilaritySearch, SurrogatedDistanceModels, HDF5, JLD2, CSV, Glob, LinearAlgebra, Dates | ||
|
||
include("common.jl") | ||
|
||
# This file is based in the julia's example of the 2023 edition | ||
|
||
|
||
function run_search_task3(idx, queries::AbstractDatabase, k::Integer, meta, resfile_::String) | ||
resfile_ = replace(resfile_, ".h5" => "") | ||
#step = 1.05f0 | ||
#delta = idx.search_algo.Δ / step^3 | ||
params = meta["params"] | ||
|
||
# produces result files for different search hyperparameters | ||
#while delta < 2f0 | ||
# idx.search_algo.Δ = delta | ||
# dt = "delta=$(round(delta; digits=3))" | ||
resfile = "$resfile_.h5" | ||
@info "searching $resfile" | ||
querytime = @elapsed knns, dists = searchbatch(idx, queries, k) | ||
meta["querytime"] = querytime | ||
save_results(knns, dists, meta, resfile) | ||
#end | ||
end | ||
|
||
function task3(; | ||
dbsize, | ||
dfile="data2024/laion2B-en-clip768v2-n=$dbsize.h5", | ||
qfile="data2024/private-queries-2024-laion2B-en-clip768v2-n=10k-epsilon=0.2.h5", | ||
k=30, | ||
outdir="results-task3-bruteforce/$dbsize/$(Dates.format(Dates.now(), "yyyymmdd-HHMMSS"))" | ||
) | ||
|
||
mkpath(outdir) | ||
dist = NormalizedCosineDistance() # 1 - dot(·, ·) | ||
nbits = 8 * 4 * 128 # memory eq to 128 fp32 | ||
#model, dist_proj, nick = create_rp_model(dist, dfile; nbits) | ||
modelingtime = @elapsed model, dist_proj, nick = create_pca_model(dist, dfile; nbits) | ||
encdatabasetime = @elapsed db = predict_h5(model, dfile; nbits) | ||
encqueriestime = @elapsed queries = predict_h5(model, qfile; nbits) | ||
|
||
# loading or computing knns | ||
G = ExhaustiveSearch(;dist=dist_proj, db) | ||
meta = Dict( | ||
"buildtime" => 0.0, | ||
"matrix_size" => size(db.matrix), | ||
"optimtime" => 0.0, | ||
"algo" => "Bruteforce", | ||
"params" => "" | ||
) | ||
# saveindex("index.task3.nbits=$nbits.jld2", G; meta=(; nbits, modelingtime, encdatabasetime, encqueriestime), store_db=false) | ||
meta["size"] = dbsize | ||
meta["modelingtime"] = modelingtime | ||
meta["encdatabasetime"] = encdatabasetime | ||
meta["encqueriestime"] = encqueriestime | ||
meta["buildtime"] = 0.0 | ||
meta["params"] = "$(meta["params"]) $nick" | ||
resfile = joinpath(outdir, "bruteforce-$nick-k=$k") | ||
run_search_task3(G, queries, k, meta, resfile) | ||
end | ||
|
||
# functions for each database; these should have all required hyperparameters | ||
|
||
if !isinteractive() | ||
if length(ARGS) == 0 || any(dbsize -> dbsize ∉ ("300K", "10M", "100M"), ARGS) | ||
throw(ArgumentError("this script must be called with a list of the following arguments: 300K, 10M or 100M")) | ||
end | ||
|
||
for dbsize in ARGS | ||
task3(; dbsize) | ||
end | ||
end |