ぽんのブログ -16ページ目

ぽんのブログ

自分用の備忘録ブログです。書いてある内容、とくにソースは、後で自分で要点が分かるよう、かなり簡略化してます(というか、いい加減)。あまり信用しないように(汗

Elastic Net が、L2正則化問題をlassoで解く、つまり

式(1)


subject to

を解くと説明しました。Zou and Hasti (2007) では

式(2)


を解きます。係数の



の対角成分でのスケーリングを表します。

elastic net 解 は、式(2)の解 に対し



なる関係を持ちます。

式(2)に則り Lars の計算方法を書きなおしてみます。

相関ベクトル c



とおきます。 X と残差とのcorrelation c



ここで μ



です。

c を書きなおすと



となります。

等角ベクトル u

アクティブセット A の元について c の符号を s = sign(c(A)) とすると

式(3)




ここに


で等角ベクトルが求められます。



なので



です。下に挙げるスクリプトでは だけ求めています。

ステップサイズ

を求める為に必要なベクトル a



で求められます。

式(3)の計算では cholesky分解を用いています。

XA に j 番目のpredictor xj を加える、つまり X の j 列目を加えるとします。

アクティブセットに j を加え ( A = [A j] )、更新された A から XA を求めたとします( XA = X(:, A) )。
すると Z


となっています。
従って



以上から Z^T・Z のcholesky分解 L



を順次 cholinsert で L の最後に追加していけばよい事になります。
また predictorを除く、つまり XA の列を除く場合は L から choldelete で除きます。

どこか間違っている可能性は大!

=== larsen.m ===


function [beta, mu] = larsen2(X, y, lambda, stop)

% n : dim of observation, p : num of predictors
[n, p] = size(X);

if lambda > eps,
    maxvariables = min(n, p);
else
    maxvariables = p;
end

% active set
A = [];

% initial value of beta and mu
beta = zeros(p, 1);
mu = zeros(n, 1);

% cholesky decomposition of gram matrix XA' * XA
L = [];

lasso_cond = true;
stop_loop = false;

% scale factor for larsen
scale = 1 / sqrt(1 + lambda);

while length(A) < maxvariables && ~stop_loop,

    % correlation c
    % [y ; 0] = scale * [X ; sqrt(lambda) * E] * beta
    % residuals: r = [ y - mu ; 0 - scale * sqrt(lambda) * beta]
    % c = scale * [X ; sqrt(lambda) * E]' * r
    c = scale * X' * (y - mu) - lambda * scale^2 * beta;

    [c_hat, c_amax_idx] = max(abs(c));

    if isempty(A),
        lars_idx = c_amax_idx;
    end

    if lasso_cond,
        A = [A lars_idx];
    end

    % XA for updated active set
    XA = X(:, A);

    % update cholesky decomp.
    if lasso_cond,
        % add new col to L
        if isempty(L),
            L = 1;
        else
            % add vector t to the last col of L
            % t = [ (XA' * x)(1:end-1), (1+lambda)] / (1+lambda)
            x = X(:, lars_idx);
            t = scale^2 * XA' * x;
            t(end) = 1;
            L = cholinsert(L, length(A), t);
        end
    else
        % delete col from L
        L = choldelete(L, lars_idx);
    end

    % update equiangular vector
    s = sign(c(A));
    w = L \ (L' \ s);
    absA = 1 / sqrt(s' * w);
    w = absA * w;

    u = scale * XA * w;

    Ac = setdiff(1:p, A);  % complement of A
    % gamma hat
    if length(A) == maxvariables,
        gamma_hat = c_hat / absA;
        gamma_hat_idx = -1;
    else
        a = scale * X' * u;
        a(A) = a(A) + lambda * scale^2 * w;
        gamma_hat_tmp = [ (c_hat - c(Ac)) ./ (absA - a(Ac)), (c_hat + c(Ac)) ./ (absA + a(Ac)) ];
        gamma_hat_tmp(gamma_hat_tmp <= 0) = inf;

        [gamma_hat, idx] = min(min(gamma_hat_tmp, [], 2));
        gamma_hat_idx = Ac(idx);
    end

    % gamma tilde
    gamma_tilde_tmp = - beta(A) ./ w;
    gamma_tilde_tmp(gamma_tilde_tmp <= 0) = inf;

    [gamma_tilde, gamma_tilde_idx] = min(gamma_tilde_tmp);
    if isnan(gamma_tilde),
        gamma_tilde = inf;
    end

    % update step size
    if gamma_hat < gamma_tilde,
        step_size = gamma_hat;
    else
        step_size = gamma_tilde;
    end

    % candidate of next beta
    beta_new = zeros(n, 1);
    beta_new(A) = beta(A) + step_size * w;

    % compeare stop and norm1 of beta
    nrm1_prev = sum(abs(beta));
    nrm1_cur = sum(abs(beta_new));
    t1 = scale * stop;
    if nrm1_prev <= t1 && t1 < nrm1_cur,
        step_size = absA * (t1 - nrm1_prev);
        beta_new(A) = beta(A) + step_size * w;
        stop_loop = true;
    end

    % update beta and mu
    beta(A) = beta_new(A);
    mu = mu + step_size * u;

    if ~stop_loop,
        if gamma_hat < gamma_tilde,
            lars_idx = gamma_hat_idx;
            lasso_cond = true;
        else
            A(gamma_tilde_idx) = [];
            lars_idx = gamma_tilde_idx;
            lasso_cond = false;
        end
    end

end

beta = beta / scale;

 

 


このスクリプトを用いて diatetes.data を lambda_2 = 1 で解いた結果がこちらです。