JavaDM
1 | 2 | 3 |oldest Next >>

当たり前ですけど、学習データって大事です。

2009-11-19 20:57:19 Theme: NaiveBayes
最近仕事で spam 判別器なんての作成していたんですが、この spam の分類器を作成していて思ったことがあります。


自分も割とそうなのですが、いわゆる精度を上げるために、アルゴリズムに傾倒してしまうようなことが普通にあって、


ベイズがベイジアンネットになり、SVMになり、SVM+カーネルロジスティックのBoostingになり・・・って具合でどんどん難しいアカデミックなところを攻めてしまうわけですが、


しかしながら実際の所、純粋に精度を上げるためには、


実はアルゴリズム云々よりも、


学習データの作りかたの方がよほど大事なんじゃないかなぁ?ってのが、最近の私の感想です。


この前、2値問題のスパム判定やってたんですが、


久々に学習データを見直すって話で、学習データ数を3倍に上げて、変数選択を見直したら検出数が二倍になり、


更にスパムを5カテゴリくらいに分割して、カテゴリ毎の学習データをそれぞれ作って、どれかスパムカテゴリに入ったらアウト!みたいな判定したら、検出数が3倍になったとかありました。
(無論精度は担保してあります。)


まぁ、元々のベースが酷すぎるだろ!?って話は、あるんですけど、


想定するspamの言語空間(spamの定義) と 学習データとの誤差を無くすってのが、いかに大事かってことを身にしみて思った次第です。


単純な エロspamコメント っていっても、実はサービスによって、その表現方法は微妙に異なります。


一例を上げると、サービスのキャリア、つまりPC or 携帯向けのサービスっていう所でも、画面表示力の差で携帯の方が短文スパムが送られてくる比率が非常に高い。


他にも、メッセージの種類、サービス利用者の層、認証のあるなし、などといったところで、やはりそれぞれに微妙に差異があるわけです。


その差異って実は決定的だったりするんですね。


また単純にspam判定されたテキストをそのまま学習データにしていくと、誤差がどうしても生まれるという話もあります。


例えば、一文中にspamだと思われる部分が8割あったとして、残りの2割は普通の文面だったとすると、後者の2割が、結構悪戯するわけです。


具体的には、外れ値除去とか変数選択とか、ホント真面目にやれって話なんですけどね。


まぁ、というわけで、


サービスごとのspamの傾向をうまく学習データに反映させたり、学習データの誤差を取り除くってのは凄く大事って話で、


確かにアルゴリズムをSVMにするってのは、普通というか、大事なことだとは思うんですけど、


それと同じかそれ以上に、学習データの作り方が、あたり前ですけど、改めて大事だなぁっと思った次第です。


終り。





同じテーマの最新記事

Naive Bayes その一 - smoothing -

2009-10-01 00:36:51 Theme: NaiveBayes
 テキストマイニングをやっていると、初期の頃は Naive Bayes とか使うと思うのですが、


まぁベイズの定理とかしばらく眺めてると、それなりに誰でも分かると思うんです。


一応オサライだけしとくと、


ベイズの定理
    事後確率 = ( 事前分布 × 尤度 ) / 結果


で、実際に式書くとこんな感じ。



p( c | x ) = p( c ) Π_i p( w_i | c ) / p( x )

( x = { w_1 ・・・ w_N } )



でした。


classification する際には、『結果』 p( x ) は c ごとの p( c | x ) を比較する上では無用なので、


結局、『尤度』 p( x | c ) と 『事前分布』 p( c ) の積が大事なんだよーって話でした。
( 実装ではちゃんと両者にlogとって足し算にしてください!でないと桁あふれするから! )


で、


この次の話としては、p( w_i | c )の解釈って実は2パターンあって、


それぞれ Multinomial NaiveBayes とか、Bernoulli NaiveBayes っていうんだよという話がでてくる。


この辺りで、もう簡単なspam filter の実装ができるくらいになっていて、なんだ簡単じゃんと思っていたところに、Smoothingの話がでてくる―――


―――んだが、しかし、


普通に Smoothing でググると、N-gram言語モデルの資料とかが出てきて???になり、


うまい資料とかがなくて、なんだかヤサグレテきて、


とりあえずWEBに出てきたラプラススムーシング辺り使っとけばいいじゃん?


ってなことをして、精度劣化しまくりで、なんだこれ???ってパターンになったりするんですな。


というというわけで!(前振り長!!)


今日はNaive Bayes の Smoothing の話をしてみます!
( 結論からいうと、この辺 みれば全部載ってる!w  )




閑話休題。



Smoothingってのは、


ある特定のクラスCの学習データに存在していない 単語w がクエリに存在した場合、
( 他のクラスでは w は学習データに登場している )


p( w | c ) = 0 なので、


事後確率 p( c | x ) ∝ P ( c ) Π_i P( w_i | c )


において、事後確率が0になってしまう。つまり、クエリテキストがどんなにクラス c らしい単語を沢山所持していても w を含むだけでそのテキストは c に属することができなくなる―――


―――んだけど、それは直感的にまずいですね?さぁどうしよう??


という時に使われるヒューリスティックルールのことです。
(正確にはベイズ確率の事前分布の話から来ているんだけど)


具体的にどうするかというと、


基本的には、全ての単語に非常に小さな確率の下駄を履かせておいて p( w_i | c ) をゼロにならなくなるようにするのです。


一応有名ところを挙げておくと、


n( w | c ) クラス c の学習データに登場する単語 w の数
δ smoothing parameter
| V | 辞書単語の総数
| W( c ) | クラス c に登場する単語の種類数


として、


○Additive Smoothing( 加算スムーシング )

p( C | w ) = { δ + n( w | C ) } / δ|V| + Σ_i n( w_i | C )



○Laplace Smoothing

p( C | w ) = { 1 + n( w | C ) } / |V| + Σ_i n( w_i | C )



○Backoff Smoothing


p( C | w ) = ( n( w | C ) - δ) / Σ_i n( w_i | C ) if n( w_i ,C ) > 0
       = ( δ | W(C) | ) / ( | W( C ) | - | V | ) * Σ_i n( w_i | C ) if n( w_i ,C ) = 0



○Interpolation Smoothing( 補完スムーシング )

p( C | w ) = max { 0 , ( n( w | C ) - δ) / Σ_i n( w_i | C ) }
+ ( δ | W(C) | ) * Σ_c n( w_i | C ) / { Σ_i n( w_i | C ) * Σ_c Σ_i n( w_i | C ) }




こんなのがあります。


Laplace Smoothingとかがよく?webで紹介されてる気がするんだけど、


ホントダメだから絶対に使わない方がいい。


これは私的感覚だけど、補完スムーシングが一番精度が良いと思う。私ぁいつもこれ使ってる。


さて、


ここで一つ老婆心ながら注意?しときたいがあるんです・・・、おそらく。


『 smoothingの仕方で精度が変わるのか? 』


これって、きっと誰でも気になったりすると思うのですが、
( 昔の人もそう思ったらしくて論文とかあるけどね )


まぁ私的な結論としては、Naive Bayes に関していうと、


まぁ別に変わんないんじゃないの?


って思ってます。
( 補完スムーシングがいいって・・・さっきと言ってることが違う・・・ってのは置いておいて )


確かにδ値が大きいうちはスムーシングによって判別器の性質って違うような気がするのですが、


大体どのスムーシングも、δの値を調整(大抵は凄く小さい値)にしてけば、多値問題だろうが、ニ値問題だろうが判別結果はほぼ同じになります。


まぁデータや変数の取り扱いがいい加減な時はいざ知らず、データの外れ値除去をちゃんとやって、変数(単語)選択をちゃんとやれば、スムーシングはどれでやってもほぼ同じってのが私の感覚です。


というわけで、


パラメータ調整とか新たなスムーシングの実装に腐心するよりも、


SVMの勉強とか、Logistic regressionの実装した方が、時間は有益に使えるんじゃないかなぁ?


と個人的には思ってますマス。


はい。


そんなことに時間をかけてた私は、、、、


昔は私も青かったなぁ・・・なんて。。。


おわり。

LSH その4 -pstableのサンプルコード-

2009-06-21 02:33:48 Theme: LSH
休み中にLSHの実装を見直しました。


pstableの実装が、かなり雑だったのでパッケージの構造から大幅に改修しました。まだまだですね。。。


せっかくなのでp-stableのサンプルコード的なモノ(Mavenのテストコード)を書いてみました。
(以前書いたsimHashのテストコード も訂正しておきました。)



package jp.ndca.toolkit.cluster.lsh.hash.pstable.data;

import static org.junit.Assert.assertEquals;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

import org.junit.Test;

import jp.ndca.toolkit.cluster.lsh.hash.HashProbability;
import jp.ndca.toolkit.cluster.lsh.hash.pstable.NormalHashFunctionHandler;
import jp.ndca.toolkit.cluster.lsh.hash.pstable.NormalHashProbability;
import jp.ndca.toolkit.cluster.lsh.hash.pstable.PstableHandler;
import jp.ndca.toolkit.cluster.lsh.hash.pstable.PstableHandlerWrapper;
import jp.ndca.toolkit.cluster.lsh.hash.pstable.data.PstableHammingDataStore;

public class PstableDataHammingStoreTest {

  @Test
  public void testSearch() throws IOException{

    /**
    * ベクトルデータの読み込み
    */
    InputStream is
      = Thread.currentThread().getContextClassLoader().getResourceAsStream("lsh.txt");
    BufferedReader br = new BufferedReader( new InputStreamReader(is) );

    List<int[]> vectorList = new ArrayList<int[]>();

    while( br.ready() ){
      String line = br.readLine();
      line = line.substring(1);           //[を除去
      line = line.substring(0, line.length()-1); //]を除去
      String[] numbers = line.split(",");
      if(numbers.length != 0)
        vectorList.add( StringArrayToIntegerArray(numbers) );
    }

    /**
    * LSHパラメータの取得
    */
    Properties prop = new Properties();
    InputStream _is
      = Thread.currentThread().getContextClassLoader().getResourceAsStream("lsh.properties")
    prop.load( _is );

    double c = Double.valueOf(prop.getProperty("c"));
    double r = Double.valueOf(prop.getProperty("r"));

    HashProbability pp = new NormalHashProbability( c, r );
    double p1 = pp.getGoodHashProb();
    double p2 = pp.getBadHashProb();

    int n = Integer.parseInt(prop.getProperty("n"));
    int dimension = Integer.parseInt(prop.getProperty("dimension"));

    PstableHandler ph = new NormalHashFunctionHandler( p1, p2, n, dimension, r );

    int K = ph.getK();
    int L = ph.getL();

    /**
    * 検索データの変換
    */
    PstableHandlerWrapper phw
      = new PstableHandlerWrapper( ph.generateHashFunctionVectorGeneratorList( K, L ) );
    PstableHammingDataStore pstableDataHammingStore
      = new PstableHammingDataStore( vectorList, phw );

    /**
    * 検索の実行
    */
    int[] query = new int[]{ 1 , 105, 119, 152, 177, 196, 215, 258, 315, 343, 413, 448 };

    long start = System.currentTimeMillis();
    String[] pstableHashes = phw.getPstableHashes(query);
    int[] result = pstableDataHammingStore.search(pstableHashes);
    long end = System.currentTimeMillis();
    long diff = end - start;

    assertEquals( true, classify(result, 9) );
    System.out.println(vectorList.size());
    System.out.println(result.length);
    System.out.println( diff );

  }

  private static int[] StringArrayToIntegerArray( String[] array ){
    int[] intArray = new int[ array.length ];
    for( int i = 0 ; i < array.length ; i++){
      if(array[i].equals(""))
        continue;
      intArray[i] = Integer.parseInt( array[i].trim() );
    }
    return intArray;
  }

  private static boolean classify( int[] candidateIds, int id ){
    for( int candidateId : candidateIds){
      if(candidateId==id)
        return true;
    }
    return false;
  }

}


このテストクラスでやってるのは、


① lsh.txtのHammingデータをint[]で取得した後、
② クラスパス直下にあるlsh.propertiesでpstableに必要なパラメータを読み込み、
③ このパラメータから、理論に用いる中間パラメータを生成。
④ ②、③のパラメータをPstableHandlerに渡して、LSHのデータ変換行列を作成。
⑤ この後、PstableHammingDataStoreにて、ハッシュ値と、idを格納するmap<String, List<Integer>>を内部的に生成し、
⑥ 最終的に、int[] queryで検索を行い、近傍点を抽出する。


ということをしています。


ちなみにこのqueryは、検索対象の lsh.txt の中にある10行目( ID9番目 ) のHammingデータと1成分しか違わないデータです。


テストコードでは、これがID9番目のハッシュ値と一緒になるということをclassifyメソッドで確かめています。


実際にテストコードを動かせばわかるとは思うのですが、このコードの計算時間は1ms以下であり、近傍データの候補として、全1000件のデータの中から大体50程度の近傍データの候補を抽出してきます。


queryに対して、普通に全てのデータとのユークリッド距離を計算しようとすると、たとえ今回のテストケースのように検索対象データが1000件程度であっても数十msかかってしまうので、それに比べたらかなり高速だという話です。


良かったら、何かに使ってくださいまし。


ちなみにこの検索では、LSHの検索条件 ( 2L個の候補データの抽出によって探索を打ち切る等 ) によって探索を終了するなどということは行っていません。あしからず。


ではでは。



1 | 2 | 3 |oldest Next >>
powered by Ameba by CyberAgent