Skip to content

Commit

Permalink
modeling time
Browse files Browse the repository at this point in the history
  • Loading branch information
sadit committed Mar 20, 2024
1 parent d651888 commit 0265316
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 30 deletions.
23 changes: 19 additions & 4 deletions common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ Similarity search on neighbor's graphs with automatic Pareto optimal performance
ES Tellez, G Ruiz - arXiv preprint arXiv:2201.07917, 2022
```
"""
function build_searchgraph(dist::SemiMetric, db::AbstractDatabase; minrecall=0.9)
function build_searchgraph(dist::SemiMetric, db::AbstractDatabase; minrecall=0.95, logbase=2.0)
algo = "SearchGraph"
logbase = 2
ctx = SearchGraphContext(;
hyperparameters_callback = OptimizeParameters(MinRecall(minrecall)),
neighborhood = Neighborhood(; logbase),
Expand Down Expand Up @@ -54,7 +53,9 @@ function save_results(knns::Matrix, dists::Matrix, meta, resfile::AbstractString
algo=meta["algo"],
buildtime=meta["buildtime"] + meta["optimtime"],
querytime=meta["querytime"],
preprocessingtime=meta["preprocessingtime"],
modelingtime=get(meta, "modelingtime", 0.0),
encdatabasetime=get(meta, "encdatabasetime", 0.0),
encqueriestime=get(meta, "encqueriestime", 0.0),
params=meta["params"],
size=meta["size"]
)
Expand Down Expand Up @@ -98,6 +99,20 @@ function create_binperms_model(dist, file::String; nbits::Int, nrefs::Int=2048)
fit(BinPerms, dist, refs, nbits), BinaryHammingDistance(), "BinPerms-$nbits"
end

function create_heh_model(dist, file::String; nbits::Int)
A = h5open(file) do f
X = f["emb"]
m, n = size(X)
n2 = 2^15
X[:, 1:n2]
end

@show size(A) typeof(A)
#fit(highentropyhyperplanes, dist, matrixdatabase(a), nbits; sample_for_hyperplane_selection=2^16, k=4092, k2=1024), binaryhammingdistance(), "highentropyhyperplanes-$nbits"
fit(HighEntropyHyperplanes, dist, MatrixDatabase(A), nbits; minent=0.5,
sample_for_hyperplane_selection=2^13), BinaryHammingDistance(), "HighEntropyHyperplanes-$nbits"
end


function predict_h5(model::Union{PCAProjection,GaussianRandomProjection}, file::String; nbits, block::Int=10^5)
dim = nbits ÷ 32
Expand All @@ -114,7 +129,7 @@ function predict_h5(model::Union{PCAProjection,GaussianRandomProjection}, file::
end
end

function predict_h5(model::BinPerms, file::String; nbits, block::Int=10^5)
function predict_h5(model::Union{BinPerms,HighEntropyHyperplanes}, file::String; nbits, block::Int=10^5)
h5open(file) do f
X = f["emb"]
m, n = size(X)
Expand Down
55 changes: 38 additions & 17 deletions eval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,61 @@ using JLD2, SimilaritySearch, DataFrames, CSV, Glob, UnicodePlots

function evaluate_results(gfile, resultfiles, k)
gold_knns = jldopen(f->f["knns"][1:k, :], gfile)
res = DataFrame(size=[], algo=[], preprocessingtime=[], buildtime=[], querytime=[], params=[], recall=[])
res = DataFrame(size=[], algo=[], modelingtime=[], encdatabasetime=[], encqueriestime=[], buildtime=[], querytime=[], params=[], recall=[])
for resfile in resultfiles
@info resfile
reg = jldopen(resfile) do f
knns = f["knns"][1:k, :]
recall = macrorecall(gold_knns, knns)
push!(res, (f["size"], f["algo"], f["preprocessingtime"], f["buildtime"], f["querytime"], f["params"], recall))
push!(res, (f["size"], f["algo"],
get(f, "modelingtime", 0.0),
get(f, "encdatabasetime", 0.0),
get(f, "encqueriestime", 0.0),
f["buildtime"], f["querytime"], f["params"], recall))
end

end

sort!(res, :recall)
res
end

function print_results(f, D, gfile, files, task, dbsize)
println(f, gfile => files)
show(f, "text/plain", gfile => files); println(f)
show(f, "text/plain", D); println(f)
p = lineplot(D.recall; ylim=(0, 1), title=String(D.algo[1]), ylabel="recall", xlabel="$(D.params[1]) to $(D.params[end])")
show(f, "text/plain", p); println(f)
end

if !isinteractive()
goldsuffix = "public-queries-2024-laion2B-en-clip768v2-n=10k.h5"
k = 30
open("results-summary.txt", "w") do f
for path in glob("results-task?/*")
task, dbsize = splitpath(path)
lastpath = glob(joinpath(path, "*")) |> sort! |> last
if length(ARGS) == 0
open("results-summary.txt", "w") do f
for path in sort!(glob("results-task*/*"))
task, dbsize = splitpath(path)
lastpath = glob(joinpath(path, "*")) |> sort! |> last
gfile = joinpath("data2024", "gold-standard-dbsize=$dbsize--$goldsuffix")
files = glob(joinpath(lastpath, "*.h5"))
length(files) == 0 && continue
D = evaluate_results(gfile, files, k)
println(f, "\n\n=== results for $dbsize $goldsuffix ===")
print_results(stdout, D, gfile, files, task, dbsize)
print_results(f, D, gfile, files, task, dbsize)
CSV.write("$task-$dbsize.csv", D)
end
end
else
for path in ARGS
task, dbsize, _ = splitpath(path)
gfile = joinpath("data2024", "gold-standard-dbsize=$dbsize--$goldsuffix")
files = glob(joinpath(lastpath, "*.h5"))
files = glob(joinpath(path, "*.h5"))
D = evaluate_results(gfile, files, k)
println(f, "\n\n=== results for $dbsize $goldsuffix ===")
println(f, gfile => files)
show(f, "text/plain", gfile => files); println()
show(f, "text/plain", D); println()
display(gfile => files)
display(D)
p = lineplot(D.recall; ylim=(0, 1), title=String(D.algo[1]), ylabel="recall", xlabel="$(D.params[1]) ... $(D.params[end])")
display(p)
show(f, "text/plain", p); println()
CSV.write("$task-$dbsize.csv", D)
println("\n\n=== results for $dbsize $goldsuffix ===")
print_results(stdout, D, gfile, files, task, dbsize)
end
end

nothing
end
1 change: 0 additions & 1 deletion task1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ function task1(;
# loading or computing knns
@info "indexing, this can take a while!"
G, meta = build_searchgraph(dist, db)
meta["preprocessingtime"] = 0.0
meta["size"] = dbsize
resfile = joinpath(outdir, "searchgraph-k=$k")
run_search_task1(G, queries, k, meta, resfile)
Expand Down
10 changes: 6 additions & 4 deletions task2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,16 @@ function task2(;
mkpath(outdir)
dist = NormalizedCosineDistance() # 1 - dot(·, ·)
nbits = 8 * 4 * 96 # same than 96 FP32
preprocessingtime = @elapsed model, dist_proj, nick = create_pca_model(dist, dfile; nbits)
preprocessingtime += @elapsed db = predict_h5(model, dfile; nbits)
preprocessingtime += @elapsed queries = predict_h5(model, qfile; nbits)
modelingtime = @elapsed model, dist_proj, nick = create_pca_model(dist, dfile; nbits)
encdatabasetime = @elapsed db = predict_h5(model, dfile; nbits)
encqueriestime = @elapsed queries = predict_h5(model, qfile; nbits)

# loading or computing knns
@info "indexing, this can take a while!"
G, meta = build_searchgraph(dist_proj, db)
meta["preprocessingtime"] = preprocessingtime
meta["modelingtime"] = modelingtime
meta["encdatabasetime"] = encdatabasetime
meta["encqueriestime"] = encqueriestime
meta["size"] = dbsize
meta["params"] = "$(meta["params"]) $nick-$nbits"
resfile = joinpath(outdir, "searchgraph-$nick-k=$k")
Expand Down
12 changes: 8 additions & 4 deletions task3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,22 @@ function task3(;

mkpath(outdir)
dist = NormalizedCosineDistance() # 1 - dot(·, ·)
#nbits = 8 * 4 * 128 # memory eq to 128 fp32
nbits = 8 * 4 * 128 # memory eq to 128 fp32
#model, dist_proj, nick = create_binperms_model(dist, dfile; nbits)
#model, dist_proj, nick = create_rp_model(dist, dfile; nbits)
preprocessingtime = @elapsed model, dist_proj, nick = create_pca_model(dist, dfile; nbits)
preprocessingtime += @elapsed db = predict_h5(model, dfile; nbits)
preprocessingtime += @elapsed queries = predict_h5(model, qfile; nbits)
modelingtime = @elapsed model, dist_proj, nick = create_pca_model(dist, dfile; nbits)
#modelingtime = @elapsed model, dist_proj, nick = create_heh_model(dist, dfile; nbits)
encdatabasetime = @elapsed db = predict_h5(model, dfile; nbits)
encqueriestime = @elapsed queries = predict_h5(model, qfile; nbits)

# loading or computing knns
@info "indexing, this can take a while!"
G, meta = build_searchgraph(dist_proj, db)
meta["size"] = dbsize
meta["preprocessingtime"] = preprocessingtime
meta["modelingtime"] = modelingtime
meta["encdatabasetime"] = encdatabasetime
meta["encqueriestime"] = encqueriestime
meta["params"] = "$(meta["params"]) $nick"
resfile = joinpath(outdir, "searchgraph-$nick-k=$k")
run_search_task3(G, queries, k, meta, resfile)
Expand Down

0 comments on commit 0265316

Please sign in to comment.