-
Notifications
You must be signed in to change notification settings - Fork 3
/
task3-bruteforce.jl
72 lines (61 loc) · 2.58 KB
/
task3-bruteforce.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
using SimilaritySearch, SurrogatedDistanceModels, HDF5, JLD2, CSV, Glob, LinearAlgebra, Dates
include("common.jl")
# This file is based in the julia's example of the 2023 edition
function run_search_task3(idx, queries::AbstractDatabase, k::Integer, meta, resfile_::String)
resfile_ = replace(resfile_, ".h5" => "")
#step = 1.05f0
#delta = idx.search_algo.Δ / step^3
params = meta["params"]
# produces result files for different search hyperparameters
#while delta < 2f0
# idx.search_algo.Δ = delta
# dt = "delta=$(round(delta; digits=3))"
resfile = "$resfile_.h5"
@info "searching $resfile"
querytime = @elapsed knns, dists = searchbatch(idx, queries, k)
meta["querytime"] = querytime
save_results(knns, dists, meta, resfile)
#end
end
function task3(;
dbsize,
dfile="data2024/laion2B-en-clip768v2-n=$dbsize.h5",
qfile="data2024/private-queries-2024-laion2B-en-clip768v2-n=10k-epsilon=0.2.h5",
k=30,
outdir="results-task3-bruteforce/$dbsize/$(Dates.format(Dates.now(), "yyyymmdd-HHMMSS"))"
)
mkpath(outdir)
dist = NormalizedCosineDistance() # 1 - dot(·, ·)
nbits = 8 * 4 * 128 # memory eq to 128 fp32
#model, dist_proj, nick = create_rp_model(dist, dfile; nbits)
modelingtime = @elapsed model, dist_proj, nick = create_pca_model(dist, dfile; nbits)
encdatabasetime = @elapsed db = predict_h5(model, dfile; nbits)
encqueriestime = @elapsed queries = predict_h5(model, qfile; nbits)
# loading or computing knns
G = ExhaustiveSearch(;dist=dist_proj, db)
meta = Dict(
"buildtime" => 0.0,
"matrix_size" => size(db.matrix),
"optimtime" => 0.0,
"algo" => "Bruteforce",
"params" => ""
)
# saveindex("index.task3.nbits=$nbits.jld2", G; meta=(; nbits, modelingtime, encdatabasetime, encqueriestime), store_db=false)
meta["size"] = dbsize
meta["modelingtime"] = modelingtime
meta["encdatabasetime"] = encdatabasetime
meta["encqueriestime"] = encqueriestime
meta["buildtime"] = 0.0
meta["params"] = "$(meta["params"]) $nick"
resfile = joinpath(outdir, "bruteforce-$nick-k=$k")
run_search_task3(G, queries, k, meta, resfile)
end
# functions for each database; these should have all required hyperparameters
if !isinteractive()
if length(ARGS) == 0 || any(dbsize -> dbsize ∉ ("300K", "10M", "100M"), ARGS)
throw(ArgumentError("this script must be called with a list of the following arguments: 300K, 10M or 100M"))
end
for dbsize in ARGS
task3(; dbsize)
end
end