Sinkhorn's matrix balancingメモ

.. math::
        \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
        s.t. \gamma 1 = a
             \gamma^T 1= b
             \gamma\geq 0
    where :
>>> import ot
    >>> a=[.5,.5]
    >>> b=[.5,.5]
    >>> M=[[0.,1.],[1.,0.]]
    >>> ot.sinkhorn(a,b,M,1)
    array([[ 0.36552929,  0.13447071],
           [ 0.13447071,  0.36552929]])
> M <- matrix(c(0.36552929,0.13447071,0.13447071,0.36552929),2,2)
> apply(M,1,sum)
[1] 0.5 0.5
> apply(M,2,sum)
[1] 0.5 0.5
  • Rに翻訳
my.sinkhorn <- function(a,b,M,reg,numItermax = 1000,stopThr=1e-9){
	na <- length(a)
	nb <- length(b)
	u <- uprev <- rep(1/na,na)
	v <- vprev <- rep(1/nb,nb)
	K <- exp(-M/reg)
	Kp <- diag(1/a) %*% K
	cnt <- 1
	loop <- TRUE
	while(loop){
		if(prod(t(K) %*% u)==0 || sum(is.na(u))>0 || sum(is.na(v))>0 || cnt>numItermax){
			loop <- FALSE
			u <- uprev
			v <- vprev
		}
		uprev <- u
		vprev <- v
		v <- b/(t(K) %*% u)
		u <- 1/(Kp %*% v)
		cnt <- cnt+1
	}
	return(diag(c(u)) %*% K %*% diag(c(v)))
}

my.sinkhorn(c(0.5,0.5),c(0.5,0.5),matrix(c(0,1,1,0),2,2),1)