失敗しました(涙
これまで何回かに分けてEMアルゴリズムのサンプルコードを載せましたが、エラー処理とかをすっ飛ばしてる割に、全部で600行超えてしまいましたね(涙
ちょこっと遊ぶだけのものだったはずなのに、それにしては大げさすぎました・・・
正規分布パラメータの構造体や混合分布の構造体なんか定義したせいで、ちょこっと遊ぶには無駄なコードが増えすぎてしまいましたw
まぁ、このシリーズの一回目書いてて気づいてたんですけど(汗。
パラメータや混合分布構造体を定義したのは、例えば楕円族など(正規分布、コーシー分布、t分布など)についてそのまま転用するのに便利かと思ったんですが。。。
例えばいろんな分布の更新式では、事後確率(負担率)の計算なんかは「前のときと同じ」で式を書くのを省略してましたが、ということは、その計算についてはいろんな分布で共通の関数が使える訳なんでしょうね。
で、スーパークラス的なものを作って、それぞれの分布で共通している計算は、内部でスーパークラスの同じ関数を呼び出すインターフェースを個別の分布用に作ればいいのでは、なんて考えた訳です。
しかしブログで小出しに載せるには適してませんでしたね(涙
とはいえここまで引っ張ってきたので、これまでの関数群を用いてEM法を動かすサンプルのメイン関数を載せておきましょう。
で、このmain関数で行うことは以下。
1.データの用意
2.混合分布構造体の用意
3.EMアルゴリズムの実行
4.結果の出力
1.については、前回のコレスキー少佐のくだりを使って正規乱数を作ります。
2.については
gaussian_mixture *gm = gaussian_mixture_new (ndata, ndim, data, ncls);
で、1.で作成したデータを渡して混合分布の構造体を作ります。
さらに、それに ncls 個分の正規分布の平均ベクトル、共分散行列の初期値を設定します。
gm->params[0] = gaussian_params_new (ndim, mu0, cov0);
gm->params[1] = ...
こんな風に、直にメンバにアクセスしてパラメータを登録するんですが、前にも書いた様に指定したクラスタ数分パラメータが登録されてるか確認出来るよう、パラメータ登録のメソッドを用意して内部カウンタでいくつパラメータが登録されたか数えておく、なんてした方がいいのでしょう。
3.では、単に
cell = gaussian_mixture (tol, gm, &iter, maxiter);
を呼び出すだけです。
これを呼び出すことで、対数尤度の条件付期待値の値の更新量が tol 以下になるまでパラメータを更新し続けます(又は対数尤度が前より減少したら停止)。
4.の出力は、各データを事後確率が最も大きなクラスタに振り分け、それぞれ別ファイルに出力します。
<サンプルのメイン関数>
int
main (void)
{
size_t ndata;
size_t ndim;
size_t ncls;
double *data;
gaussian_mixture *gm;
ndata = 200;
ndim = 3;
ncls = 2;
data = (double *) malloc (ndata * ndim * sizeof (double));
/* テストデータの作成 */
{
int i, l;
double test_mu0[] = {-2., -2., -1.};
double test_sigma0[] = {sqrt (2.), sqrt (2.), sqrt (2.)};
double test_rho0[] = {0.01, 0.05, -0.03};
double *data0 = test_data (100, ndata / 2, ndim, test_mu0, test_sigma0, test_rho0);
double test_mu1[] = {2., 2., -1.};
double test_sigma1[] = {sqrt (3.), sqrt (3.), sqrt (3.)};
double test_rho1[] = {0.05, 0.01, 0.05};
double *data1 = test_data (200, ndata - ndata / 2, ndim, test_mu1, test_sigma1, test_rho1);
FILE *fp;
l = 0;
for (i = 0; i < (ndata / 2) * ndim; i++) data[l++] = data0[i];
free (data0);
for (i = 0; i < (ndata - ndata / 2) * ndim; i++) data[l++] = data1[i];
free (data1);
if ((fp = fopen ("input.dat", "w")) != NULL) {
fprintf_row_major_array (fp, ndata, ndim, data, "%f");
fclose (fp);
}
}
/* 混合分布構造体の作成 */
gm = gaussian_mixture_new (ndata, ndim, data, ncls);
/* 混合分布を構成する正規分布パラメータの初期値設定 */
{
double init_mu0[] = {-1., -1., -1.};
double init_mu1[] = { 1., 1., 1.};
double init_sigma[] = {1., 1., 1.};
double init_rho[] = {0., 0., 0.};
gsl_matrix *init_cov = covariance_matrix (ndim, init_sigma, init_rho);
gm->params[0] = gaussian_params_new (ndim, init_mu0, init_cov->data);
gm->params[1] = gaussian_params_new (ndim, init_mu1, init_cov->data);
}
/* EMアルゴリズム */
{
int iter;
double cell;
cell = gaussian_mixture_EM (1.e-3, gm, &iter, 100);
fprintf (stderr, "cell = %f, iter = %d\n", cell, iter);
}
/* クラスタリング結果の出力 */
{
int i, k;
FILE *fp, *fp_out[ncls];
for (k = 0; k < ncls; k++) {
char fn[80];
sprintf (fn, "cluster%02d.res", k + 1);
fp_out[k] = fopen (fn, "w");
}
for (i = 0; i < ndata; i++) {
int zmaxk = 0;
for (k = 1; k < ncls; k++) {
zmaxk = (gm->z[i * ncls + k] > gm->z[i * ncls + zmaxk]) ? k : zmaxk;
}
fp = fp_out[zmaxk];
if (fp) {
int j;
for (j = 0; j < ndim; j++) {
fprintf (fp, "%f", data[i * ndim + j]);
if (j < ndim - 1) fprintf (fp, "\t");
else fprintf (fp, "\n");
}
}
}
for (k = 0; k < ncls; k++) if (fp_out[k]) fclose (fp_out[k]);
}
free (gm);
free (data);
return 0;
}
以上です。
まぁ、長いメイン関数ですね・・・(汗
上のサンプルのメイン関数では
平均ベクトル μ 、共分散行列 Σ がそれぞれ
μ1=[-2, -1, -2]
Σ1=[
2.00, 0.02, 0.10
0.02, 2.00, -0.06
0.10, -0.06, 2.00
]
μ2=[4, 3, -2]
Σ2=[
3.00, 0.15, 0.03
0.15, 3.00, 0.15
0.03, 0.15, 3.00
]
の2つの正規分布から100個づつ発生させたデータを混合させたものを入力に用いています。
上が入力となるテストデータです。
3Dなので見にくいですが、データが右と左の塊に分かれているように見えますね。
で、これをEM法でクラスタリングすると
こんなふうに赤と緑の二つのクラスタに分けられました。
というような感じで、それなりに上手くクラスタリングできているようです。
ただ、上のサンプルメイン関数では、クラスタパラメータの初期値をかなりいーかげんに決めていますが、データ数やデータの分布の仕方では初期値依存が大きくなったりして上手くクラスタ分けできない場合もあるようです。
パラメータの初期値についてですが、よく見かけるのは kmeans と併用している例でしょうか?
つまり、まず kmeans法でクラスタ分けし、その結果から得られる平均・分散を初期値にしてEM法を実行する、という事です。
さて、ここまで C でサンプルプログラムを載せてきましたが、ちょこっと遊ぶにはやっぱり長すぎると思うので、次回、Octave (びんぼー人用まっとら○)を使ってもっと短いサンプルコードを載せたいと思います・・・