Rcppで格子遊び2

  • 完全2部グラフから、条件を満たす辺のリストを作る
    • 2部グラフの一塊はx+y+z+w=n-1のベクトルの集まり、もう一塊がx+y+z+w=nのベクトルの集まり
    • それぞれ、相当のベクトル数ではあるけれど、全ペアの距離を調べてリストアップする
    • ただし、出来上がるエッジ本数はあらかじめわかっている
    • さきほどの、頂点数を計算する関数と同じファイルに以下のように書いておく(関数の順番はどちらがどちらでもよい)
  • RとC++の間でのデータ転送はそれなりに時間がかかるらしい。特に大きな行列でのやりとりは大変らしいが、そんなときは、reference渡しがよいそうだ(こちら)。使ってみる。おそらくdeep copyをせずにshallow copyをするだけなのが軽い理由なのかと…
    • 引数として const arma::mat& x これが引数をreferenceとしてarma::matの形で取り込み、関数の中でxという名前でアクセスする、ということ
  • その他
    • 行列xの行数をx.n_rowsで取り出す
    • a行b列の全要素が0の行列をarma::zeros(a,b)で作る
    • 二つの引数行列x,yのそれぞれの全行の総当たり。ベクトルペアのハミング距離を出して、条件判定をして登録している
// latticeConnect.cpp
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>
using namespace arma;  // use the Armadillo library for matrix computations
using namespace Rcpp;

// [[Rcpp::export]]
arma::mat latticeConnect(const arma::mat& x, const arma::mat& y) {
  
  int nrowx = x.n_rows;
  int ncolx = x.n_cols;
  int nrowy = y.n_rows;
  arma::mat ret = arma::zeros(nrowx*ncolx,2);
  int cnt = 0;
  for (int i = 0; i < nrowx; i++) {
    arma::rowvec vx = x.row(i);
    for(int j = 0; j < nrowy; j++) {
      arma::rowvec vy = y.row(j);
      int tmp = 0;
      for(int k = 0; k < ncolx; k++){
        tmp = tmp + abs(vx(k)-vy(k));
      }
      if(tmp==1){
        ret(cnt,0) = i;
        ret(cnt,1) = j;
        cnt++;
      }
    }
  }
  return ret;

}

// [[Rcpp::export]]
arma::vec numTetrahedron(int n) {
  arma::vec ret = vec(n+1);
  ret[0] = 1;
  for(int i=1; i < n+1 ; i++){
    ret[i] = (i+3)*ret[i-1]/i;
  }
  return ret;
}
// END
  • これを使ってRの関数を作り直し
table.graph4 <- function(n,d){
	# Make a list of tables when i = 0,1,2,...,n, 
	num.tab <- rep(0,n+1)
	num.tab[1] <- 1
	for(i in 2:(n+1)){
		num.tab[i] <- (i+2)*num.tab[i-1]/(i-1)
	}
	ret <- list()
	ret[[1]] <- matrix(0,1,d)
	for(i in 1:n){
		ret[[i+1]] <- matrix(0,num.tab[i+1],d)
		cnt <- 1
		tmp <- ret[[i]]
		tmp[,1] <- tmp[,1] + 1
		ret[[i+1]][cnt:(cnt+length(tmp[,1])-1),] <- tmp
		cnt <- cnt + length(tmp[,1])
		for(j in 2:d){
			if(j==2){
				s <- which(ret[[i]][,1]==0)
			}else{
				s <- which(apply(matrix(ret[[i]][,1:(j-1)],ncol=j-1),1,sum)==0)
			}
			
			tmpret <- matrix(ret[[i]][s,],nrow=length(s),ncol=d)
			tmpret[,j] <- tmpret[,j] + 1
			#tmp <- rbind(tmp,tmpret)
			ret[[i+1]][cnt:(cnt+length(s)-1),] <- tmpret
			cnt <- cnt+length(s)
		}
		#ret[[i+1]] <- tmp
	}
	# Make matrices that indicate connection between tables for n=k and tables for n=k+1
	ret2 <- list()
	for(i in 1:n){
		#ret2[[i]] <- matrix(0,nrow(ret[[i]]),nrow(ret[[i+1]]))
		ret2[[i]] <- latticeConnect(ret[[i]],ret[[i+1]]) + 1
	}
	return(list(tblmat=ret,connection=ret2))
}
  • 速さ比べ
system.time(tbl.out2 <- table.graph2(15,4))
system.time(tbl.out3 <- table.graph3(15,4))
library(Rcpp)
sourceCpp("latticeConnect.cpp")
system.time(tbl.out4 <- table.graph4(15,4))
> system.time(tbl.out2 <- table.graph2(15,4))
   ユーザ   システム       経過  
      2.76       0.00       2.77 
> system.time(tbl.out3 <- table.graph3(15,4))
   ユーザ   システム       経過  
      2.67       0.00       2.67 
> library(Rcpp)
Warning message:
 パッケージ ‘Rcpp’ はバージョン 3.1.2 の R の下で造られました  
> sourceCpp("latticeConnect.cpp")
> system.time(tbl.out4 <- table.graph4(15,4))
   ユーザ   システム       経過  
      0.03       0.00       0.03