非線形に対応しているのはSVRのRBFであることは承知で
曲線SIN(X)にノイズを加え、元の曲線にどれだけ近づけ
より良い回帰が可能か視覚的に比較してみます。
プログラム
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt
from sklearn.svm import SVR
import matplotlib.pyplot as plt
# インプットを乱数で生成
X = np.sort(10 * np.random.rand(80, 1), axis=0)
# アウトプットはnumpy sin関数
y = np.sin(X).ravel()
X = np.sort(10 * np.random.rand(80, 1), axis=0)
# アウトプットはnumpy sin関数
y = np.sin(X).ravel()
# アウトプットにノイズを与える
y[::5] += 3 * (0.5 - np.random.rand(16))
y[::5] += 3 * (0.5 - np.random.rand(16))
# RBFカーネル、線形、多項式でフィッティング
svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.15)
svr_lin = SVR(kernel='linear', C=1e3)
svr_poly = SVR(kernel='poly', C=1e3, degree=2)
y_rbf = svr_rbf.fit(X, y).predict(X)
y_lin = svr_lin.fit(X, y).predict(X)
y_poly = svr_poly.fit(X, y).predict(X)
svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.15)
svr_lin = SVR(kernel='linear', C=1e3)
svr_poly = SVR(kernel='poly', C=1e3, degree=2)
y_rbf = svr_rbf.fit(X, y).predict(X)
y_lin = svr_lin.fit(X, y).predict(X)
y_poly = svr_poly.fit(X, y).predict(X)
# プロットする
plt.figure(figsize=[10, 5])
plt.rcParams["font.size"] = 24
plt.scatter(X, y, c='c', label='データ', s=60, marker='s')
plt.hold('on')
plt.plot(X, np.sin(X), c='y', linestyle='--', label='元の曲線sin(x)', linewidth=8)
plt.plot(X, y_rbf, c='g', label='RBFモデル', linewidth=6)
plt.plot(X, y_lin, c='m', label='線形モデル', linewidth=6)
plt.plot(X, y_poly, c='b', label='多項式モデル', linewidth=6)
plt.xlabel('データ')
plt.ylabel('ターゲット')
plt.title('SVRカーネル別比較')
plt.legend()
plt.show()
plt.figure(figsize=[10, 5])
plt.rcParams["font.size"] = 24
plt.scatter(X, y, c='c', label='データ', s=60, marker='s')
plt.hold('on')
plt.plot(X, np.sin(X), c='y', linestyle='--', label='元の曲線sin(x)', linewidth=8)
plt.plot(X, y_rbf, c='g', label='RBFモデル', linewidth=6)
plt.plot(X, y_lin, c='m', label='線形モデル', linewidth=6)
plt.plot(X, y_poly, c='b', label='多項式モデル', linewidth=6)
plt.xlabel('データ')
plt.ylabel('ターゲット')
plt.title('SVRカーネル別比較')
plt.legend()
plt.show()
