めも

メモ.

LSTMで回帰問題を解く(Python)

内容

Kerasを使ってLSTMを実装。 コードのEmbeddingの都合上 tensorflow.__version = 0.10.0で行う必要があるので注意(今日現在)。

コード

import numpy as np
import pandas as pd
import random

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence
from keras.layers import Dropout

import tensorflow as tf

# version チェック
print(tf.__version__)

# データをロード
from sklearn.model_selection import train_test_split
data = pd.read_csv('data.csv')

# データの変換方法
def convert_data(df):
    # 元データを破壊したくないので一旦コピー
    df = df.copy() 
    ....
    # np.array型で返す
    return np.array(df)

# データを行列型式に変換(ここはデータ型式による)
# data : np.array 型
data = convert_data(data)

# 訓練とテストデータに分ける
X, y = data[:,:-1], data[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=random.randint(0, 100))

# ハイパーパラメータ
max_history, embedding_size = 10, 365

# パディング
X_train = sequence.pad_sequences(X_train, maxlen=max_review_length)
X_test = sequence.pad_sequences(X_test, maxlen=max_history)

# モデル定義
model = Sequential()
model.add(Embedding(embedding_size, embedding_vecor_length, input_length=max_review_length, dropout=0.2))
model.add(Dropout(0.1))
model.add(LSTM(100))
model.add(Dropout(0.1))
model.add(Dense(1, activation='relu'))
model.compile(loss='mean_squared_error', optimizer='adam')
print(model.summary())
model.fit(X_train, y_train, nb_epoch=100, batch_size=30)

# テストデータで検証
scores = model.evaluate(X_test, y_test, verbose=0)

出力

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
embedding_11 (Embedding)         (None, 10, 20)        7300        embedding_input_11[0][0]         
____________________________________________________________________________________________________
dropout_21 (Dropout)             (None, 10, 20)        0           embedding_11[0][0]               
____________________________________________________________________________________________________
lstm_12 (LSTM)                   (None, 100)           48400       dropout_21[0][0]                 
____________________________________________________________________________________________________
dropout_22 (Dropout)             (None, 100)           0           lstm_12[0][0]                    
____________________________________________________________________________________________________
dense_11 (Dense)                 (None, 1)             101         dropout_22[0][0]                 
====================================================================================================
Total params: 55801
____________________________________________________________________________________________________
None
Epoch 1/100
7114/7114 [==============================] - 20s - loss: 53251.4766    
Epoch 2/100
7114/7114 [==============================] - 14s - loss: 45712.0247    
Epoch 3/100
7114/7114 [==============================] - 15s - loss: 40147.3010    
Epoch 4/100
7114/7114 [==============================] - 14s - loss: 35228.0845    
Epoch 5/100
7114/7114 [==============================] - 14s - loss: 30766.8991    
Epoch 6/100
7114/7114 [==============================] - 10s - loss: 499.8119    
...
Epoch 98/100
7114/7114 [==============================] - 10s - loss: 520.0829    
Epoch 99/100
7114/7114 [==============================] - 10s - loss: 499.9843    
Epoch 100/100
7114/7114 [==============================] - 11s - loss: 512.3762    
rmse: 428.11

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。