つづき・・・ | python3Xのブログ

python3Xのブログ

ここでは40代、50代の方が日々の生活で役に立つ情報や私の趣味であるプログラム、Excelや科学に関する内容で投稿する予定です。

import keras
# モデルの作成
model = Sequential()
model.add(Dense(2, input_dim = 2, activation = 'sigmoid',
           kernel_initializer = 'uniform'))
model.add(Dense(3,activation = 'softmax',
           kernel_initializer = 'uniform')) 
sgd = keras.optimizers.SGD(lr = 0.5, momentum = 0.0,
                           decay = 0.0, nesterov = False)
model.compile(optimizer = sgd, loss = 'categorical_crossentropy',
              metrics = ['accuracy']) 
#  学習
startTime = time.time()
history = model.fit(X_train, T_train, epochs = 1000, batch_size = 100,
                    verbose = 0, validation_data = (X_test, T_test)) 
# モデル評価
score = model.evaluate(X_test, T_test, verbose = 0) 
print('cross entropy {0:3.2f}, accuracy {1:3.2f}'.format(score[0], score[1]))
calculation_time = time.time() - startTime
print("Calculation time:{0:.3f} sec".format(calculation_time))
cross entropy 0.24, accuracy 0.91
Calculation time:4.342 sec
plt.figure(1, figsize = (12, 3))
plt.subplots_adjust(wspace=0.5)
# 学習曲線表示 
plt.subplot(1, 3, 1)
plt.title("学習曲線")
plt.plot(history.history['loss'], 'm', label='training') # (A)
plt.plot(history.history['val_loss'], 'cornflowerblue', label='test')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
# 精度表示 
plt.subplot(1, 3, 2)
plt.title("精度表示")
plt.plot(history.history['accuracy'], 'm', label='training') # (C)
plt.plot(history.history['val_accuracy'], 'cornflowerblue', label='test')
plt.xlabel("Epoch")
plt.ylabel("accuracy")
plt.legend()
# 境界線表示 
plt.subplot(1, 3, 3)
plt.title("等高線のよる境界線")
Show_data(X_test, T_test)
xn = 60  # 等高線表示の解像度
x0 = np.linspace(X_range0[0], X_range0[1], xn)
x1 = np.linspace(X_range1[0], X_range1[1], xn)
xx0, xx1 = np.meshgrid(x0, x1)
x = np.c_[np.reshape(xx0, xn * xn, 1), np.reshape(xx1, xn * xn, 1)]
y = model.predict(x) # (E)
K = 3
for ic in range(K):
    f = y[:, ic]
    f = f.reshape(xn, xn)
    f = f.T
    cont = plt.contour(xx0, xx1, f, levels=[0.5, 0.9], colors=[
                       'cornflowerblue', 'm'])
    cont.clabel(fmt='%1.1f', fontsize=9)
    plt.xlim(X_range0)
    plt.ylim(X_range1)
plt.show()