-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC Added SkipMissing wrapper metric #261
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Exclude any missing indices from being included the wrappped distance metric. | ||
""" | ||
struct SkipMissing{D<:PreMetric} <: PreMetric | ||
d::D | ||
end | ||
|
||
result_type(dist::SkipMissing, a::Type, b::Type) = result_type(dist.d, a, b) | ||
|
||
# Always fallback to the internal metric behaviour | ||
(dist::SkipMissing)(a, b) = dist.d(a, b) | ||
|
||
# Special case vector arguments where we can mask out incomplete cases | ||
function (dist::SkipMissing)(a::AbstractVector, b::AbstractVector) | ||
require_one_based_indexing(a) | ||
require_one_based_indexing(b) | ||
n = length(a) | ||
length(b) == n || throw(DimensionMismatch("a and b have different lengths")) | ||
|
||
mask = BitVector(undef, n) | ||
@inbounds for i in 1:n | ||
mask[i] = !(ismissing(a[i]) || ismissing(b[i])) | ||
end | ||
|
||
# Calling `_evaluate` allows us to also mask metric parameters like weights or periods | ||
# I don't think this can be generalized to user defined metric types though without knowing | ||
# what the parameters mean. | ||
# NTOE: We call disallowmissings to avoid downstream type promotion issues. | ||
if dist.d isa UnionMetrics | ||
params = parameters(dist.d) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAICT, |
||
|
||
return _evaluate( | ||
dist.d, | ||
disallowmissing(view(a, mask)), | ||
disallowmissing(view(b, mask)), | ||
isnothing(params) ? params : view(params, mask), | ||
) | ||
else | ||
return dist.d( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if there's a safer fallback? |
||
disallowmissing(view(a, mask)), | ||
disallowmissing(view(b, mask)), | ||
) | ||
end | ||
end | ||
|
||
# Convenience function | ||
skipmissing(dist::PreMetric, args...) = SkipMissing(dist)(args...) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -575,7 +575,7 @@ function test_colwise(dist, x, y, T) | |
# ≈ and all( .≈ ) seem to behave slightly differently for F64 | ||
@test all(colwise(dist, x, y) .≈ r1) | ||
@test all(colwise(dist, (x[:,i] for i in axes(x, 2)), (y[:,i] for i in axes(y, 2))) .≈ r1) | ||
|
||
@test colwise!(dist, r4, x, y) ≈ @test_deprecated(colwise!(r5, dist, x, y)) | ||
@test r4 ≈ r5 | ||
|
||
|
@@ -1051,6 +1051,124 @@ end | |
end | ||
end | ||
|
||
@testset "skip missing" begin | ||
x = Float64[] | ||
a = [1, missing, 3, missing, 5] | ||
b = [6, missing, missing, 9, 10] | ||
A = allowmissing(reshape(1:20, 5, 4)) | ||
B = allowmissing(reshape(21:40, 5, 4)) | ||
A[3, 1] = missing | ||
B[4, 2] = missing | ||
w = collect(0.2:0.2:1.0) | ||
|
||
# Sampling of different types of distance calculations to check against | ||
dists = [ | ||
SqEuclidean(), | ||
Euclidean(), | ||
Cityblock(), | ||
TotalVariation(), | ||
Chebyshev(), | ||
Minkowski(2.5), | ||
Hamming(), | ||
Bregman(x -> sqeuclidean(x, zero(x)), x -> 2*x), | ||
CosineDist(), | ||
CorrDist(), | ||
ChiSqDist(), | ||
KLDivergence(), | ||
GenKLDivergence(), | ||
RenyiDivergence(0.0), | ||
JSDivergence(), | ||
SpanNormDist(), | ||
BhattacharyyaDist(), | ||
HellingerDist(), | ||
BrayCurtis(), | ||
Jaccard(), | ||
WeightedEuclidean(w), | ||
WeightedSqEuclidean(w), | ||
WeightedCityblock(w), | ||
WeightedMinkowski(w, 2.5), | ||
WeightedHamming(w), | ||
MeanAbsDeviation(), | ||
MeanSqDeviation(), | ||
RMSDeviation(), | ||
NormRMSDeviation(), | ||
] | ||
|
||
# NOTE: Most of this is special casing the failure conditions | ||
# for using `missing` in the base metrics | ||
@testset "Distance $dist" for dist in dists | ||
D = Distances.SkipMissing(dist) | ||
|
||
# Baseline that our wrapped metric has the same empty case | ||
# behaviour | ||
if dist isa NormRMSDeviation | ||
@test_throws ArgumentError dist(x, x) | ||
@test_throws ArgumentError D(x, x) | ||
elseif nameof(typeof(dist)) in nameof.(Distances.weightedmetrics) | ||
@test_throws DimensionMismatch dist(x, x) | ||
@test_throws BoundsError D(x, x) # Could choose to special case this? | ||
elseif dist isa Union{BhattacharyyaDist, HellingerDist, MeanAbsDeviation, MeanSqDeviation, RMSDeviation} | ||
@test isnan(dist(x, x)) | ||
@test isnan(D(x, x)) | ||
else | ||
@test D(x, x) == dist(x, x) | ||
end | ||
|
||
# Cover existing failure cases with missings | ||
# TODO: Simplify this with an error variable | ||
if dist isa Union{Hamming, WeightedHamming, ChiSqDist, KLDivergence, GenKLDivergence, SpanNormDist} | ||
@test_throws TypeError dist(a, b) | ||
@test_throws TypeError colwise(dist, A, B) | ||
@test_throws TypeError pairwise(dist, A, B, dims=2) | ||
@test_throws TypeError pairwise(dist, A, dims=2) | ||
elseif dist isa JSDivergence | ||
@test_throws TypeError dist(a, b) | ||
@test_throws MethodError colwise(dist, A, B) | ||
@test_throws MethodError pairwise(dist, A, B, dims=2) | ||
@test_throws MethodError pairwise(dist, A, dims=2) | ||
elseif dist isa Bregman | ||
@test_throws ArgumentError dist(a, b) | ||
@test_throws ArgumentError colwise(dist, A, B) | ||
@test_throws ArgumentError pairwise(dist, A, B, dims=2) | ||
@test_throws ArgumentError pairwise(dist, A, dims=2) | ||
elseif dist isa CorrDist | ||
@test_throws UndefVarError dist(a, b) | ||
@test_throws UndefVarError colwise(dist, A, B) | ||
@test_throws MethodError pairwise(dist, A, B, dims=2) | ||
@test_throws MethodError pairwise(dist, A, dims=2) | ||
elseif dist isa RenyiDivergence | ||
@test_throws MethodError dist(a, b) | ||
@test_throws MethodError colwise(dist, A, B) | ||
@test_throws MethodError pairwise(dist, A, B, dims=2) | ||
@test_throws MethodError pairwise(dist, A, dims=2) | ||
|
||
# Doesn't handle eltype Union{T, Missing} | ||
@test_throws MethodError colwise(dist, A[:, 3:4], B[:, 3:4]) | ||
else | ||
@test ismissing(dist(a, b)) | ||
@test_throws MethodError colwise(dist, A, B) | ||
@test_throws MethodError pairwise(dist, A, B, dims=2) | ||
@test_throws MethodError pairwise(dist, A, dims=2) | ||
end | ||
|
||
# Handle weights | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're just testing that |
||
if dist isa Distances.UnionMetrics && Distances.parameters(dist) isa Vector | ||
@test D(a, b) == Distances._evaluate(dist, [1, 5], [6, 10], Distances.parameters(dist)[[1, 5]]) | ||
else | ||
@test D(a, b) == dist([1, 5], [6, 10]) | ||
end | ||
|
||
@test colwise(D, A, B)[3:4] ≈ colwise(dist, disallowmissing(A[:, 3:4]), disallowmissing(B[:, 3:4])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double check results are the same for non-missing columns. |
||
|
||
M = pairwise(D, A, B, dims=2) | ||
@test size(M) == (4, 4) | ||
@test !any(ismissing, M) | ||
|
||
M = pairwise(D, A, dims=2) | ||
@test size(M) == (4, 4) | ||
@test !any(ismissing, M) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could manually check the calculations here, but I figured a few smoke tests that it doesn't error or return |
||
end | ||
end | ||
#= | ||
@testset "zero allocation colwise!" begin | ||
d = Euclidean() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if there's a benefit to have specific subtypes here? All the metrics I tested seemed to work.