鍋あり谷あり

テーマを決めずに適当に書いています。

in-place merge sort

http://blog.livedoor.jp/dankogai/archives/50957658.html

現代では一時メモリーを使わないin-place merge sortが開発されている

と書いてある。そういえば、STLの stable_sort の計算量が O( N (log N)**2 )だったよなぁと思い、どうやったらそうなるのか調べてみたら
http://thomas.baudel.name/Visualisation/VisuTri/inplacestablesort.html
に実装があるのを発見した(いやその前にSTLのソースを見たんだが、とても読みにくかったので断念した)。
で。ソースから計算内容を理解し、ああなるほどそうするのかと思ってみた。
というわけで、どんな計算なのかを私なりに説明してみる:

  1. 左半分と右半分を整列する。
  2. 左半分と右半分をマージする。

ここまではマージソートそのものだが、普通にマージするとO(N)のメモリが要る。そこで、ちょっとクイックソートと似た感じの工夫をする:

    1. 適当な値 p を決め、「左でp未満」「左でp以上」「右でp未満」「右でp以上」に分ける。
    2. 「左でp以上」と「右でp未満」を、ローテートで入れ替える。すると。左の方にp未満が集まり、右の方にp以上が集まる。
    3. 「左でp未満」と「元右でp未満」をマージする。
    4. 「元左でp以上」と「右でp以上」をマージする。

この「適当な値p」がちゃんと選べないと、クイックソートのように大損害なんだと思う。違うかな。
あと。
ローテートを反転三回で実装するのは、あんまり速くない。面倒でも地道に書いた方がいい。
で。
折角なので書いてみた。

int 
find_last_less( Foo * p, int n, Foo const & key, int cmp( Foo const &, Foo const & ) )
{
	if ( !( cmp( p[0], key )<0 ) ){
		return -1;
	}
	if ( cmp( p[n-1], key )<0 ){
		return n-1;
	}
	int less=0;
	int ge=n-1;
	while( less+1<ge ){
		int mid = less + (ge-less)/2;
		if ( cmp( p[ mid ], key )<0 ){
			less = mid;
		} else {
			ge = mid;
		}
	}
	return less;
}

int gcd( int a, int b )
{
	for(;;){
		if ( b==0 ){
			return a;
		}
		int c = a%b;
		a=b;
		b=c;
	}
}

void rot_left( Foo * p, int size, int rot )
{
	int g = gcd( size, rot );
	for( int start=0 ; start<g ; ++start ){
		int i=( start + rot )%size;
		Foo head = p[i];
		while( i != start ){
			int next = ( i+rot )%size;
			p[i] = p[ next ];
			i=next;
		}
		p[start]=head;
	}
}

inline
Foo max2( Foo const & a, Foo const & b, int cmp( Foo const &, Foo const & ) )
{
	return cmp( a, b )<0 ? b : a;
}

void merge( Foo * p, int left, int right, int cmp( Foo const & a, Foo const & b ) )
{
	if ( left==0 || right==0 ){
		return;
	}
	if ( cmp( p[ left-1 ], p[ left ] )<=0 ){
		return;
	}
	if ( left==1 && right ==1 ){
		if ( cmp( p[1], p[0] )<0 ){
			swap( p[0], p[1] );
		}
		return;
	}
	Foo key = left<right ? p[ left + (right+1)/2 ] : p[ (left+1)/2 ];
	if ( cmp( key, p[ left ] )<=0  && cmp( key, p[0] ) <= 0 ){
		key = max2( p[left-1], p[left+right-1], cmp );
	}
	int mL = find_last_less( p, left, key, cmp );
	int mR = find_last_less( p+left, right, key, cmp );
	int mLL = mL + 1;
	int mRL = mR + 1;
	int mLGE = left-mLL;
	int mRGE = right - mRL;
	rot_left( p+mLL, mLGE + mRL, mLGE );
	merge( p, mLL, mRL, cmp );
	merge( p+mLL+mRL, mLGE, mRGE, cmp );
}

void
ipms( Foo * p, int n, int cmp( Foo const & a, Foo const & b ) )
{
	if ( n<2 ){
		return;
	}
	if ( n==2 ){
		if ( cmp( p[1], p[0] )<0 ){
			swap( p[0], p[1] );
		}
		return;
	}
	int mid = n/2;
	ipms( p, mid, cmp );
	ipms( p+mid, n-mid, cmp );
	merge( p, mid, n-mid, cmp );
}

こんなにたくさん書かなきゃ行けないとは思ってなかったのでびっくりした。整列したいだけなのにgcdまで必要とは。
実行してみると、確かに N(log(N)**2) に比例していた。
ちなみに。
c++ template を使ってないのは、意識的にそうしたから。
テンプレートが使える環境なら、STL の stable_sort を使う方がいい。逆に言えば。ちゃんとした C++ コンパイラがない環境で役に立てるようにと書いた。
あと。
STL の stable_sort に比べるとだいぶ遅い。比較回数は4倍、コピー回数は3倍ぐらい。
たぶん、数が少ないときは挿入ソートにするとか、もっと少ないときにはまた別のソートとか、そういうことをきちんとやればいいんだと思う。書いてないので思うだけ。それとも無駄な計算してるかなぁ。
あとあと。
このソースは NYSL ライセンスにしておく。ご自由にお使いください。使う人がいるような気はしないけど。