-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Wasserstein distance #159
base: master
Are you sure you want to change the base?
Add Wasserstein distance #159
Conversation
Also, I've only tested this function for my own application without weights. The function call with weights does not break for some inputs, but I've not yet tested it for correctness. |
Needs to be consistent with the current API and naming style: dist = Wasserstein(u_weights, v_weights)
evaluate(dist, u, v) # might be deprecated in the future: #22
dist(u, v)
wasserstein(u, v, u_weights, v_weights) Also, could you add some tests for FYI, there's an implementation of earthmoving distance in https://github.com/mirkobunse/EarthMoversDistance.jl and JuliaImages/ImageDistances.jl#4 (cc: @timholy ) |
Ok, thanks, I can update this.
I may need help with this. Can you elaborate?
I tried this first, but the results didn't match up when I compared this with Python's |
There're You can take a look at https://github.com/JuliaStats/Distances.jl/blob/master/src/bregman.jl on how codes are organized for a new distance type. |
a37bdff
to
484d20b
Compare
484d20b
to
1e88fe4
Compare
Ok, I adjusted the API, but I'm not exactly sure what to expect with |
Please note that I also created a PR that adds the Wasserstein distance (#158). I implemented it using formula (2.5) from Optimal Transport on Discrete Domains and that paper shows that calculating the Wasserstein distance entails solving a linear program. My PR is thus more "heavy" because I include a solver for linear programs. Do you know how its possible that you (at least not on first sight) do not need to solve a LP to calculate the distance? |
I actually only saw that PR after opening this one.
I don't know, because I'm not really familiar with this measure. I was just being pragmatic and naively translated the scipy wasserstein_distance because it was being used in a Python code that I needed in Julia, and thought I could remove its dependence on PyCall. |
This is because the 1D solution, as noted on the scipy page, albeit vaguely is closed form, whereas for 2+D you need to actually compute the transport plan. @Sh4pe, given that there don't seem to be 2D distances here, it would probably make more sense as a stand-alone package, or possibly integrated into the EMD package. Either way, thanks for writing this up.
Mind going into this? As someone who was planning on using one of these, the inconsistency gives me pause. Guessing you used |
FYI: https://github.com/JuliaOptimalTransport/OptimalTransport.jl contains specialized implementations for the 1D solution, both for discrete and continuous distributions, but also a pure Julia implementation for solving the optimal transport problem with discrete measures in arbitrary dimensions (and soon there will be also a specialized implementation for multivariate normal distributions). We just added |
Is this equivalent to https://github.com/scipy/scipy/blob/v1.4.1/scipy/stats/stats.py#L6934-L7008? I'm a bit tired of losing traction on PRs like this. Why is the pairwise functionality necessary @johnnychen94 ? Should I just pull this out and use it where we need and not bother using Distances.jl? |
Yes, it provides everything in that function (and some additional features). For instance, the example in the scipy docstring could be written as (functionality for exact computations was extracted to the more lightweight ExactOptimalTransport.jl package but it's reexported by OptimalTransport.jl so you could use that one as well): julia> using ExactOptimalTransport, LinearAlgebra
julia> wasserstein(discretemeasure([0, 1, 3]), discretemeasure([5, 6, 8]))
5.0
julia> wasserstein(discretemeasure([0, 1], normalize!([3.0, 1.0], 1)), discretemeasure([0, 1], normalize!([2.0, 2.0], 1)))
0.25
julia> wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)))
4.078133143804784 Additionally, you can change the metric (default: julia> using Distances
julia> wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)); metric = SqEuclidean())
19.08963752665245
julia> wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)); p = 2)^2
19.089637526652446 And a bit more efficient if you want to compute the squared 2-Wasserstein distance (by default with respect to the Euclidean metric): julia> squared2wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)))
19.08963752665245 Internally, julia> using Distributions
julia> wasserstein(DiscreteNonParametric([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), DiscreteNonParametric([4.5, 1.4], normalize!([3.2, 3.5], 1)))
4.078133143804784
julia> wasserstein(Categorical(normalize!([1.4, 0.9, 3.1, 7.2], 1)), Categorical(normalize!([3.2, 3.5], 1)))
1.7553897180762852 Wasserstein distances can be computed analytically also for uni- and multivariate normal distributions with respect to the squared Euclidean distance: julia> wasserstein(MvNormal([1.2, -4.2, 3.3], Diagonal([0.1, 1.4, 0.6])), MvNormal(I(3)); p = Val(2))
5.5246646247752915
julia> squared2wasserstein(MvNormal([1.2, -4.2, 3.3], Diagonal([0.1, 1.4, 0.6])), MvNormal(I(3)))
30.521919216243514 More generally, for continuous univariat distributions the Wasserstein distance can be computed as an integral of the metric composed with the respective quantile functions on the unit interval. ExactOptimalTransport computes the integral numerically with QuadGK: julia> wasserstein(Normal(3.1, 2.1), Laplace(1.2, 0.5))
1.9961751319645216 Fundamentally, julia> ot_plan(Euclidean(), discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)))
4×2 SparseArrays.SparseMatrixCSC{Float64, Int64} with 5 stored entries:
0.111111 ⋅
0.0714286 ⋅
0.246032 ⋅
0.0938166 0.477612
julia> f = ot_plan(Euclidean(), Normal(3.1, 2.1), Normal())
(::ExactOptimalTransport.var"#T#10"{Normal{Float64}, Normal{Float64}}) (generic function with 1 method)
julia> f(1.0)
-1.0000000000000002
julia> f(3.1)
0.0 If you have two histograms and a cost matrix, you can compute the optimal transport cost with |
This PR basically includes a translation of wasserstein_distance. I'm not familiar with this distance, so if someone can help add/improve correctness tests I would greatly appreciate it!