Skip to content

Commit

Permalink
Merge pull request #323 from SouthEndMusic/fix_inf_f
Browse files Browse the repository at this point in the history
Use index guessing functionality moved to `FindFirstFunctions.jl` (`Guesser`)
  • Loading branch information
ChrisRackauckas authored Aug 5, 2024
2 parents 80c8ebb + 2e9b14c commit 35db42e
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ DataInterpolationsSymbolicsExt = "Symbolics"
Aqua = "0.8"
BenchmarkTools = "1"
ChainRulesCore = "1.24"
FindFirstFunctions = "1.1"
FindFirstFunctions = "1.3"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
Expand Down
4 changes: 2 additions & 2 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ end

function u_tangent(A::LinearInterpolation, t, Δ)
out = zero(A.u)
idx = get_idx(A, t, A.idx_prev[])
idx = get_idx(A, t, A.iguesser)
t_factor = (t - A.t[idx]) / (A.t[idx + 1] - A.t[idx])
out[idx] = Δ * (one(eltype(out)) - t_factor)
out[idx + 1] = Δ * t_factor
Expand All @@ -61,7 +61,7 @@ end

function u_tangent(A::QuadraticInterpolation, t, Δ)
out = zero(A.u)
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.idx_prev[])
i₀, i₁, i₂ = _quad_interp_indices(A, t, A.iguesser)
t₀ = A.t[i₀]
t₁ = A.t[i₁]
t₂ = A.t[i₂]
Expand Down
6 changes: 1 addition & 5 deletions src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra, RecipesBase
using PrettyTables
using ForwardDiff
import FindFirstFunctions: searchsortedfirstcorrelated, searchsortedlastcorrelated,
bracketstrictlymontonic
Guesser

include("parameter_caches.jl")
include("interpolation_caches.jl")
Expand All @@ -22,10 +22,6 @@ include("online.jl")
include("show.jl")

(interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t)
function (interp::AbstractInterpolation)(t::Number, i::Integer)
interp.idx_prev[] = i
_interpolate(interp, t)
end

function (interp::AbstractInterpolation)(t::AbstractVector)
u = get_u(interp.u, t)
Expand Down
42 changes: 19 additions & 23 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
function derivative(A, t, order = 1)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
iguess = A.idx_prev[]
iguess = A.iguesser

return if order == 1
val, idx = _derivative(A, t, iguess)
A.idx_prev[] = idx
val
_derivative(A, t, iguess)
elseif order == 2
ForwardDiff.derivative(t -> begin
val, idx = _derivative(A, t, iguess)
A.idx_prev[] = idx
val
_derivative(A, t, iguess)
end, t)
else
throw(DerivativeNotFoundError())
Expand All @@ -20,7 +16,7 @@ end
function _derivative(A::LinearInterpolation, t::Number, iguess)
idx = get_idx(A, t, iguess; idx_shift = -1, ub_shift = -1, side = :first)
slope = get_parameters(A, idx)
slope, idx
slope
end

function _derivative(A::QuadraticInterpolation, t::Number, iguess)
Expand All @@ -29,7 +25,7 @@ function _derivative(A::QuadraticInterpolation, t::Number, iguess)
du₀ = l₀ * (2t - A.t[i₁] - A.t[i₂])
du₁ = l₁ * (2t - A.t[i₀] - A.t[i₂])
du₂ = l₂ * (2t - A.t[i₀] - A.t[i₁])
return @views @. du₀ + du₁ + du₂, i₀
return @views @. du₀ + du₁ + du₂
end

function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
Expand Down Expand Up @@ -101,21 +97,21 @@ function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
end

function _derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number, idx)
_derivative(A, t), idx
_derivative(A, t)
end
function _derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number, idx)
_derivative(A, t), idx
_derivative(A, t)
end

function _derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess; idx_shift = -1, side = :first)
j = min(idx, length(A.c)) # for smooth derivative at A.t[end]
wj = t - A.t[idx]
(@evalpoly wj A.b[idx] 2A.c[j] 3A.d[j]), idx
@evalpoly wj A.b[idx] 2A.c[j] 3A.d[j]
end

function _derivative(A::ConstantInterpolation, t::Number, iguess)
return zero(first(A.u)), iguess
return zero(first(A.u))
end

function _derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
Expand All @@ -132,7 +128,7 @@ end
function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A, t, iguess; lb = 2, ub_shift = 0, side = :first)
σ = get_parameters(A, idx - 1)
A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx
A.z[idx - 1] + 2σ * (t - A.t[idx - 1])
end

# CubicSpline Interpolation
Expand All @@ -144,13 +140,13 @@ function _derivative(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
c₁, c₂ = get_parameters(A, idx)
dC = c₁
dD = -c₂
dI + dC + dD, idx
dI + dC + dD
end

function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Number, iguess)
# change t into param [0 1]
t < A.t[1] && return zero(A.u[1]), 1
t > A.t[end] && return zero(A.u[end]), lastindex(t)
t < A.t[1] && return zero(A.u[1])
t > A.t[end] && return zero(A.u[end])
idx = get_idx(A, t, iguess)
n = length(A.t)
scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx])
Expand All @@ -165,14 +161,14 @@ function _derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Num
ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1])
end
end
ducum * A.d * scale, idx
ducum * A.d * scale
end

# BSpline Curve Approx
function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, iguess)
# change t into param [0 1]
t < A.t[1] && return zero(A.u[1]), 1
t > A.t[end] && return zero(A.u[end]), lastindex(t)
t < A.t[1] && return zero(A.u[1])
t > A.t[end] && return zero(A.u[end])
idx = get_idx(A, t, iguess)
scale = (A.p[idx + 1] - A.p[idx]) / (A.t[idx + 1] - A.t[idx])
t_ = A.p[idx] + (t - A.t[idx]) * scale
Expand All @@ -186,7 +182,7 @@ function _derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number, ig
ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1])
end
end
ducum * A.d * scale, idx
ducum * A.d * scale
end

# Cubic Hermite Spline
Expand All @@ -198,7 +194,7 @@ function _derivative(
out = A.du[idx]
c₁, c₂ = get_parameters(A, idx)
out += Δt₀ * (Δt₀ * c₂ + 2(c₁ + Δt₁ * c₂))
out, idx
out
end

# Quintic Hermite Spline
Expand All @@ -211,5 +207,5 @@ function _derivative(
c₁, c₂, c₃ = get_parameters(A, idx)
out += Δt₀^2 *
(3c₁ + (3Δt₁ + Δt₀) * c₂ + (3Δt₁^2 + Δt₀ * 2Δt₁) * c₃)
out, idx
out
end
14 changes: 7 additions & 7 deletions src/integral_inverses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ invert_integral(A::AbstractInterpolation) = throw(IntegralInverseNotFoundError()
_integral(A::AbstractIntegralInverseInterpolation, idx, t) = throw(IntegralNotFoundError())

function _derivative(A::AbstractIntegralInverseInterpolation, t::Number, iguess)
inv(A.itp(A(t))), A.idx_prev[]
inv(A.itp(A(t)))
end

"""
Expand All @@ -38,11 +38,11 @@ struct LinearInterpolationIntInv{uType, tType, itpType, T} <:
u::uType
t::tType
extrapolate::Bool
idx_prev::Base.RefValue{Int}
iguesser::Guesser{tType}
itp::itpType
function LinearInterpolationIntInv(u, t, A)
new{typeof(u), typeof(t), typeof(A), eltype(u)}(
u, t, A.extrapolate, Ref(1), A)
u, t, A.extrapolate, Guesser(t), A)
end
end

Expand All @@ -64,7 +64,7 @@ function _interpolate(
x = A.itp.u[idx]
slope = get_parameters(A.itp, idx)
u = A.u[idx] + 2Δt / (x + sqrt(x^2 + slope * 2Δt))
u, idx
u
end

"""
Expand All @@ -84,11 +84,11 @@ struct ConstantInterpolationIntInv{uType, tType, itpType, T} <:
u::uType
t::tType
extrapolate::Bool
idx_prev::Base.RefValue{Int}
iguesser::Guesser{tType}
itp::itpType
function ConstantInterpolationIntInv(u, t, A)
new{typeof(u), typeof(t), typeof(A), eltype(u)}(
u, t, A.extrapolate, Ref(1), A
u, t, A.extrapolate, Guesser(t), A
)
end
end
Expand All @@ -112,5 +112,5 @@ function _interpolate(
# :right means that value to the right is used for interpolation
idx_ = get_idx(A, t, idx; side = :first, lb = 1, ub_shift = 0)
end
A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_], idx
A.u[idx] + (t - A.t[idx]) / A.itp.u[idx_]
end
Loading

0 comments on commit 35db42e

Please sign in to comment.