混合ディリクレ分布のEMアルゴリズム・その2 | ぽんのブログ

ぽんのブログ

自分用の備忘録ブログです。書いてある内容、とくにソースは、後で自分で要点が分かるよう、かなり簡略化してます(というか、いい加減)。あまり信用しないように(汗

前回の混合ディリクレ分布のEMアルゴリズムに則りテストするためのプログラムを作ってみました。

ユーティリティ群は混合多項分布の時と同じ。
但し digamma 関数の逆関数を計算するため gsl/gsl_sf.h をインクルードしています。
#include <gsl/gsl_math.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_sf.h>

const gsl_rng    *_rng_;

/*** ユーティリティー ***/
/* ベクトルの和をとる */
double
_sum (int n, double *vector)
{
    int       i;
    double    sum = 0.;
    for (i = 0; i < n; i++) sum += vector[i];
    return sum;
}

/* ベクトルをその和で規格化する */
void
_normalize (int n, double *vector)
{
    int       i;
    double    sum = _sum (n, vector);
    for (i = 0; i < n; i++) vector[i] /= sum;
    return;
}

/* ベクトルを出力 */
void
fprintf_vector (FILE *stream, int n, double *vector, char *format)
{
    int       i;
    for (i = 0; i < n; i++) {
        fprintf (stream, format, vector[i]);
        if (i < n - 1) fprintf (stream, "\t");
    }
    fprintf (stream, "\n");
   return;
}

/* 行列を出力 */
void
fprintf_matrix (FILE *stream, int nrow, int ncol, double **matrix, char *format)
{
    int       i;
    for (i = 0; i < nrow; i++) fprintf_vector (stream, ncol, matrix[i], format);
    return;
}

次にEMアルゴリズムの部分。
/*** EMアルゴリズム ***/
/* digamma関数の逆関数 初期値計算用 */
double
psi_inverse0 (double y)
{
    return (y < -2.22) ? - 1. / (y + M_EULER) : exp (y) + 0.5;
}

/* digamma関数の逆関数 */
double
psi_inverse (double y, double tol, int maxiter)
{
    int       iter = 0;
    double    x = psi_inverse0 (y);
    while (1) {
        double    dx = (y - gsl_sf_psi (x)) / gsl_sf_psi_1 (x);
        x += dx;
        if (fabs (dx) <= tol) break;
        if (++iter > maxiter) {
            fprintf (stderr, "tol = %.3e, iter = %d, dx = %.3e\n", tol, iter, dx);
            break;
        }
    }
    return x;
}

/* 対数尤度の条件付き期待値 */
double
log_likelihood (int ncls, int nitems, double *xi, double **z, double **x, double **alpha)
{
    int       i, k;
    double    ll = 0.;
    // ll = sum_i sum_k { z[i][k] * log (xi[k] * f (x[i] | alpha[k]) ) }
    for (i = 0; i < nitems; i++) {
        for (k = 0; k < ncls; k++) {
            ll += z[i][k] * (gsl_ran_dirichlet_lnpdf (nitems, alpha[k], x[i]) + log (xi[k]));
        }
    }
    return ll;
}

/* Eステップ */
void
E_step (int n, int ncls, int nitems, double *xi, double **z, double **x, double **alpha)
{
    int        i, k;
    for (i = 0; i < n; i++) {
        for (k = 0; k < ncls; k++) {
            // z[i][k] = xi[k] * f(x[i] | alpha[k])
            z[i][k] = xi[k] * gsl_ran_dirichlet_pdf (nitems, alpha[k], x[i]);       
        }
        // z[i][k] = z[i][k] / sum_k z[i][k]
        _normalize (ncls, z[i]);
    }
    return;
}

/* Mステップ */
void
M_step (int n, int ncls, int nitems, double *xi, double **z, double **x, double **alpha)
{
    int       i, k, l;
    for (k = 0; k < ncls; k++) {
/* alpha[k][l]の更新 */
double alpha_k = _sum (nitems, alpha[k]);
for (l = 0; l < nitems; l++) {
double logx_kl = 0.;
double z_k = 0.;
for (i = 0; i < n; i++) {
logx_kl += z[i][k] * log (x[i][l]);
z_k += z[i][k];
}
logx_kl /= z_k;
alpha[k][l] = psi_inverse (gsl_sf_psi (alpha_k) + logx_kl, 1.e-8, 1000);
}

        /* xi の更新 */
        // xi[k] = sum_i z[i][k] / n
        xi[k] = 0.;
        for (i = 0; i < n; i++) xi[k] += z[i][k];
        xi[k] /= (double) n;
    }
    return;
}

ディリクレ分布の場合、gsl の関数

double    gsl_ran_dirichlet_lnpdf (size_t K, const double p[], const unsigned int x[]);

で対数尤度を計算できます。
また

double
    gsl_ran_dirichlet_pdf (size_t K, const double p[], const unsigned int x[]);

が密度関数になります。

続いて入力とするデータを作成する部分

/*** テストデータ作成用 ***/
/* パラメータを作成 (nitems 次元ベクトルを ncls 個分) */
double **
create_parameter (int ncls, int nitems)
{
    int           k, l;
    double        **alpha = (double **) malloc (ncls * sizeof (double *));
    for (k = 0; k < ncls; k++) {
        alpha[k] = (double *) malloc (nitems * sizeof (double));
        for (l = 0; l < nitems; l++) alpha[k][l] = gsl_ran_flat (_rng_, 0., 100.);
    }
    return alpha;
}

/* データ作成 (nitems 次元ベクトルを n 個分) */
double **
create_mixture_dirichlet_sample (int n, int ncls, int nitems, double *xi, double **alpha)
{
    int        i, k;
    double     **x;

    x = (double **) malloc (n * sizeof (double *));
    for (i = 0; i < n; i++) {
        // 混合率 xi に応じて多項分布(パラメータを指定する k の値)を選ぶ
        double    r = gsl_ran_flat (_rng_, 0., 1.);
        for (k = 0; k < ncls; k++) {
            if (r < xi[k]) break;
            r -= xi[k];
        }
        x[i] = (double *) malloc (nitems * sizeof (double));
        gsl_ran_dirichlet (_rng_, nitems, alpha[k], x[i]);
    }
    return x;
}

最後にメイン関数。
int
main (void)
{
    int       ncls = 3;      // コンポーネント数 (K)
    int       nitems = 4;     // アイテム数 (L)
    int       n = 1000;      // データセット数 (N)

    double    **x;           // データ

    double    **alpha;        // パラメータ
    double    *xi;            // 混合比
    double    **z;            // 事後確率

    _rng_ = gsl_rng_alloc (gsl_rng_default); // gslの乱数初期化

    // 混合ディリクレ分布に従うデータの作成
    {
        int       i;
        double    **alpha0 = create_parameter (ncls, nitems); // 各ディリクレ分布のパラメータ
        double    xi0[3] = {0.2, 0.5, 0.3};                   // 混合率

        x = create_mixture_dirichlet_sample (n, ncls, nitems, xi0, alpha0);

        fprintf (stdout, "xi0 :\n");
        fprintf_vector (stdout, ncls, xi0, "%.3f");
        fprintf (stdout, "\nalpha0 :\n");
        fprintf_matrix (stdout, ncls, nitems, alpha0, "%.3f");
    }

    // 領域確保と初期値設定
    {
        int       i, k, l;

        alpha = (double **) malloc (ncls * sizeof (double *));
        for (k = 0; k < ncls; k++) {
            alpha[k] = (double *) malloc (nitems * sizeof (double));
            for (l = 0; l < nitems; l++) alpha[k][l] = 100.;
        }

        z = (double **) malloc (n * sizeof (double *));
        for (i = 0; i < n; i++) {
            z[i] = (double *) malloc (ncls * sizeof (double));
            for (k = 0; k < ncls; k++) z[i][k] = 0.;
        }

        xi = (double *) malloc (ncls * sizeof (double));
        for (k = 0; k < ncls; k++) xi[k] = gsl_ran_flat (_rng_, 0., 1.);
        _normalize (ncls, xi);
    }

    // EMアルゴリズムを実行
    {
        int       iter = 0;
        double    l_prev;
        double    l = -GSL_DBL_MAX;

        while (1) {
            l_prev = l;
            E_step (n, ncls, nitems, xi, z, x, alpha);
            l = log_likelihood (ncls, nitems, xi, z, x, alpha);
            if (fabs (l - l_prev) < 1.e-8 || ++iter >= 5000) break;
            M_step (n, ncls, nitems, xi, z, x, alpha);
        }
        fprintf (stdout, "\niter = %d, l = %f\n\n", iter, l);
        fprintf (stdout, "xi :\n");
        fprintf_vector (stdout, ncls, xi, "%.3f");
        fprintf (stdout, "\nalpha :\n");
        fprintf_matrix (stdout, ncls, nitems, alpha, "%.3f");
    }
    return 0;
}

これを動かした結果は

<真のパラメータ>
xi0 :
0.200    0.500    0.300

alpha0 :
99.974    16.291    28.262    94.720
23.166    48.497    95.748    74.431
54.004    73.995    75.994    65.864

iter = 2005, l = 26.521813

<計算結果>
xi :
0.190    0.291    0.519

alpha :
99.328    15.842    28.203    94.319
56.989    78.458    81.157    70.535
23.748    50.103    98.876    77.305


となりました。
何処か間違ってる気がする・・・