Skip to content

Commit

Permalink
Merge pull request #91 from frankier/fix-loess-dispatch
Browse files Browse the repository at this point in the history
Fix loess dispatch to avoid infinite recursion
  • Loading branch information
andreasnoack authored Aug 30, 2024
2 parents 1e63d3f + 5286502 commit 060956f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 24 deletions.
59 changes: 35 additions & 24 deletions src/Loess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,14 @@ modelmatrix(model::LoessModel) = model.xs

response(model::LoessModel) = model.ys

"""
loess(xs, ys; normalize=true, span=0.75, degree=2)
Fit a loess model.
Args:
- `xs`: A `n` by `m` matrix with `n` observations from `m` independent predictors
- `ys`: A length `n` response vector.
- `normalize`: Normalize the scale of each predicitor. (default true when `m > 1`)
- `span`: The degree of smoothing, typically in [0,1]. Smaller values result in smaller
local context in fitting.
- `degree`: Polynomial degree.
- `cell`: Control parameter for bucket size. Internal interpolation nodes will be
added to the K-D tree until the number of bucket element is below `n * cell * span`.
Returns:
A fit `LoessModel`.
"""
function loess(
function _loess(
xs::AbstractMatrix{T},
ys::AbstractVector{T};
normalize::Bool = true,
span::AbstractFloat = 0.75,
degree::Integer = 2,
cell::AbstractFloat = 0.2
) where T<:AbstractFloat

Base.require_one_based_indexing(xs)
Base.require_one_based_indexing(ys)

Expand Down Expand Up @@ -146,12 +126,43 @@ function loess(
LoessModel(convert(Matrix{T}, xs), convert(Vector{T}, ys), predictions_and_gradients, kdtree)
end

loess(xs::AbstractVector{T}, ys::AbstractVector{T}; kwargs...) where {T<:AbstractFloat} =
"""
loess(xs, ys; normalize=true, span=0.75, degree=2)
Fit a loess model.
Args:
- `xs`: A `n` by `m` matrix with `n` observations from `m` independent predictors
- `ys`: A length `n` response vector.
- `normalize`: Normalize the scale of each predicitor. (default true when `m > 1`)
- `span`: The degree of smoothing, typically in [0,1]. Smaller values result in smaller
local context in fitting.
- `degree`: Polynomial degree.
- `cell`: Control parameter for bucket size. Internal interpolation nodes will be
added to the K-D tree until the number of bucket element is below `n * cell * span`.
Returns:
A fit `LoessModel`.
"""
function loess(
xs::AbstractMatrix{T},
ys::AbstractVector{T};
normalize::Bool = true,
span::AbstractFloat = 0.75,
degree::Integer = 2,
cell::AbstractFloat = 0.2
) where T<:AbstractFloat
_loess(xs, ys; normalize, span, degree, cell)
end

loess(xs::AbstractVector{T}, ys::AbstractVector{S}; kwargs...) where {T,S} =
loess(reshape(xs, (length(xs), 1)), ys; kwargs...)

function loess(xs::AbstractArray{T,N}, ys::AbstractVector{S}; kwargs...) where {T,N,S}
function loess(xs::AbstractMatrix{T}, ys::AbstractVector{S}; kwargs...) where {T,S}
R = float(promote_type(T, S))
loess(convert(AbstractArray{R,N}, xs), convert(AbstractVector{R}, ys); kwargs...)
# Dispatch to another function here to avoid potential infinite recursion
_loess(convert(AbstractMatrix{R}, xs), convert(AbstractVector{R}, ys); kwargs...)
end


Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ end
@test predict(ft, 1.0) == 1.0
@test_throws ArgumentError loess(Float64[], Float64[])
end

@testset "infinite recursion. Issue 60" begin
x = collect(1.0:10.0)
y = convert(Vector{Union{Nothing, Float64}}, x)
@test_throws MethodError loess(x, y)
end

0 comments on commit 060956f

Please sign in to comment.