diff --git a/Project.toml b/Project.toml index 405b8c7..770df59 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.10.11" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" @@ -19,6 +20,7 @@ DistancesSparseArraysExt = "SparseArrays" [compat] ChainRulesCore = "1" LinearAlgebra = "<0.0.1, 1" +Missings = "1" SparseArrays = "<0.0.1, 1" Statistics = "<0.0.1, 1" StatsAPI = "1" diff --git a/src/Distances.jl b/src/Distances.jl index 854e28c..3349077 100644 --- a/src/Distances.jl +++ b/src/Distances.jl @@ -1,6 +1,7 @@ module Distances using LinearAlgebra +using Missings using Statistics: mean import StatsAPI: pairwise, pairwise! @@ -116,6 +117,7 @@ include("haversine.jl") include("mahalanobis.jl") include("bhattacharyya.jl") include("bregman.jl") +include("missing.jl") include("deprecated.jl") diff --git a/src/missing.jl b/src/missing.jl new file mode 100644 index 0000000..3ed89be --- /dev/null +++ b/src/missing.jl @@ -0,0 +1,47 @@ +""" +Exclude any missing indices from being included the wrappped distance metric. +""" +struct SkipMissing{D<:PreMetric} <: PreMetric + d::D +end + +result_type(dist::SkipMissing, a::Type, b::Type) = result_type(dist.d, a, b) + +# Always fallback to the internal metric behaviour +(dist::SkipMissing)(a, b) = dist.d(a, b) + +# Special case vector arguments where we can mask out incomplete cases +function (dist::SkipMissing)(a::AbstractVector, b::AbstractVector) + require_one_based_indexing(a) + require_one_based_indexing(b) + n = length(a) + length(b) == n || throw(DimensionMismatch("a and b have different lengths")) + + mask = BitVector(undef, n) + @inbounds for i in 1:n + mask[i] = !(ismissing(a[i]) || ismissing(b[i])) + end + + # Calling `_evaluate` allows us to also mask metric parameters like weights or periods + # I don't think this can be generalized to user defined metric types though without knowing + # what the parameters mean. + # NTOE: We call disallowmissings to avoid downstream type promotion issues. + if dist.d isa UnionMetrics + params = parameters(dist.d) + + return _evaluate( + dist.d, + disallowmissing(view(a, mask)), + disallowmissing(view(b, mask)), + isnothing(params) ? params : view(params, mask), + ) + else + return dist.d( + disallowmissing(view(a, mask)), + disallowmissing(view(b, mask)), + ) + end +end + +# Convenience function +skipmissing(dist::PreMetric, args...) = SkipMissing(dist)(args...) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c7b60ce..d210a17 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Distances using Test using LinearAlgebra +using Missings using OffsetArrays using Random using Statistics @@ -15,7 +16,7 @@ include("test_dists.jl") # Test ChainRules definitions on Julia versions that support weak dependencies # Support for extensions was added in # https://github.com/JuliaLang/julia/commit/93587d7c1015efcd4c5184e9c42684382f1f9ab2 -# https://github.com/JuliaLang/julia/pull/47695 +# https://github.com/JuliaLang/julia/pull/47695 if VERSION >= v"1.9.0-alpha1.18" include("chainrules.jl") end \ No newline at end of file diff --git a/test/test_dists.jl b/test/test_dists.jl index 395cc8d..4870ec6 100644 --- a/test/test_dists.jl +++ b/test/test_dists.jl @@ -575,7 +575,7 @@ function test_colwise(dist, x, y, T) # ≈ and all( .≈ ) seem to behave slightly differently for F64 @test all(colwise(dist, x, y) .≈ r1) @test all(colwise(dist, (x[:,i] for i in axes(x, 2)), (y[:,i] for i in axes(y, 2))) .≈ r1) - + @test colwise!(dist, r4, x, y) ≈ @test_deprecated(colwise!(r5, dist, x, y)) @test r4 ≈ r5 @@ -1051,6 +1051,124 @@ end end end +@testset "skip missing" begin + x = Float64[] + a = [1, missing, 3, missing, 5] + b = [6, missing, missing, 9, 10] + A = allowmissing(reshape(1:20, 5, 4)) + B = allowmissing(reshape(21:40, 5, 4)) + A[3, 1] = missing + B[4, 2] = missing + w = collect(0.2:0.2:1.0) + + # Sampling of different types of distance calculations to check against + dists = [ + SqEuclidean(), + Euclidean(), + Cityblock(), + TotalVariation(), + Chebyshev(), + Minkowski(2.5), + Hamming(), + Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), + CosineDist(), + CorrDist(), + ChiSqDist(), + KLDivergence(), + GenKLDivergence(), + RenyiDivergence(0.0), + JSDivergence(), + SpanNormDist(), + BhattacharyyaDist(), + HellingerDist(), + BrayCurtis(), + Jaccard(), + WeightedEuclidean(w), + WeightedSqEuclidean(w), + WeightedCityblock(w), + WeightedMinkowski(w, 2.5), + WeightedHamming(w), + MeanAbsDeviation(), + MeanSqDeviation(), + RMSDeviation(), + NormRMSDeviation(), + ] + + # NOTE: Most of this is special casing the failure conditions + # for using `missing` in the base metrics + @testset "Distance $dist" for dist in dists + D = Distances.SkipMissing(dist) + + # Baseline that our wrapped metric has the same empty case + # behaviour + if dist isa NormRMSDeviation + @test_throws ArgumentError dist(x, x) + @test_throws ArgumentError D(x, x) + elseif nameof(typeof(dist)) in nameof.(Distances.weightedmetrics) + @test_throws DimensionMismatch dist(x, x) + @test_throws BoundsError D(x, x) # Could choose to special case this? + elseif dist isa Union{BhattacharyyaDist, HellingerDist, MeanAbsDeviation, MeanSqDeviation, RMSDeviation} + @test isnan(dist(x, x)) + @test isnan(D(x, x)) + else + @test D(x, x) == dist(x, x) + end + + # Cover existing failure cases with missings + # TODO: Simplify this with an error variable + if dist isa Union{Hamming, WeightedHamming, ChiSqDist, KLDivergence, GenKLDivergence, SpanNormDist} + @test_throws TypeError dist(a, b) + @test_throws TypeError colwise(dist, A, B) + @test_throws TypeError pairwise(dist, A, B, dims=2) + @test_throws TypeError pairwise(dist, A, dims=2) + elseif dist isa JSDivergence + @test_throws TypeError dist(a, b) + @test_throws MethodError colwise(dist, A, B) + @test_throws MethodError pairwise(dist, A, B, dims=2) + @test_throws MethodError pairwise(dist, A, dims=2) + elseif dist isa Bregman + @test_throws ArgumentError dist(a, b) + @test_throws ArgumentError colwise(dist, A, B) + @test_throws ArgumentError pairwise(dist, A, B, dims=2) + @test_throws ArgumentError pairwise(dist, A, dims=2) + elseif dist isa CorrDist + @test_throws UndefVarError dist(a, b) + @test_throws UndefVarError colwise(dist, A, B) + @test_throws MethodError pairwise(dist, A, B, dims=2) + @test_throws MethodError pairwise(dist, A, dims=2) + elseif dist isa RenyiDivergence + @test_throws MethodError dist(a, b) + @test_throws MethodError colwise(dist, A, B) + @test_throws MethodError pairwise(dist, A, B, dims=2) + @test_throws MethodError pairwise(dist, A, dims=2) + + # Doesn't handle eltype Union{T, Missing} + @test_throws MethodError colwise(dist, A[:, 3:4], B[:, 3:4]) + else + @test ismissing(dist(a, b)) + @test_throws MethodError colwise(dist, A, B) + @test_throws MethodError pairwise(dist, A, B, dims=2) + @test_throws MethodError pairwise(dist, A, dims=2) + end + + # Handle weights + if dist isa Distances.UnionMetrics && Distances.parameters(dist) isa Vector + @test D(a, b) == Distances._evaluate(dist, [1, 5], [6, 10], Distances.parameters(dist)[[1, 5]]) + else + @test D(a, b) == dist([1, 5], [6, 10]) + end + + @test colwise(D, A, B)[3:4] ≈ colwise(dist, disallowmissing(A[:, 3:4]), disallowmissing(B[:, 3:4])) + + M = pairwise(D, A, B, dims=2) + @test size(M) == (4, 4) + @test !any(ismissing, M) + + M = pairwise(D, A, dims=2) + @test size(M) == (4, 4) + @test !any(ismissing, M) + end +end #= @testset "zero allocation colwise!" begin d = Euclidean()