Skip to content

Commit

Permalink
bruteforce
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Aug 29, 2024
1 parent 5230db0 commit 3c4add5
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 0 deletions.
85 changes: 85 additions & 0 deletions task1-bruteforce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
using SimilaritySearch, SurrogatedDistanceModels, JLD2, CSV, Glob, LinearAlgebra, Dates

include("common.jl")

# This file is based in the julia's example of the 2023 edition

"""
load_database(file)
Loads a dataset stored in `file`n
"""
function load_database(file)
@info "loading clip768 (converting Float16 -> Float32)"
X = jldopen(file) do f
Matrix{Float32}(f["emb"])
end

#=for col in eachcol(X)
normalize!(col)
end=#

StrideMatrixDatabase(X)
end

function run_search_task1(idx, queries::AbstractDatabase, k::Integer, meta, resfile_::AbstractString)
resfile_ = replace(resfile_, ".h5" => "")
resfile = "$resfile_.h5"
@info "searching $resfile"
querytime = @elapsed knns, dists = searchbatch(idx, queries, k)
meta["querytime"] = querytime
save_results(knns, dists, meta, resfile)
end

"""
task1(; kwargs...)
Runs an entire beenchmark
- `dbsize`: string denoting the size of the dataset (e.g., "300K", "100M"), million scale should not be used in GitHub Actions.
- `k`: the number of neighbors to find
"""
function task1(;
dbsize,
dfile="data2024/laion2B-en-clip768v2-n=$dbsize.h5",
#qfile="data2024/public-queries-2024-laion2B-en-clip768v2-n=10k.h5",
qfile="data2024/private-queries-2024-laion2B-en-clip768v2-n=10k-epsilon=0.2.h5",
k=30,
outdir="results-task1-bruteforce/$dbsize/$(Dates.format(Dates.now(), "yyyymmdd-HHMMSS"))"
)

mkpath(outdir)

dist = NormalizedCosineDistance() # 1 - dot(·, ·)
@info "loading $qfile and $dfile"
@time db = load_database(dfile)
@time queries = load_database(qfile)

# loading or computing knns
@info "indexing, this can take a while!"
G = ExhaustiveSearch(; dist, db)
meta = Dict(
"buildtime" => 0.0,
"matrix_size" => size(db.matrix),
"optimtime" => 0.0,
"algo" => "Bruteforce",
"params" => ""
)
meta["size"] = dbsize
meta["modelingtime"] = 0.0
meta["encdatabasetime"] = 0.0
meta["encqueriestime"] = 0.0
meta["buildtime"] = 0.0
resfile = joinpath(outdir, "bruteforce-k=$k")
run_search_task1(G, queries, k, meta, resfile)
end

if !isinteractive()
if length(ARGS) == 0 || any(dbsize -> dbsize ("300K", "10M", "100M"), ARGS)
throw(ArgumentError("this script must be called with a list of the following arguments: 300K, 10M or 100M"))
end

for dbsize in ARGS
task1(; dbsize)
end
end
72 changes: 72 additions & 0 deletions task3-bruteforce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using SimilaritySearch, SurrogatedDistanceModels, HDF5, JLD2, CSV, Glob, LinearAlgebra, Dates

include("common.jl")

# This file is based in the julia's example of the 2023 edition


function run_search_task3(idx, queries::AbstractDatabase, k::Integer, meta, resfile_::String)
resfile_ = replace(resfile_, ".h5" => "")
#step = 1.05f0
#delta = idx.search_algo.Δ / step^3
params = meta["params"]

# produces result files for different search hyperparameters
#while delta < 2f0
# idx.search_algo.Δ = delta
# dt = "delta=$(round(delta; digits=3))"
resfile = "$resfile_.h5"
@info "searching $resfile"
querytime = @elapsed knns, dists = searchbatch(idx, queries, k)
meta["querytime"] = querytime
save_results(knns, dists, meta, resfile)
#end
end

function task3(;
dbsize,
dfile="data2024/laion2B-en-clip768v2-n=$dbsize.h5",
qfile="data2024/private-queries-2024-laion2B-en-clip768v2-n=10k-epsilon=0.2.h5",
k=30,
outdir="results-task3-bruteforce/$dbsize/$(Dates.format(Dates.now(), "yyyymmdd-HHMMSS"))"
)

mkpath(outdir)
dist = NormalizedCosineDistance() # 1 - dot(·, ·)
nbits = 8 * 4 * 128 # memory eq to 128 fp32
#model, dist_proj, nick = create_rp_model(dist, dfile; nbits)
modelingtime = @elapsed model, dist_proj, nick = create_pca_model(dist, dfile; nbits)
encdatabasetime = @elapsed db = predict_h5(model, dfile; nbits)
encqueriestime = @elapsed queries = predict_h5(model, qfile; nbits)

# loading or computing knns
G = ExhaustiveSearch(;dist=dist_proj, db)
meta = Dict(
"buildtime" => 0.0,
"matrix_size" => size(db.matrix),
"optimtime" => 0.0,
"algo" => "Bruteforce",
"params" => ""
)
# saveindex("index.task3.nbits=$nbits.jld2", G; meta=(; nbits, modelingtime, encdatabasetime, encqueriestime), store_db=false)
meta["size"] = dbsize
meta["modelingtime"] = modelingtime
meta["encdatabasetime"] = encdatabasetime
meta["encqueriestime"] = encqueriestime
meta["buildtime"] = 0.0
meta["params"] = "$(meta["params"]) $nick"
resfile = joinpath(outdir, "bruteforce-$nick-k=$k")
run_search_task3(G, queries, k, meta, resfile)
end

# functions for each database; these should have all required hyperparameters

if !isinteractive()
if length(ARGS) == 0 || any(dbsize -> dbsize ("300K", "10M", "100M"), ARGS)
throw(ArgumentError("this script must be called with a list of the following arguments: 300K, 10M or 100M"))
end

for dbsize in ARGS
task3(; dbsize)
end
end

0 comments on commit 3c4add5

Please sign in to comment.