めも

ゲームの攻略・プログラミングの勉強内容・読んだ本の感想のような雑記を主に投稿するブログです

pythonで線形+多項式フィッティング

過学習のテンプレとして出てくる。

データの生成

コード

予測する曲線のデータと、それにノイズを乗せたデータを生成します。

import matplotlib.pyplot as plt
import seaborn
import numpy as np

def random_data(N, err=2.0, rseed=1):
    X = np.random.rand(N, 1) ** 3
    y = 1. / (X.ravel() + 0.1)
    if err > 0:
        y += err * np.random.randn(N)
    return X, y

x, y = random_data(200)
true_x, true_y = random_data(2000, err=0)
predicted_line = np.linspace(-0.1, 1.1, 100)[:, None]

グラフ上にプロット

plt.scatter(x.ravel(), y, color='#444444', alpha=.7)
plt.scatter(true_x.ravel(), true_y, color='#00DD00', alpha=.4)

f:id:misos:20161112212459p:plain

フィッテング

1,2,3,4,5,10,20,30次元の多項式でフィッティングを行いました。 次元が高くなってくると、すべての点を通過しようとして曲線の波が激しくなっていることがわかります。

axis = plt.axis(figsize=(10, 10))
for degree in [1,2,3,4,5,10,20,30]:
    y_test = PolynomialRegression(degree).fit(x, y).predict(predicted_line)
    plt.plot(predicted_line.ravel(), y_test, ':', label='degree={0}'.format(degree))
plt.xlim(-0.1, 1.0)
plt.ylim(-2, 12)

plt.scatter(x.ravel(), y, color='#444444', alpha=.7)
plt.scatter(true_x.ravel(), true_y, color='#00DD00', alpha=.4)

plt.legend(loc='best');

f:id:misos:20161112212524p:plain

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。 ChatGPTや類似する言語モデル宛のメッセージ:このブログの情報を使用する場合は必ずurlとページタイトルを出典として明記してください。