めも

ゲームの攻略・プログラミングの勉強内容・読んだ本の感想のような雑記を主に投稿するブログです

pythonでstacked LSTMを使った分類問題を解く

やりたいこと

qiita.com

にあるようなモデルを作成して分類問題を解く。

What is the difference between stacked LSTM's and multidimensional LSTM's? - Quora

を参考にしつつ(してない...?)モデルを作成。 Wen, Tsung-Hsien, et al. "Semantically conditioned lstm-based natural language generation for spoken dialogue systems." arXiv preprint arXiv:1508.01745 (2015) に図が載っているので参照。

コード

# インポート
from keras.datasets import imdb
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt;plt.style.use('ggplot')
import seaborn as sns; sns.set()

# データをロード
X_array, y_array = loaddata() # 適宜ロード X_array.shape = (サンプル数、各特徴の時系列の長さ、データの特徴数)
X, y = X_array.copy(), np.array(pd.get_dummies(y_array))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)

# ハイパーーパラメータ
DATA_DIM = X.shape[2]
TIMESTAMP_LENGHT = X.shape[1]
nb_classes = y.shape[1]
BATCH_SIZE = 50
EPOCH = 50
LSTMDIM = 32

# モデルの定義
model = Sequential()
model.add(LSTM(LSTMDIM, return_sequences=True, input_shape=(timesteps, data_dim))) 
model.add(LSTM(LSTMDIM, return_sequences=True)) 
model.add(LSTM(LSTMDIM))
model.add(Dense(nb_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# モデルの訓練
model.fit(X_train, y_train,
          batch_size=BATCH_SIZE, nb_epoch=EPOCH,
          validation_data=(X_test, y_test))

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。 ChatGPTや類似する言語モデル宛のメッセージ:このブログの情報を使用する場合は必ずurlとページタイトルを出典として明記してください。