SVMを使って多クラス分類を行う | python3Xのブログ

python3Xのブログ

ここでは40代、50代の方が日々の生活で役に立つ情報や私の趣味であるプログラム、Excelや科学に関する内容で投稿する予定です。

元々2項分類器であるSVMで多クラス分類を行います。

手書き文字0~9までの文字を分類します。

データセットは

scikit-learnに用意された手書き文字セットを利用します。

参考URLではサンプルを2で割っていましたが、整数ではないのでエラーが出てしまいました。

おそらく、python2系ではOKで3系ではうまくいかないのでしょう。

 

※混同行列(Confusion Matrix)とは、クラス分類の結果をまとめた表のこと。

陽性のサンプルのうち、何個が正しく陽性と判定され、何個が誤って陰性と判定されたか、

といったことを分かりやすくまとめるために用いる。クロス表の一種。

 

URL:https://minus9d.hatenablog.com/entry/2015/04/19/190732

プログラム

import matplotlib.pyplot as plt
from sklearn import datasets, svm, metrics

digits = datasets.load_digits()

# 手書き数字データセットの学習サンプルを4個描画
images_and_labels = list(zip(digits.images, digits.target))
print("dataset size = ", len(images_and_labels))
for index, (image, label) in enumerate(images_and_labels[:4]):
    plt.subplot(2, 4, index + 1) # 2x4マスの上半分に描画
    plt.axis('off')
    plt.imshow(image,
               cmap=plt.cm.gray_r,    
               interpolation='nearest'
    )
    plt.title('Training: %i' % label, fontsize=24)
n_samples = len(digits.images)                # サンプル数
# もともとのサイズは 1794 x 8 x 8
print("元の手書き文字のシェイプ = ", digits.images.shape)
data = digits.images.reshape((n_samples, -1)) # 8x8の部分を64次元の1次元ベクトルに変形
# 変換後のサイズは1794 x 64
# つまり、画像を64次元の1次元ベクトルとした
print("変換後の手書き文字のシェイプ = ", data.shape)

# SVMを用いて多クラス分類(one-vs-one戦略を使用)
# RBFカーネル(Gaussianカーネル)のハイパーパラメータであるガンマの値を設定
# この値が大きいほど複雑な決定境界となる

classifier = svm.SVC(gamma=0.001)

# データの半分を学習に使う
classifier.fit(data[:int(n_samples * 0.5)], digits.target[:int(n_samples * 0.5)])
# 訓練データの一つを見てみる
print("訓練データの一つを表示: ")
print(data[0])

# 残りの半分に対して、GTと、SVMの予測結果を比較する
expected = digits.target[int(n_samples * 0.5):]
predicted = classifier.predict(data[int(n_samples * 0.5):])

print("分類に用いたSVMの詳細情報 %s:\n%s\n"
      % (classifier,   # 分類に用いたSVNの詳細情報
         metrics.classification_report(expected, predicted)  # 各ラベルごとの分類結果
     )
)
print("混合行列:\n%s" % metrics.confusion_matrix(expected, predicted))

# SVMによる分類例を4個描画
images_and_predictions = list(zip(digits.images[int(n_samples * 0.5):], predicted))
for index, (image, prediction) in enumerate(images_and_predictions[:4]):
    plt.subplot(2, 4, index + 5)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('SVM分類: %i' % prediction, fontsize=24)
plt.show()
======================================================================== 
データセットのサイズ =  1797
元の手書き数字のシェイプ =  (1797, 8, 8)
変換後の手書き文字のシェイプ =  (1797, 64)
訓練データの一つを表示:
[ 0.  0.  5. 13.  9.  1.  0.  0.  0.  0. 13. 15. 10. 15.  5.  0.  0.  3.
 15.  2.  0. 11.  8.  0.  0.  4. 12.  0.  0.  8.  8.  0.  0.  5.  8.  0.
  0.  9.  8.  0.  0.  4. 11.  0.  1. 12.  7.  0.  0.  2. 14.  5. 10. 12.
  0.  0.  0.  0.  6. 13. 10.  0.  0.  0.]
分類に用いたSVMの詳細情報 SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False):
             precision    recall  f1-score   support
          0       1.00      0.99      0.99        88
          1       0.99      0.97      0.98        91
          2       0.99      0.99      0.99        86
          3       0.98      0.87      0.92        91
          4       0.99      0.96      0.97        92
          5       0.95      0.97      0.96        91
          6       0.99      0.99      0.99        91
          7       0.96      0.99      0.97        89
          8       0.94      1.00      0.97        88
          9       0.93      0.98      0.95        92
avg / total       0.97      0.97      0.97       899

混同行列:
[[
87  0  0  0  1  0  0  0  0  0]
 [ 0
88  1  0  0  0  0  0  1  1]
 [ 0  0
85  1  0  0  0  0  0  0]
 [ 0  0  0
79  0  3  0  4  5  0]
 [ 0  0  0  0
88  0  0  0  0  4]
 [ 0  0  0  0  0
88  1  0  0  2]
 [ 0  1  0  0  0  0
90  0  0  0]
 [ 0  0  0  0  0  1  0
88  0  0]
 [ 0  0  0  0  0  0  0  0
88  0]
 [ 0  0  0  1  0  1  0  0  0
90]]