ぽんのブログ -21ページ目

ぽんのブログ

自分用の備忘録ブログです。書いてある内容、とくにソースは、後で自分で要点が分かるよう、かなり簡略化してます(というか、いい加減)。あまり信用しないように(汗

前回の計算手順に則り、最低限の lasso - lars の octave スクリプトを書いてみました。

*** lasso_lars.m ***
function beta = lasso_lars (X, y, t)
[m, n] = size(X);
maxvals = min(m, n);

% beta, mu, アクティブセット A の初期化
beta = zeros(n, 1);
mu = zeros(m, 1);
A = [];

stop_loop = false;

while length(A) <
maxvals && !stop_loop,

    % correlationの更新
    c = X' * (y - mu);
    [c_hat, c_amax_idx_tmp] = max(abs(c));

    if isempty(A),
        c_amax_idx = c_amax_idx_tmp;
    end

    if c_amax_idx >= 0,
        A = [A, c_amax_idx];
    end

    % equiangularベクトルの更新
    XA = X(:, A);
    GA = XA' * XA;
    invGA = inv(GA);
    s = sign(c(A));
    absA = 1 / sqrt(s' * invGA * s);
    w = absA * invGA * s;

    u = XA * w;
 
    % gamma_hatの算出
if length(A) ==
maxvals,
gamma_hat = c_hat / absA;
gamma_hat_idx = -1;
else
Ac = setdiff(1:n, A); % A の complement
a = X' * u;
gamma_hat_tmp = [ (c_hat - c(Ac)) ./ (absA - a(Ac)), (c_hat + c(Ac)) ./ (absA + a(Ac)) ];
gamma_hat_tmp(gamma_hat_tmp <= 0) = Inf;
[gamma_hat, idx] = min(min(gamma_hat_tmp, [], 2));
gamma_hat_idx = Ac(idx);
end

    % gamma tildeの算出
    gamma_tilde_tmp = - beta(A) ./ w;
    gamma_tilde_tmp(gamma_tilde_tmp <= 0) = Inf;

    [gamma_tilde, gamma_tilde_idx] = min(gamma_tilde_tmp);
    if isnan(gamma_tilde),
        gamma_tilde = Inf;
    end

    % step sizeの更新
    if gamma_hat < gamma_tilde,
        step_size = gamma_hat;
    else
        step_size = gamma_tilde;
    end

    % betaの更新
    beta_new = zeros(n, 1);
    beta_new(A) = beta(A) + step_size * w;

    % betaのL1ノルムの比較
    if t ~= Inf,
        nrm1_prev = sum(abs(beta)); % 更新前の beta の L1 ノルム
        nrm1_cur = sum(abs(beta_new)); %
更新後の beta の L1 ノルム
        if nrm1_prev <= t && t < nrm1_cur,
            step_size = absA * (t - nrm1_prev);
            beta_new(A) = beta(A) + step_size * w;
            stop_loop = true;
        end
    end

    % beta、muの更新
    beta(A) = beta_new(A);
    mu = mu + step_size * u;

    if ~stop_loop,
        if gamma_hat < gamma_tilde,
            c_amax_idx = gamma_hat_idx;
        else
            A(gamma_tilde_idx) = [];
            c_amax_idx = -1;
        end
    end

end

これで最低限の lasso - lars アルゴリズムになってるのではないかと思います・・・

これを使って以下のスクリプトで diabetes.data の lasso 回帰を行ってみます。

*** test_lasso_lars.m ***
function test_lasso_lars()

data = dlmread ("diabetes.data", '\t', 1, 0);

y = data(:, end);
x = data(:, 1:end-1);

[m, n] = size(x);

% 観測ベクトルの中心化
y = y - mean(y);

% 説明変数の規格化
x = bsxfun(@minus, x, mean(x));
x = bsxfun(@rdivide, x, sqrt(dot(x, x)));

t1 = 3500;
dt = 10;
t = [0:dt:t1];
l = length(t);
for i = 1:l
    beta(:, i) = lasso_lars(x, y, t(i));
end

clf
hold on

ph = plot(t, beta);
set(ph, 'linewidth', 1.2);

plot(t, zeros(l, 1), '-.', 'color', [0, 0, 0]);

print -djpg path.jpg

hold off

end

t の値を 0 から 3500 まで 10 づつ大きくしながら解 βを求め配列に格納、計算終了後に各説明変数の βのパスをプロットしています。

結果がこちらです。