Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Wasserstein distance #159

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

charleskawczynski
Copy link

This PR basically includes a translation of wasserstein_distance. I'm not familiar with this distance, so if someone can help add/improve correctness tests I would greatly appreciate it!

@charleskawczynski
Copy link
Author

Also, I've only tested this function for my own application without weights. The function call with weights does not break for some inputs, but I've not yet tested it for correctness.

@johnnychen94
Copy link
Contributor

johnnychen94 commented Feb 16, 2020

Needs to be consistent with the current API and naming style:

dist = Wasserstein(u_weights, v_weights)

evaluate(dist, u, v) # might be deprecated in the future: #22
dist(u, v)
wasserstein(u, v, u_weights, v_weights)

Also, could you add some tests for pairwise and colwise, just to make sure they're working as expected?

FYI, there's an implementation of earthmoving distance in https://github.com/mirkobunse/EarthMoversDistance.jl and JuliaImages/ImageDistances.jl#4 (cc: @timholy )

@charleskawczynski
Copy link
Author

Needs to be consistent with the current API and naming style:

Ok, thanks, I can update this.

Also, could you add some tests for pairwise and colwise, just to make sure they're working as expected?

I may need help with this. Can you elaborate?

FYI, there's a version in https://github.com/mirkobunse/EarthMoversDistance.jl

I tried this first, but the results didn't match up when I compared this with Python's wasserstein_distance.

@johnnychen94
Copy link
Contributor

johnnychen94 commented Feb 16, 2020

Also, could you add some tests for pairwise and colwise, just to make sure they're working as expected?

I may need help with this. Can you elaborate?

There're test_pairwise and test_colwise helpers in test codes that you can easily take advantage of.

You can take a look at https://github.com/JuliaStats/Distances.jl/blob/master/src/bregman.jl on how codes are organized for a new distance type.

@charleskawczynski
Copy link
Author

Ok, I adjusted the API, but I'm not exactly sure what to expect with pairwise and colwise and they're not passing. I'm not sure I tried using JuliaImages/ImageDistances.jl#4, so perhaps that's what I need. But I imagine that having an implementation in this package seems natural. I'll try to fix the broken bits later, but if anyone else would like to, please feel free!

@Sh4pe
Copy link

Sh4pe commented Feb 23, 2020

Please note that I also created a PR that adds the Wasserstein distance (#158). I implemented it using formula (2.5) from Optimal Transport on Discrete Domains and that paper shows that calculating the Wasserstein distance entails solving a linear program. My PR is thus more "heavy" because I include a solver for linear programs.

Do you know how its possible that you (at least not on first sight) do not need to solve a LP to calculate the distance?

@charleskawczynski
Copy link
Author

Please note that I also created a PR that adds the Wasserstein distance (#158). I implemented it using formula (2.5) from Optimal Transport on Discrete Domains and that paper shows that calculating the Wasserstein distance entails solving a linear program. My PR is thus more "heavy" because I include a solver for linear programs.

I actually only saw that PR after opening this one.

Do you know how its possible that you (at least not on first sight) do not need to solve a LP to calculate the distance?

I don't know, because I'm not really familiar with this measure. I was just being pragmatic and naively translated the scipy wasserstein_distance because it was being used in a Python code that I needed in Julia, and thought I could remove its dependence on PyCall.

@dsweber2
Copy link

Do you know how its possible that you (at least not on first sight) do not need to solve a LP to calculate the distance?

This is because the 1D solution, as noted on the scipy page, albeit vaguely is closed form, whereas for 2+D you need to actually compute the transport plan. @Sh4pe, given that there don't seem to be 2D distances here, it would probably make more sense as a stand-alone package, or possibly integrated into the EMD package. Either way, thanks for writing this up.

FYI, there's a version in https://github.com/mirkobunse/EarthMoversDistance.jl
I tried this first, but the results didn't match up when I compared this with Python's wasserstein_distance.

Mind going into this? As someone who was planning on using one of these, the inconsistency gives me pause. Guessing you used Euclidean() for W2 and Cityblock() for W1?

@devmotion
Copy link
Member

FYI: https://github.com/JuliaOptimalTransport/OptimalTransport.jl contains specialized implementations for the 1D solution, both for discrete and continuous distributions, but also a pure Julia implementation for solving the optimal transport problem with discrete measures in arbitrary dimensions (and soon there will be also a specialized implementation for multivariate normal distributions). We just added wasserstein as an interface for computing the p-Wasserstein distance.

@charleskawczynski
Copy link
Author

FYI: https://github.com/JuliaOptimalTransport/OptimalTransport.jl contains specialized implementations for the 1D solution, both for discrete and continuous distributions, but also a pure Julia implementation for solving the optimal transport problem with discrete measures in arbitrary dimensions (and soon there will be also a specialized implementation for multivariate normal distributions). We just added wasserstein as an interface for computing the p-Wasserstein distance.

Is this equivalent to https://github.com/scipy/scipy/blob/v1.4.1/scipy/stats/stats.py#L6934-L7008?

I'm a bit tired of losing traction on PRs like this. Why is the pairwise functionality necessary @johnnychen94 ? Should I just pull this out and use it where we need and not bother using Distances.jl?

@devmotion
Copy link
Member

Yes, it provides everything in that function (and some additional features). For instance, the example in the scipy docstring could be written as (functionality for exact computations was extracted to the more lightweight ExactOptimalTransport.jl package but it's reexported by OptimalTransport.jl so you could use that one as well):

julia> using ExactOptimalTransport, LinearAlgebra

julia> wasserstein(discretemeasure([0, 1, 3]), discretemeasure([5, 6, 8]))
5.0

julia> wasserstein(discretemeasure([0, 1], normalize!([3.0, 1.0], 1)), discretemeasure([0, 1], normalize!([2.0, 2.0], 1)))
0.25

julia> wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)))
4.078133143804784

Additionally, you can change the metric (default: metric = Distances.Euclidean()) and the order (default: p = Val(1)):

julia> using Distances

julia> wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)); metric = SqEuclidean())
19.08963752665245

julia> wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)); p = 2)^2
19.089637526652446

And a bit more efficient if you want to compute the squared 2-Wasserstein distance (by default with respect to the Euclidean metric):

julia> squared2wasserstein(discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)))
19.08963752665245

Internally, discretemeasure just uses a Distributions.DiscreteNonParametric for the univariate case, so you could use it directly instead (note though that discretemeasure supports the multivariate case whereas DiscreteNonParametric doesn't):

julia> using Distributions

julia> wasserstein(DiscreteNonParametric([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), DiscreteNonParametric([4.5, 1.4], normalize!([3.2, 3.5], 1)))
4.078133143804784

julia> wasserstein(Categorical(normalize!([1.4, 0.9, 3.1, 7.2], 1)), Categorical(normalize!([3.2, 3.5], 1)))
1.7553897180762852

Wasserstein distances can be computed analytically also for uni- and multivariate normal distributions with respect to the squared Euclidean distance:

julia> wasserstein(MvNormal([1.2, -4.2, 3.3], Diagonal([0.1, 1.4, 0.6])), MvNormal(I(3)); p = Val(2))
5.5246646247752915

julia> squared2wasserstein(MvNormal([1.2, -4.2, 3.3], Diagonal([0.1, 1.4, 0.6])), MvNormal(I(3)))
30.521919216243514

More generally, for continuous univariat distributions the Wasserstein distance can be computed as an integral of the metric composed with the respective quantile functions on the unit interval. ExactOptimalTransport computes the integral numerically with QuadGK:

julia> wasserstein(Normal(3.1, 2.1), Laplace(1.2, 0.5))
1.9961751319645216

Fundamentally, wasserstein and squared2wasserstein are more convenient interfaces with common defaults for ot_cost. You can also obtain the optimal transport plan (as a sparse matrix for discrete distributions and a function for continuous distributions):

julia> ot_plan(Euclidean(), discretemeasure([3.4, 3.9, 7.5, 7.8], normalize!([1.4, 0.9, 3.1, 7.2], 1)), discretemeasure([4.5, 1.4], normalize!([3.2, 3.5], 1)))
4×2 SparseArrays.SparseMatrixCSC{Float64, Int64} with 5 stored entries:
 0.111111    
 0.0714286   
 0.246032    
 0.0938166  0.477612

julia> f = ot_plan(Euclidean(), Normal(3.1, 2.1), Normal())
(::ExactOptimalTransport.var"#T#10"{Normal{Float64}, Normal{Float64}}) (generic function with 1 method)

julia> f(1.0)
-1.0000000000000002

julia> f(3.1)
0.0

If you have two histograms and a cost matrix, you can compute the optimal transport cost with emd2 and the corresponding plan with emd (IIRC the names are taken from the Python OT package, for which a basic interface exists as well: https://github.com/JuliaOptimalTransport/PythonOT.jl).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants