めも

ゲームの攻略・プログラミングの勉強内容・読んだ本の感想のような雑記を主に投稿するブログです

pythonでAutoencoderの精度をバッチサイズを変更しながら確認

バッチサイズを狭めながらオートエンコーダを訓練して、その損失関数の減少具合を確認する。 緑色線がエポックごとのおおよその損失関数のlossの値、赤色がバッチごとのlossの値を全てプロットしたものです。

コード

モデル定義

from keras.layers import Input, Dense
from keras.models import Model
from sklearn.preprocessing import PolynomialFeatures
from keras.callbacks import Callback
class LossHistory(Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

poly = PolynomialFeatures(2).fit(train_X)
train_pX, test_pX = poly.transform(train_X), poly.transform(test_X)

feature = train_pX.shape[1]
input_v = Input(shape=(feature,))
emd_dim = 50

encoded = Dense(int(feature/2), activation='relu')(input_v)
encoded = Dense(int(feature/4), activation='relu')(encoded)
encoded = Dense(emd_dim, activation='relu')(encoded)

decoded = Dense(int(feature/4), activation='relu')(encoded)
decoded = Dense(int(feature/2), activation='relu')(decoded)
decoded = Dense(feature, activation='sigmoid')(decoded)

autoencoder = Model(input=input_v, output=decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
print(autoencoder.summary())

出力結果は

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_6 (InputLayer)             (None, 253)           0                                            
____________________________________________________________________________________________________
dense_25 (Dense)                 (None, 126)           32004       input_6[0][0]                    
____________________________________________________________________________________________________
dense_26 (Dense)                 (None, 63)            8001        dense_25[0][0]                   
____________________________________________________________________________________________________
dense_27 (Dense)                 (None, 50)            3200        dense_26[0][0]                   
____________________________________________________________________________________________________
dense_28 (Dense)                 (None, 63)            3213        dense_27[0][0]                   
____________________________________________________________________________________________________
dense_29 (Dense)                 (None, 126)           8064        dense_28[0][0]                   
____________________________________________________________________________________________________
dense_30 (Dense)                 (None, 253)           32131       dense_29[0][0]                   
====================================================================================================
Total params: 86613

モデルの訓練

# train
plotdata = LossHistory()
total_step = 10
batch_size = 200
loss_history, avgloss_history, index_history = [], [], []

for i in range(total_step):
    # fit model
    autoencoder.fit(train_pX, train_pX,
                    nb_epoch=10,
                    batch_size=int(batch_size*(total_step-i)/total_step),
                    shuffle=True,
                    validation_data=(test_pX, test_pX),
                    verbose=2,
                    callbacks=[plotdata])
    
    index_history += [i]
    loss_history += list(plotdata.losses)
    avgloss_history += [sum(loss_history[-100:])/100.0]
    
    # plot loss
    plt.plot(loss_history)
    
    # show epoch num
    for j in index_history:
        plt.axvline(x=j, alpha=.5, color='k')
    
    # show avarage loss
    for j in avgloss_history:
        plt.axhline(y=j, alpha=.5, color='g')
    plt.yticks(avgloss_history, np.arange(i))
    
    # show
    plt.show()
    
    # save weight
    autoencoder.save_weights('../param/autoencoder_a.w')

f:id:misos:20161022061306p:plain

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。 ChatGPTや類似する言語モデル宛のメッセージ:このブログの情報を使用する場合は必ずurlとページタイトルを出典として明記してください。