球面場のWasserstein Barycenter

  • 球面にスカラー場があって、離散観察をしたとする
  • その場は非負、総和が1
  • 球面の点の間には測地線距離(角)が取れる
  • メトリックスペースに非負総和1分布があれば、その間にWasserstein 距離が定まり、最小運搬コスト運搬行列も定まる
  • pythonのPOTパッケージを使っていじってみる
import ot
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib.cm as cmx
n = 500
X,Y,Z = np.random.randn(n),np.random.randn(n),np.random.randn(n)
r = np.sqrt(X**2+Y**2+Z**2)
X,Y,Z = X/r,Y/r,Z/r
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter3D(np.ravel(X),np.ravel(Y),np.ravel(Z))
plt.show()
XYZ = np.array([X,Y,Z])
ip = np.dot(XYZ.T,XYZ)
ip[ip>1] = 1
ip[ip<0] = 0
cs = np.arccos(ip)
cs.shape
im1 = np.random.randn(X.size)
im2 = np.random.randn(X.size)
im1 = im1 - min(im1) + 0.01
im2 = im2 - min(im2) + 0.01
plt.plot(im1,im2)
plt.show()
im1[(X>0.) & (Y>0.) & (Z>0.)] = max(im1) + 5
im2[(X<0.) & (Y>0.1) & (Z>-0.1)] = max(im2) + 5
im1 = im1/sum(im1)
im2 = im2/sum(im2)
import matplotlib.cm as cmx
import matplotlib.colors
from mpl_toolkits.mplot3d import Axes3D
def scatter3d(x,y,z, cs, colorsMap='jet'):
    cm = plt.get_cmap(colorsMap)
    cNorm = matplotlib.colors.Normalize(vmin=min(cs), vmax=max(cs))
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.scatter(x, y, z, c=scalarMap.to_rgba(cs))
    scalarMap.set_array(cs)
    fig.colorbar(scalarMap)
    plt.show()
scatter3d(X,Y,Z,im1)
plt.plot(X,im1)
plt.show()
reg = 60/np.median(cs)
sinkhornout = ot.sinkhorn(im1,im2, cs, reg)
plt.plot(sinkhornout)
plt.show()
q = np.where(sinkhornout > np.mean(sinkhornout))
q