多変量混合正規分布のEMアルゴリズム・サンプルプログラム その8 | ぽんのブログ

ぽんのブログ

自分用の備忘録ブログです。書いてある内容、とくにソースは、後で自分で要点が分かるよう、かなり簡略化してます(というか、いい加減)。あまり信用しないように(汗

前回までで C言語でEMアルゴリズムのサンプルプロいグラムを載せてましたが、octave 用の EM アルゴリズムの例として以下を載せました。

関数名は GMEM (Gaussian Mixture の EM)で、以下を

GMEM.m

というファイル名で、octave が探しに来れる適当なところに保存すれば呼びだす事が出来るはずです。

======== GMEM.m ========

function [z,xi,mu,cov,iter]=GMEM(x, K, tol, maxiter)

if nargin()<3 tol=1.e-3 endif
if nargin()<4 maxiter=100 endif
rand('seed', 10);

[N,D]=size(x);
mu=rand(D,K); % 平均ベクトルを乱数で初期化
cov=zeros(D,D,K);
for k=1:K % 共分散行列を単位行列で初期化
cov(:,:,k)=eye(D);
endfor
xi=ones(1,K)/K; % 混合比は 1 / クラスタ数

cell=-realmax;
iter=1;
while (1)
cell_pre=cell;

z=Estep(x,xi,mu,cov,K); % Eステップ

cell=conditional_expect_ll(x,xi,z,mu,cov,K)
if (cell-cell_pre<tol) break; endif

[xi,mu,cov]=Mstep(x,z,K); % Mステップ

iter=iter+1;
if iter > maxiter break; endif

endwhile

endfunction

% 尤度関数
function y=likelihood(x,m,cov)
[N,D]=size(x);
xm=(x-repmat(m',N,1))';
s2=sum(xm.*(inv(cov)*xm),1)';
nrm=sqrt(2*pi)^D*sqrt(det(cov));
y=exp(-0.5*s2)/nrm;
endfunction

% 対数尤度の条件付き期待値
function y=conditional_expect_ll(x,xi,z,mu,cov,K)
[N,D]=size(x);

for k=1:K
ll(:,k)=log(likelihood(x,mu(:,k),cov(:,:,k)))+log(xi(k));
endfor
y=sum(sum(z.*ll)); 

endfunction

% Eステップ
function z=Estep(x,xi,mu,cov,K)
[N,D]=size(x);

for k=1:K
l=likelihood(x,mu(:,k),cov(:,:,k));
z(:,k)=xi(k)*l;
endfor

z=z./repmat(sum(z,2),1,2);

endfunction

% Mステップ
function [xi,mu,cov]=Mstep(x,z,K)
[N,D]=size(x);

% 混合比の更新
xi=mean(z);

% 平均の更新
mu=x'*z;
for k=1:K
mu(:,k)=mu(:,k)/(N*xi(k));
endfor

% 分散共分散行列の更新
cov=zeros(D,D,K);
for k=1:K
for i=1:N
xm=x(i,:)-mu(:,k)';
cov(:,:,k)=cov(:,:,k)+z(i,k)*xm'*xm;
endfor
cov(:,:,k)=cov(:,:,k)/(N*xi(k));
endfor

endfunction

=== ここまで ===

100行以下になってますかね??
まぁ、 octaveじゃなくても、上の様にパラメータなんかを直にやり取りするようにすれば、Cでのコードだってもっと短くなるんですが。

octave でスクリプト書くのなんて久々なのですが、多分もっと最適化しようとすればできるのでしょう・・・

octaveという事でforループ使っちゃいけない症候群が中途半端に発現しています(謎)

べたに for ループで書いた方が分かりやすいのですが(というか、上ではやってる事が非常に分かりにくい・・・汗)そうすると目に見えて遅くなります。

ただfor ループ回避の書き方が作法に合ったものになってるかは自信梨です(汗

ぽんのブログ-なし



でも共分散行列の計算ではループ使う以外に思いつきませんでした(涙


それから、共分散行列の初期値は単位行列平均ベクトルの初期値は乱数で決めてますが、ここはどうなんでしょうね?毎回の計算で結果が変るのは困るので、乱数の種を設定しています(= 10)。

前回書いたように、ここのところは別途 kmeans の関数を作り、kmeans の結果から初期値を決める方が良いのかもしれません。




では、これを使って前回の C のサンプルメインと同じ事をやってみましょう。

%%% サンプルスクリプト %%%

% 指定された平均 mu、共分散 cov の正規乱数作成
function y=randmg(seed,N,mu,cov)

D=length(mu);

randn("seed",seed); % seed値をセット
x=randn(D,N); % 標準正規乱数
y=(chol(cov)'*x)' - repmat(mu,N,1);

endfunction

K=2; % クラスタ数

% 正規分布1
mu1=[-2, -1, -2];
cov1=[
2.00, 0.02, 0.10;
0.02, 2.00, -0.06;
0.10, -0.06, 2.00
];

% 正規分布2
mu2=[4, 3, -2];
cov2=[
3.00, 0.15, 0.03;
0.15, 3.00, 0.15;
0.03, 0.15, 3.00
];

% 2種類の正規分布からデータを取得(100個づつ)
x=[randmg(100,100,mu1,cov1);randmg(200,100,mu2,cov2)];

[z,xi,mu,cov,iter]=GMEM(x,K) % EM法


%*** 計算結果の出力 ***
[N,D]=size(x);

% 出力ファイルオープン
for k=1:K
fn=sprintf("cluster%02d.res",k);
fp(k)=fopen(fn,"w");
endfor

% 各データが、zが最大になるクラスタに属すとし
% 各クラスタに属すデータを各々のファイルに出力
for i=1:N
maxidx=find(z(i,:)==max(z(i,:)));
if ~isempty(fp(maxidx))
for j=1:D
fprintf(fp(maxidx),"%f\t",x(i,j));
endfor
fprintf(fp(maxidx),"\n");
endif
endfor

% ファイルクロース
for k=1:K
if ~isempty(fp(k)) fclose(fp(k)); endif
endfor



100行とか言いながら、これも合わせれば200行になりますか?(汗

まぁ、いいや・・・

これを動かすと、およそ20回のイテレーションで


ぽんのブログ-出力2


上の様に2つのクラスタに分解する事が出来ました。

Cに比べ octave だと大分動作は遅くなりますが、何といっても手軽ですね。

手軽なのでいろいろ遊べます。


例えばこんな事も♪



上のスクリプトの

x=[randmg(100,100,mu1,cov1);randmg(200,100,mu2,cov2)];

これで2種類の正規分布から100個づつデータを取り出してるんですが、これを

x=[randmg(100,100,mu1,cov1);randmg(200,100,mu2,cov2), 20*(rand(50,3)-0.5)];

なんてすれば、2種類の正規乱数計100個に加え、さらに -10 < (x,y,z) < 10 の範囲の一様乱数を50個追加したりできます。


で、それでEMアルゴリズムを実行させると

ぽんのブログ-出力2B

こんな結果が得られま・・・

あれ・・・???

なんか変な風になりましたね。。。

推定された正規分布の分散は

[30, 27, 27]



[12, 7, 3]

と、元々の分布よりかなり大きくなってしまいました・・・

まぁ、それもそうですよね。

一様乱数が加わった、つまりとんでもない「はずれ値」が乗ってしまったので、このノイズに推定結果が引っ張られてしまったようです。


でも、こんな風にノイズが乗ったデータを扱わなければならない場合も多くあるんでしょうね。

という訳で次回、「はずれ値」が乗ってしまってもよりロバストだといわれる t分布を扱ってみましょう。