Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC Added SkipMissing wrapper metric #261

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.10.11"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand All @@ -19,6 +20,7 @@ DistancesSparseArraysExt = "SparseArrays"
[compat]
ChainRulesCore = "1"
LinearAlgebra = "<0.0.1, 1"
Missings = "1"
SparseArrays = "<0.0.1, 1"
Statistics = "<0.0.1, 1"
StatsAPI = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/Distances.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Distances

using LinearAlgebra
using Missings
using Statistics: mean
import StatsAPI: pairwise, pairwise!

Expand Down Expand Up @@ -116,6 +117,7 @@ include("haversine.jl")
include("mahalanobis.jl")
include("bhattacharyya.jl")
include("bregman.jl")
include("missing.jl")

include("deprecated.jl")

Expand Down
47 changes: 47 additions & 0 deletions src/missing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Exclude any missing indices from being included the wrappped distance metric.
"""
struct SkipMissing{D<:PreMetric} <: PreMetric
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there's a benefit to have specific subtypes here? All the metrics I tested seemed to work.

d::D
end

result_type(dist::SkipMissing, a::Type, b::Type) = result_type(dist.d, a, b)

# Always fallback to the internal metric behaviour
(dist::SkipMissing)(a, b) = dist.d(a, b)

# Special case vector arguments where we can mask out incomplete cases
function (dist::SkipMissing)(a::AbstractVector, b::AbstractVector)
require_one_based_indexing(a)
require_one_based_indexing(b)
n = length(a)
length(b) == n || throw(DimensionMismatch("a and b have different lengths"))

mask = BitVector(undef, n)
@inbounds for i in 1:n
mask[i] = !(ismissing(a[i]) || ismissing(b[i]))
end

# Calling `_evaluate` allows us to also mask metric parameters like weights or periods
# I don't think this can be generalized to user defined metric types though without knowing
# what the parameters mean.
# NTOE: We call disallowmissings to avoid downstream type promotion issues.
if dist.d isa UnionMetrics
params = parameters(dist.d)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT, parameters is only defined for UnionMetrics?


return _evaluate(
dist.d,
disallowmissing(view(a, mask)),
disallowmissing(view(b, mask)),
isnothing(params) ? params : view(params, mask),
)
else
return dist.d(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if there's a safer fallback?

disallowmissing(view(a, mask)),
disallowmissing(view(b, mask)),
)
end
end

# Convenience function
skipmissing(dist::PreMetric, args...) = SkipMissing(dist)(args...)
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Distances

using Test
using LinearAlgebra
using Missings
using OffsetArrays
using Random
using Statistics
Expand All @@ -15,7 +16,7 @@ include("test_dists.jl")
# Test ChainRules definitions on Julia versions that support weak dependencies
# Support for extensions was added in
# https://github.com/JuliaLang/julia/commit/93587d7c1015efcd4c5184e9c42684382f1f9ab2
# https://github.com/JuliaLang/julia/pull/47695
# https://github.com/JuliaLang/julia/pull/47695
if VERSION >= v"1.9.0-alpha1.18"
include("chainrules.jl")
end
120 changes: 119 additions & 1 deletion test/test_dists.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ function test_colwise(dist, x, y, T)
# ≈ and all( .≈ ) seem to behave slightly differently for F64
@test all(colwise(dist, x, y) .≈ r1)
@test all(colwise(dist, (x[:,i] for i in axes(x, 2)), (y[:,i] for i in axes(y, 2))) .≈ r1)

@test colwise!(dist, r4, x, y) ≈ @test_deprecated(colwise!(r5, dist, x, y))
@test r4 ≈ r5

Expand Down Expand Up @@ -1051,6 +1051,124 @@ end
end
end

@testset "skip missing" begin
x = Float64[]
a = [1, missing, 3, missing, 5]
b = [6, missing, missing, 9, 10]
A = allowmissing(reshape(1:20, 5, 4))
B = allowmissing(reshape(21:40, 5, 4))
A[3, 1] = missing
B[4, 2] = missing
w = collect(0.2:0.2:1.0)

# Sampling of different types of distance calculations to check against
dists = [
SqEuclidean(),
Euclidean(),
Cityblock(),
TotalVariation(),
Chebyshev(),
Minkowski(2.5),
Hamming(),
Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x),
CosineDist(),
CorrDist(),
ChiSqDist(),
KLDivergence(),
GenKLDivergence(),
RenyiDivergence(0.0),
JSDivergence(),
SpanNormDist(),
BhattacharyyaDist(),
HellingerDist(),
BrayCurtis(),
Jaccard(),
WeightedEuclidean(w),
WeightedSqEuclidean(w),
WeightedCityblock(w),
WeightedMinkowski(w, 2.5),
WeightedHamming(w),
MeanAbsDeviation(),
MeanSqDeviation(),
RMSDeviation(),
NormRMSDeviation(),
]

# NOTE: Most of this is special casing the failure conditions
# for using `missing` in the base metrics
@testset "Distance $dist" for dist in dists
D = Distances.SkipMissing(dist)

# Baseline that our wrapped metric has the same empty case
# behaviour
if dist isa NormRMSDeviation
@test_throws ArgumentError dist(x, x)
@test_throws ArgumentError D(x, x)
elseif nameof(typeof(dist)) in nameof.(Distances.weightedmetrics)
@test_throws DimensionMismatch dist(x, x)
@test_throws BoundsError D(x, x) # Could choose to special case this?
elseif dist isa Union{BhattacharyyaDist, HellingerDist, MeanAbsDeviation, MeanSqDeviation, RMSDeviation}
@test isnan(dist(x, x))
@test isnan(D(x, x))
else
@test D(x, x) == dist(x, x)
end

# Cover existing failure cases with missings
# TODO: Simplify this with an error variable
if dist isa Union{Hamming, WeightedHamming, ChiSqDist, KLDivergence, GenKLDivergence, SpanNormDist}
@test_throws TypeError dist(a, b)
@test_throws TypeError colwise(dist, A, B)
@test_throws TypeError pairwise(dist, A, B, dims=2)
@test_throws TypeError pairwise(dist, A, dims=2)
elseif dist isa JSDivergence
@test_throws TypeError dist(a, b)
@test_throws MethodError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)
elseif dist isa Bregman
@test_throws ArgumentError dist(a, b)
@test_throws ArgumentError colwise(dist, A, B)
@test_throws ArgumentError pairwise(dist, A, B, dims=2)
@test_throws ArgumentError pairwise(dist, A, dims=2)
elseif dist isa CorrDist
@test_throws UndefVarError dist(a, b)
@test_throws UndefVarError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)
elseif dist isa RenyiDivergence
@test_throws MethodError dist(a, b)
@test_throws MethodError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)

# Doesn't handle eltype Union{T, Missing}
@test_throws MethodError colwise(dist, A[:, 3:4], B[:, 3:4])
else
@test ismissing(dist(a, b))
@test_throws MethodError colwise(dist, A, B)
@test_throws MethodError pairwise(dist, A, B, dims=2)
@test_throws MethodError pairwise(dist, A, dims=2)
end

# Handle weights
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're just testing that skipmissing returns the same results as deleting indices missing in either vector.

if dist isa Distances.UnionMetrics && Distances.parameters(dist) isa Vector
@test D(a, b) == Distances._evaluate(dist, [1, 5], [6, 10], Distances.parameters(dist)[[1, 5]])
else
@test D(a, b) == dist([1, 5], [6, 10])
end

@test colwise(D, A, B)[3:4] ≈ colwise(dist, disallowmissing(A[:, 3:4]), disallowmissing(B[:, 3:4]))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double check results are the same for non-missing columns.


M = pairwise(D, A, B, dims=2)
@test size(M) == (4, 4)
@test !any(ismissing, M)

M = pairwise(D, A, dims=2)
@test size(M) == (4, 4)
@test !any(ismissing, M)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could manually check the calculations here, but I figured a few smoke tests that it doesn't error or return missings was good enough?

end
end
#=
@testset "zero allocation colwise!" begin
d = Euclidean()
Expand Down
Loading