octaveでEMアルゴリズム・その2 混合t分布のEMアルゴリズム(改) | ぽんのブログ

ぽんのブログ

自分用の備忘録ブログです。書いてある内容、とくにソースは、後で自分で要点が分かるよう、かなり簡略化してます(というか、いい加減)。あまり信用しないように(汗

大分前に、t分布を用いたEMアルゴリズムについても書いていました。
なので前回のスクリプトを書き変えてt分布版も作ってみました。

動かすには digamma と trigamma 関数が必要になるのですがそれぞれこちら(digamma.m)とこちら(trigamma.m)から入手して同じディレクトリにおいておく必要があります(あと faithful.txt も)。

ただ、どっか間違っている可能性は非常に大なのであまり信用できませんが(汗

ファイル名は何でもいいのですが、とりあえず test_faithful_tmm.oct で。。。
% 各ステージの描画
function plotstage(x,mu,cov,K)
    clf    % クリア
    axis([-3,3,-3,3],'square');
    hold    % hold on
    plot(x(:,1),x(:,2),'+','markersize',4); % データをプロット
    for k=1:K
        plotellipse(mu(:,k),cov(:,:,k)); % 共分散をプロット
        ph=plot(mu(1,k),mu(2,k),'o','color',[0,0.5,0.5]); % 平均をプロット
        set(ph,'linewidth',2);
    endfor
    hold    % hold off
endfunction

% 標準偏差の楕円
function plotellipse(mu,cov)
    if(size(cov)(1)!=size(cov)(2) || size(cov)(1)!=2) return; endif
    p=[0:0.01:2*pi]';
    w=[cos(p),sin(p)];

    [v,e]=eig(cov);
    t=repmat(sqrt(abs(diag(e)')),size(w)(1),1).*w;
    t=t*v'+repmat(mu',size(t)(1),1);
    ph=plot(t(:,1),t(:,2),"color",[1,0,0]);
    set(ph,'linewidth',2);
endfunction

% 対数尤度関数
function [y,s2]=loglikelihood(x,nu,m,cov)
    [N,D]=size(x);
    xm=(x-repmat(m',N,1))';
    s2=sum(xm.*(inv(cov)*xm),1)';
    lnrm=0.5*D*log(pi*nu)+lgamma(0.5*nu)+0.5*log(det(cov));
    y=lgamma(0.5*(D+nu))-0.5*(D+nu)*log(1+s2/nu)-lnrm;
endfunction

% dq/dnu, d^2q/dnu^2
function [dq,d2q]=d_d2_q(nu)

    dq=digamma(0.5*nu)-log(0.5*nu);
    d2q=0.5*trigamma(0.5*nu)-1./nu;

endfunction

% nuの更新
function nu=most_likelihood_nu(nu,u,w,N)

    K=length(nu);

    iter=0;
    y=1+(u-w)/N;

y0=y-log(0.5*nu);
nu=exp(y0)+0.5;
if ~isempty((i=find(y0<-2.22))) nu(i)=-1./(y0(i)-digamma(1)); endif

    % ニュートン法
    while (1)
        [dq,d2q]=d_d2_q(nu);
        dnu=(y-dq)./d2q;
        if ~isempty(find(nu+dnu<0)) break; endif
        nu=nu+dnu;
        if (sqrt(dnu*dnu')<1.e-8) break; endif
        iter=iter+1;
        if(iter>100) break; endif
    endwhile

endfunction

% log (sum_k exp (x_k) ) )
function ret=logsumexp(x)
    ret=x(1);
    for k=2:length(x)
        vmax=max(ret,x(k));
        vmin=min(ret,x(k));
        if (vmax == vmin) ret=vmax+log(2);
        elseif (vmax>vmin+50) ret=vmax;
        else ret=vmax+log(exp(vmin-vmax)+1);
        endif
    endfor
endfunction

% Eステップ
function [z,w,u,logp]=Estep(x,nu,xi,mu,cov,K)
    [N,D]=size(x);

    % log z = log(xi*f / sum xi*f) = log (xi*f) - log (sum xi*f)

    % log (xi*f)
    for k=1:K
        [lnz(:,k),s2(:,k)]=loglikelihood(x,nu(k),mu(:,k),cov(:,:,k));
        lnz(:,k)=lnz(:,k)+log(xi(k));
    endfor

    % log (sum xi*f) = log ( sum exp ( log (xi * f) ) )
    for i=1:N
        lnsumz(i,1)=logsumexp(lnz(i,:));
    endfor

logp=sum(lnsumz); % 対数尤度

    % z = exp( log z ) = exp ( log (xi*f) - log (sum xi*f) )
    z=max(exp(lnz-repmat(lnsumz,1,K)),eps);
    % zを規格化
    z=z./repmat(sum(z,2),1,K);

    anu=repmat(nu,N,1);
    w=(anu+D)./(anu+s2);
%    u=digamma(0.5*(anu+1))-log(anu+s2);
    u=digamma(0.5*(anu+D))-log(0.5*(anu+s2));

endfunction

% Mステップ
function [xi,nu,mu,cov]=Mstep(x,z,nu,w,u,K)
    [N,D]=size(x);

    wz=w.*z;

    % 混合比の更新
    xi=mean(z);

    % 平均の更新
    mu=x'*wz;
    mu=mu./repmat(sum(wz,1),D,1);

    % 分散共分散行列の更新
    cov=zeros(D,D,K);
    for k=1:K
        for i=1:N
            xm=x(i,:)'-mu(:,k);
            cov(:,:,k)=cov(:,:,k)+wz(i,k)*xm*xm';
        endfor
        cov(:,:,k)=cov(:,:,k)/(N*xi(k));
    endfor

    % 自由度の更新
    nu=most_likelihood_nu(nu,sum(u,1),sum(w,1),N);

endfunction

% EMアルゴリズム
function [z,xi,nu,mu,cov,iter]=TMM(x,nu,xi,mu,cov,K,tol,maxiter)

    if nargin()<3 tol=1.e-3; endif
    if nargin()<4 maxiter=100; endif

    [N,D]=size(x);

    logp=-realmax;
    iter=1;
    while (1)
        logp_pre=logp;

        [z,w,u,logp]=Estep(x,nu,xi,mu,cov,K);
logp % 対数尤度を表示

        if ((logp-logp_pre)<tol) break; endif

        [xi,nu,mu,cov]=Mstep(x,z,nu,w,u,K);

        plotstage(x,mu,cov,K);
        usleep(10);

        iter=iter+1;
        if (iter > maxiter) break; endif

    endwhile

endfunction

%%% サンプルスクリプト %%%
x=load('faithful.txt');

[N,D]=size(x);
x=x-repmat(mean(x,1),N,1);
x=x./repmat(sqrt(var(x,1)),N,1);

% ノイズを加える。
x=[x;3-6*rand(150,2)];

K=2;

rand('seed', 100);

mu=rand(D,K);        % 平均ベクトルを乱数で初期化
cov=zeros(D,D,K);
for k=1:K            % 共分散行列を単位行列で初期化
    cov(:,:,k)=eye(D);
endfor
xi=ones(1,K)/K;    % 混合比の初期値(=1 / クラスタ数)
nu=ones(1,k);        % 自由度の初期値(=1)

[z,nu,xi,mu,cov,iter]=TMM(x,nu,xi,mu,cov,K,1.e-3,1000)




以上を digamma.m、trigamma.m 及び faithful.txt と同じディレクトリに保存し

$ octave -q test_faithful_tmm.oct

で実行できるはずです。

使ったデータは、Old Faithful のデータに一様なノイズ([-3、3]の範囲)を150個加えたものです。


上の赤い点が元データ、それ以外の緑の点がノイズです。
これを混合正規分布モデル(GMM)、混合t分布モデル(TMMとでも呼ぶのでしょうか?)で K=2 とした場合の更新の具合を見てみましょう。

まず最初に、ノイズを含まないデータの場合がこちら(前回のスクリプトで K=2 とした場合)。



実際はカクカクしていてこんなに滑らかに動きませんが(汗
クラスタリング結果はこんな感じでした。



次にノイズありの場合。
まずは GMM の場合。前回のスクリプトでノイズを加え K=2 としたもの。


加えられたノイズの影響で上手に分けられていません。
クラスタリング結果はこんな感じ。



真ん中の塊とそれを取り巻くグループに分けられてしまっています。。。

次にTMMの場合。今回のスクリプトの実行結果です。



クラスタリング結果はこんな感じ。


ノイズで無い真のデータ部分は K=2 のノイズ無しの場合に似た形でクラスタ分けがされているようです。
一般にすそ野の広いt分布の場合、ノイズやはずれ値に対してよりロバストだそうです。
今回の結果では、ノイズ混じりの場合でも混合t分布ではうまくいっているように見えますが、もちろん初期値を変えればうまくいかなかったり、あるいは混合正規分布でもうまくいくような場合もあるのかもしれません。

今後もう少し色々試してみたいと思いますが何もしません(キッパリ
あとコードもどうにかしたいけど何もしません(キッパリ・その2