Skip to content

Commit

Permalink
quick update lappy
Browse files Browse the repository at this point in the history
  • Loading branch information
kfgarrity committed Oct 4, 2024
1 parent dcf7096 commit c370a3f
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 54 deletions.
15 changes: 9 additions & 6 deletions src/CalcTB_laguerre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ const n_3body_onsite = 2
#const n_3body_onsite_same = 4
const n_3body_onsite_same = 5


EXP_a = [3.0]

using ..CrystalMod:cutoff2X
using ..CrystalMod:cutoff3bX
Expand Down Expand Up @@ -3956,7 +3956,7 @@ function calc_tb_prepare_fast(reference_tbc::tb_crys; use_threebody=false, use_t
end

if dist > 1e-5
ad = 2.0*dist
ad = EXP_a[1]*dist
expa=exp.(-0.5*ad)

rho[a1, 1] += (1.0 * expa) * cut
Expand Down Expand Up @@ -4586,7 +4586,7 @@ function calc_onsite(t1,s1,s2, database=missing)

end

function laguerre_fast!(dist, memory; a = 2.0)
function laguerre_fast!(dist, memory; a = EXP_a[1])

# a=2.0
ad = a*dist
Expand All @@ -4604,7 +4604,8 @@ end

function laguerre_fast_threebdy!(dist_0, dist_a, dist_b, same_atom, triple, memory)

a=2.0
#a=2.0
a=EXP_a[1]

ad_0 = a*dist_0
expa_0 =exp.(-0.5*ad_0) #* 10.0
Expand Down Expand Up @@ -4646,7 +4647,8 @@ end

function laguerre_fast_threebdy_onsite!(dist_0, dist_a, dist_b, same_atom, memory)

a=2.0
#a=2.0
a=EXP_a[1]

ad_0 = a*dist_b
expa_0 =exp.(-0.5*ad_0) #* 10.0
Expand Down Expand Up @@ -4682,7 +4684,8 @@ Calculate laguerre polynomials up to order `nmax`
"""
function laguerre(dist, ind=missing; nmax=6, memory=missing)

a=2.0
# a=2.0
a=EXP_a[1]


# a=3.0
Expand Down
176 changes: 128 additions & 48 deletions src/FitTB_laguerre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Plots
#using JLD
using LinearAlgebra
#using Statistics
#using Optim
using Optim
using Random
#using Calculus
using DelimitedFiles
Expand Down Expand Up @@ -52,6 +52,7 @@ using Suppressor
using ..AtomicProj:projwfc_workf
using ..SCF:remove_scf_from_tbc

using ..CalcTB:EXP_a
#using ..CrystalMod:orbital_index

#using ..TB:get_grid
Expand Down Expand Up @@ -1428,7 +1429,7 @@ This is the primary function for fitting. Uses the self-consistent linear fittin
- `start_small = false` When fitting only 3body data, setting this to true will start the 3body terms with very small values, which can improve convergence. Not useful if also fitting 2body terms.
"""
function do_fitting_recursive(list_of_tbcs ; weights_list = missing, dft_list=missing, kpoints = [0 0 0; 0 0 0.5; 0 0.5 0.5; 0.5 0.5 0.5; 0 0 0.25; 0 0.25 0; 0.25 0 0 ; 0.25 0.25 0.25; 0.25 0 0.25], starting_database = missing, update_all = false, fit_threebody=true, fit_threebody_onsite=true, do_plot = false, energy_weight = missing, rs_weight=missing,ks_weight=missing, niters=50, lambda=0.0, leave_one_out=false, prepare_data = missing, RW_PARAM=0.0, NLIM = 100, refit_database = missing, start_small = false, fit_to_dft_eigs=false, use_factor_dict = false)
function do_fitting_recursive(list_of_tbcs ; weights_list = missing, dft_list=missing, kpoints = [0 0 0; 0 0 0.5; 0 0.5 0.5; 0.5 0.5 0.5; 0 0 0.25; 0 0.25 0; 0.25 0 0 ; 0.25 0.25 0.25; 0.25 0 0.25], starting_database = missing, update_all = false, fit_threebody=true, fit_threebody_onsite=true, do_plot = false, energy_weight = missing, rs_weight=missing,ks_weight=missing, niters=50, lambda=0.0, leave_one_out=false, prepare_data = missing, RW_PARAM=0.0, NLIM = 100, refit_database = missing, start_small = false, fit_to_dft_eigs=false, use_factor_dict = false, cs_start = missing)

if !ismissing(dft_list)
println("top")
Expand Down Expand Up @@ -1531,7 +1532,7 @@ function do_fitting_recursive(list_of_tbcs ; weights_list = missing, dft_list=mi
# database_linear, ch_lin, cs_lin, X_Hnew_BIG, Y_Hnew_BIG, X_H, X_Snew_BIG, Y_H, h_on, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3 = prepare_data
end

return do_fitting_recursive_main(list_of_tbcs, pd; weights_list = weights_list, dft_list=dft_list, kpoints = kpoints, starting_database = starting_database, update_all = update_all, fit_threebody=fit_threebody, fit_threebody_onsite=fit_threebody_onsite, do_plot = do_plot, energy_weight = energy_weight, rs_weight=rs_weight,ks_weight = ks_weight, niters=niters, lambda=lambda, leave_one_out=leave_one_out, RW_PARAM=RW_PARAM, KPOINTS=KPOINTS, KWEIGHTS=KWEIGHTS, nk_max=nk_max, start_small = start_small , fit_to_dft_eigs=fit_to_dft_eigs)
return do_fitting_recursive_main(list_of_tbcs, pd; weights_list = weights_list, dft_list=dft_list, kpoints = kpoints, starting_database = starting_database, update_all = update_all, fit_threebody=fit_threebody, fit_threebody_onsite=fit_threebody_onsite, do_plot = do_plot, energy_weight = energy_weight, rs_weight=rs_weight,ks_weight = ks_weight, niters=niters, lambda=lambda, leave_one_out=leave_one_out, RW_PARAM=RW_PARAM, KPOINTS=KPOINTS, KWEIGHTS=KWEIGHTS, nk_max=nk_max, start_small = start_small , fit_to_dft_eigs=fit_to_dft_eigs, cs_start = cs_start)

end

Expand Down Expand Up @@ -1588,7 +1589,7 @@ function do_fitting_recursive_factor(list_of_tbcs ; weights_list = missing, dft_
end


function do_fitting_recursive_optim(list_of_tbcs ; weights_list = missing, dft_list=missing, kpoints = [0 0 0; 0 0 0.5; 0 0.5 0.5; 0.5 0.5 0.5; 0 0 0.25; 0 0.25 0; 0.25 0 0 ; 0.25 0.25 0.25; 0.25 0 0.25], starting_database = missing, update_all = false, fit_threebody=true, fit_threebody_onsite=true, do_plot = false, energy_weight = missing, rs_weight=missing,ks_weight=missing, niters=50, lambda=0.0, leave_one_out=false, prepare_data = missing, RW_PARAM=0.0, NLIM = 100, refit_database = missing, start_small = false, fit_to_dft_eigs=false)
function do_fitting_recursive_optim(list_of_tbcs ; weights_list = missing, dft_list=missing, kpoints = [0 0 0; 0 0 0.5; 0 0.5 0.5; 0.5 0.5 0.5; 0 0 0.25; 0 0.25 0; 0.25 0 0 ; 0.25 0.25 0.25; 0.25 0 0.25], starting_database = missing, update_all = false, fit_threebody=true, fit_threebody_onsite=true, do_plot = false, energy_weight = missing, rs_weight=missing,ks_weight=missing, niters=50, lambda=0.0, leave_one_out=false, prepare_data = missing, RW_PARAM=0.0, NLIM = 100, refit_database = missing, start_small = false, fit_to_dft_eigs=false, cs_start = missing, a_list = [2.0])

if !ismissing(dft_list)
println("top")
Expand All @@ -1598,68 +1599,112 @@ function do_fitting_recursive_optim(list_of_tbcs ; weights_list = missing, dft_l
KPOINTS, KWEIGHTS, nk_max = get_k_simple(kpoints, list_of_tbcs)
end



# println("KWEIGHTS 3 ", size(KWEIGHTS[3]), " " , KWEIGHTS[3][1:6])
DATABASE = []
ERROR = []
for a in a_list
EXP_a[1] = a

if ismissing(prepare_data)
println("DO LINEAR FITTING")
# println("KWEIGHTS 3 ", size(KWEIGHTS[3]), " " , KWEIGHTS[3][1:6])

if ismissing(prepare_data)
println("DO LINEAR FITTING")

if update_all == true
starting_database_t = missing #keep all
if update_all == true
starting_database_t = missing #keep all
else
starting_database_t = starting_database
end

pd = do_fitting_linear(list_of_tbcs; kpoints = KPOINTS, mode=:kspace, dft_list = dft_list, fit_threebody=fit_threebody, fit_threebody_onsite=fit_threebody_onsite, do_plot = false, starting_database=starting_database_t, return_database=false, NLIM=NLIM, refit_database=refit_database)
else
starting_database_t = starting_database
println("SKIP LINEAR MISSING")
pd = prepare_data
# database_linear, ch_lin, cs_lin, X_Hnew_BIG, Y_Hnew_BIG, X_H, X_Snew_BIG, Y_H, h_on, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3 = prepare_data
end

pd = do_fitting_linear(list_of_tbcs; kpoints = KPOINTS, mode=:kspace, dft_list = dft_list, fit_threebody=fit_threebody, fit_threebody_onsite=fit_threebody_onsite, do_plot = false, starting_database=starting_database_t, return_database=false, NLIM=NLIM, refit_database=refit_database)
else
println("SKIP LINEAR MISSING")
pd = prepare_data
# database_linear, ch_lin, cs_lin, X_Hnew_BIG, Y_Hnew_BIG, X_H, X_Snew_BIG, Y_H, h_on, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3 = prepare_data
end
# DATABASE = []
#for factor = 0.1:0.1:1.1
#for factor = [0.5]

DATABASE = []
#for factor = 0.1:0.1:1.1
#for factor = [0.5]
database, ch, cs, X_Hnew_BIG, Xc_Hnew_BIG, Xc_Snew_BIG, X_H, X_Snew_BIG, Y_H, Y_S, HON, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3, keepind, keepdata, Y_Hnew_BIG, Y_Snew_BIG, YS_new, cs , ch_refit, SPIN, threebody_inds = pd

database, ch, cs, X_Hnew_BIG, Xc_Hnew_BIG, Xc_Snew_BIG, X_H, X_Snew_BIG, Y_H, Y_S, HON, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3, keepind, keepdata, Y_Hnew_BIG, Y_Snew_BIG, YS_new, cs , ch_refit, SPIN, threebody_inds = pd
returnstuff = true

function f(cs_f)
cs_temp = cs_f
topstuff = top(list_of_tbcs, pd, weights_list, dft_list, kpoints, starting_database , update_all , fit_threebody, fit_threebody_onsite, do_plot, energy_weight, rs_weight, ks_weight , niters, lambda, leave_one_out, RW_PARAM, KPOINTS, KWEIGHTS, nk_max, start_small, fit_to_dft_eigs, returnstuff)

#i don't think these matter...
Y_S_temp = Y_S
Y_Snew_BIG_temp = Y_Snew_BIG

#Y_Snew_BIG_temp = X_Snew_BIG * cs_f
YS_new_temp = X_Snew_BIG * cs_f
#YS_new_temp = YS_new
pd_temp = database, ch, cs_temp, X_Hnew_BIG, Xc_Hnew_BIG, Xc_Snew_BIG, X_H, X_Snew_BIG, Y_H, Y_S_temp, HON, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3, keepind, keepdata, Y_Hnew_BIG, Y_Snew_BIG_temp, YS_new_temp, cs_temp , ch_refit, SPIN, threebody_inds
database, ch_new, error_current = do_fitting_recursive_main(list_of_tbcs, pd_temp; weights_list = weights_list, dft_list=dft_list, kpoints = kpoints, starting_database = starting_database, update_all = update_all, fit_threebody=fit_threebody, fit_threebody_onsite=fit_threebody_onsite, do_plot = do_plot, energy_weight = energy_weight, rs_weight=rs_weight,ks_weight = ks_weight, niters=niters, lambda=lambda, leave_one_out=leave_one_out, RW_PARAM=RW_PARAM, KPOINTS=KPOINTS, KWEIGHTS=KWEIGHTS, nk_max=nk_max, start_small = start_small , fit_to_dft_eigs=fit_to_dft_eigs, returnstuff=true)
ch = ch_new
return error_current
end

function f(cs_f)
cs_temp = cs_f

#i don't think these matter...
Y_S_temp = Y_S
Y_Snew_BIG_temp = Y_Snew_BIG

#Y_Snew_BIG_temp = X_Snew_BIG * cs_f
YS_new_temp = X_Snew_BIG * cs_f
#YS_new_temp = YS_new
pd_temp = database, ch, cs_temp, X_Hnew_BIG, Xc_Hnew_BIG, Xc_Snew_BIG, X_H, X_Snew_BIG, Y_H, Y_S_temp, HON, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3, keepind, keepdata, Y_Hnew_BIG, Y_Snew_BIG_temp, YS_new_temp, cs_temp , ch_refit, SPIN, threebody_inds
ch_new = missing
error_current = 0.0
@suppress begin
database, ch_new, error_current = do_fitting_recursive_main(list_of_tbcs, pd_temp; weights_list = weights_list, dft_list=dft_list, kpoints = kpoints, starting_database = starting_database, update_all = update_all, fit_threebody=fit_threebody, fit_threebody_onsite=fit_threebody_onsite, do_plot = do_plot, energy_weight = energy_weight, rs_weight=rs_weight,ks_weight = ks_weight, niters=niters, lambda=lambda, leave_one_out=leave_one_out, RW_PARAM=RW_PARAM, KPOINTS=KPOINTS, KWEIGHTS=KWEIGHTS, nk_max=nk_max, start_small = start_small , fit_to_dft_eigs=fit_to_dft_eigs, returnstuff=true, topstuff = missing, returndatabase=returndatabase)
end
ch = ch_new
println("cs_f ", cs_f)
println("error_current $error_current")
weight = repeat([1.0, 1.0, 1.0, 10.0, 100.0, 1000.0]*1e-3, 6)[1:length(cs_f)]
err = error_current + sum(cs_f.^2 .* weight) #very light regularization
if err < err_min
cs_min = cs_f
err_min = err
end
return err
end

if ismissing(cs_start)
cs_start = cs
end
cs_min = cs_start
err_min = 10.0^15

println("cs_start ", cs_start)
returndatabase=true
f(cs_start)
returndatabase=false
ret = Optim.optimize(f, cs_start, f_tol = 1e-3, f_calls_limit = 150, iterations = 45)
#ret = Optim.optimize(f, cs_start, f_tol = 1e-3, f_calls_limit = 20, iterations = 2)
returndatabase=true
println("err_min $err_min")
println("cs_min $cs_min")
f(cs_min)
println("ret")
println(ret)
println("Optim.minimizer(ret)")
println(Optim.minimizer(ret))


#push!(DATABASE, database)
push!(DATABASE, deepcopy(database))
push!(ERROR, err_min)
end

#end

f(cs)
# ret = f(cs)
# println("ret $ret")

return database
return DATABASE, ERROR

#return do_fitting_recursive_main(list_of_tbcs, pd; weights_list = weights_list, dft_list=dft_list, kpoints = kpoints, starting_database = starting_database, update_all = update_all, fit_threebody=fit_threebody, fit_threebody_onsite=fit_threebody_onsite, do_plot = do_plot, energy_weight = energy_weight, rs_weight=rs_weight,ks_weight = ks_weight, niters=niters, lambda=lambda, leave_one_out=leave_one_out, RW_PARAM=RW_PARAM, KPOINTS=KPOINTS, KWEIGHTS=KWEIGHTS, nk_max=nk_max, start_small = start_small , fit_to_dft_eigs=fit_to_dft_eigs)


end


function do_fitting_recursive_main(list_of_tbcs, prepare_data; weights_list=missing, dft_list=missing, kpoints = [0 0 0; 0 0 0.5; 0 0.5 0.5; 0.5 0.5 0.5], starting_database = missing, update_all = false, fit_threebody=true, fit_threebody_onsite=true, do_plot = false, energy_weight = missing, rs_weight=missing, ks_weight = missing, niters=50, lambda=0.0, leave_one_out=false, RW_PARAM=0.0001, KPOINTS=missing, KWEIGHTS=missing, nk_max=0, start_small=false, fit_to_dft_eigs=false, returnstuff=false)

function top(list_of_tbcs, prepare_data, weights_list, dft_list, kpoints, starting_database , update_all , fit_threebody, fit_threebody_onsite, do_plot, energy_weight, rs_weight, ks_weight , niters, lambda, leave_one_out, RW_PARAM, KPOINTS, KWEIGHTS, nk_max, start_small, fit_to_dft_eigs, returnstuff)

# database_linear, ch_lin, cs_lin, X_Hnew_BIG, Y_Hnew_BIG, X_H, X_Snew_BIG, Y_H, h_on, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3,keepind, keepdata = prepare_data

database_linear, ch_lin, cs_lin, X_Hnew_BIG, Xc_Hnew_BIG, Xc_Snew_BIG, X_H, X_Snew_BIG, Y_H, Y_S, h_on, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3, keepind, keepdata, Y_Hnew_BIG, Y_Snew_BIG, Ys_new, cs, ch_refit, SPIN, threebody_inds = prepare_data

println("AAAAAAAA ch_lin ", ch_lin)

println("keepind " , length(keepind), " " , sum(keepind))
Expand Down Expand Up @@ -1753,7 +1798,6 @@ function do_fitting_recursive_main(list_of_tbcs, prepare_data; weights_list=miss
WEIGHTS = zeros(NCALC, nk_max, NWAN_MAX, SPIN_MAX)
ENERGIES = zeros(NCALC)

Ys = Ys_new + Xc_Snew_BIG
X_Snew_BIG = nothing
Xc_Snew_BIG = nothing

Expand Down Expand Up @@ -2070,7 +2114,35 @@ function do_fitting_recursive_main(list_of_tbcs, prepare_data; weights_list=miss
end
println("EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE")

# return

return ch_keep, keep_inds, toupdate_inds, cs_keep, keep_inds_S, toupdate_inds_S, list_of_tbcs, dft_list, KPOINTS, KWEIGHTS, energy_weight, rs_weight, ks_weight, weights_list, NWAN_MAX, SPIN_MAX, NAT_MAX, NCALC, VALS, E_DEN, H1, H1spin, DQ, ENERGY_SMEAR, OCCS, WEIGHTS, ENERGIES, X_Snew_BIG, Xc_Snew_BIG, NCOLS_orig, NCOLS, ch, keep_bool, NVAL, NAT, scf, VALS0

end

function do_fitting_recursive_main(list_of_tbcs, prepare_data; weights_list=missing, dft_list=missing, kpoints = [0 0 0; 0 0 0.5; 0 0.5 0.5; 0.5 0.5 0.5], starting_database = missing, update_all = false, fit_threebody=true, fit_threebody_onsite=true, do_plot = false, energy_weight = missing, rs_weight=missing, ks_weight = missing, niters=50, lambda=0.0, leave_one_out=false, RW_PARAM=0.0001, KPOINTS=missing, KWEIGHTS=missing, nk_max=0, start_small=false, fit_to_dft_eigs=false, returnstuff=false, topstuff=missing, returndatabase=true, cs_start=missing)

if ismissing(topstuff)
ch_keep, keep_inds, toupdate_inds, cs_keep, keep_inds_S, toupdate_inds_S, list_of_tbcs, dft_list, KPOINTS, KWEIGHTS, energy_weight, rs_weight, ks_weight, weights_list, NWAN_MAX, SPIN_MAX, NAT_MAX, NCALC, VALS, E_DEN, H1, H1spin, DQ, ENERGY_SMEAR, OCCS, WEIGHTS, ENERGIES, X_Snew_BIG, Xc_Snew_BIG, NCOLS_orig, NCOLS, ch, keep_bool, NVAL, NAT, scf, VALS0 = top(list_of_tbcs, prepare_data, weights_list, dft_list, kpoints, starting_database , update_all , fit_threebody, fit_threebody_onsite, do_plot, energy_weight, rs_weight, ks_weight , niters, lambda, leave_one_out, RW_PARAM, KPOINTS, KWEIGHTS, nk_max, start_small, fit_to_dft_eigs, returnstuff)
else
ch_keep, keep_inds, toupdate_inds, cs_keep, keep_inds_S, toupdate_inds_S, list_of_tbcs, dft_list, KPOINTS, KWEIGHTS, energy_weight, rs_weight, ks_weight, weights_list, NWAN_MAX, SPIN_MAX, NAT_MAX, NCALC, VALS, E_DEN, H1, H1spin, DQ, ENERGY_SMEAR, OCCS, WEIGHTS, ENERGIES, X_Snew_BIG, Xc_Snew_BIG, NCOLS_orig, NCOLS, ch, keep_bool, NVAL, NAT, scf, VALS0 = topstuff
end



# database_linear, ch_lin, cs_lin, X_Hnew_BIG, Y_Hnew_BIG, X_H, X_Snew_BIG, Y_H, h_on, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3,keepind, keepdata = prepare_data

database_linear, ch_lin, cs_lin, X_Hnew_BIG, Xc_Hnew_BIG, Xc_Snew_BIG, X_H, X_Snew_BIG, Y_H, Y_S, h_on, ind_BIG, KEYS, HIND, SIND, DMIN_TYPES, DMIN_TYPES3, keepind, keepdata, Y_Hnew_BIG, Y_Snew_BIG, Ys_new, cs, ch_refit, SPIN, threebody_inds = prepare_data

if !ismissing(cs_start)
cs = cs_start
cs_lin = cs_start
Ys_new = X_Snew_BIG * cs
println("use starting cs , $cs")
println("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
# sleep(10)
end

Ys = Ys_new + Xc_Snew_BIG

VALS_working = zeros(size(VALS))
ENERGIES_working = zeros(size(ENERGIES))
Expand Down Expand Up @@ -2832,8 +2904,12 @@ function do_fitting_recursive_main(list_of_tbcs, prepare_data; weights_list=miss
println(good)
println("make database")

database = make_database(chX2, csX2, KEYS, HIND, SIND,DMIN_TYPES,DMIN_TYPES3, scf=scf, starting_database=starting_database, tbc_list = list_of_tbcs[good])

if returndatabase == false
database =missing
else
database = make_database(chX2, csX2, KEYS, HIND, SIND,DMIN_TYPES,DMIN_TYPES3, scf=scf, starting_database=starting_database, tbc_list = list_of_tbcs[good])
end

return database, chX, current_error

end
Expand All @@ -2848,12 +2924,16 @@ function do_fitting_recursive_main(list_of_tbcs, prepare_data; weights_list=miss
# end


if leave_one_out == false
if leave_one_out == false && returnstuff == false
database, ch, current_error = do_iters(ch, niters)
println("return")
return database
elseif returnstuff == true
database, ch, current_error = do_iters(ch, niters)
println("returnstuff ")
println("database ", typeof(database))
println("ch ", typeof(ch))
println("current_error $current_error")
return database, ch, current_error
else
database, ch = do_iters(ch, min(niters,20) )
Expand Down

0 comments on commit c370a3f

Please sign in to comment.