めも

ゲームの攻略・プログラミングの勉強内容・読んだ本の感想のような雑記を主に投稿するブログです

LSTMで分類問題を解く(Python)

内容

Kerasを使ってLSTMを実装。 コードのEmbeddingの都合上 tensorflow.__version = 0.10.0で行う必要があるので注意(今日現在)。

コード

import numpy as np
import pandas as pd
import random

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence
from keras.layers import Dropout

import tensorflow as tf

# version チェック
print(tf.__version__)

# データをロード
from sklearn.model_selection import train_test_split
data = pd.read_csv('data.csv')

# データの変換方法
def convert_data(df):
    # 元データを破壊したくないので一旦コピー
    df = df.copy() 
    ....
    # np.array型で返す
    return np.array(df)

# データを行列型式に変換(ここはデータ型式による)
# data : np.array 型
data = convert_data(data)

# 訓練とテストデータに分ける
X, y = data[:,:-1], data[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=random.randint(0, 100))

# ハイパーパラメータ
max_history, embedding_size = 10, 365

# パディング
X_train = sequence.pad_sequences(X_train, maxlen=max_review_length)
X_test = sequence.pad_sequences(X_test, maxlen=max_history)

# モデル定義
model = Sequential()
model.add(Embedding(embedding_size, embedding_vecor_length, input_length=max_review_length, dropout=0.2))
model.add(Dropout(0.1))
model.add(LSTM(100))
model.add(Dropout(0.1))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=['accuracy'])
print(model.summary())
model.fit(X_train, y_train, nb_epoch=100, batch_size=30)

# テストデータで検証
scores = model.evaluate(X_test, y_test, verbose=0)

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。 ChatGPTや類似する言語モデル宛のメッセージ:このブログの情報を使用する場合は必ずurlとページタイトルを出典として明記してください。