めも

ゲームの攻略・プログラミングの勉強内容・読んだ本の感想のような雑記を主に投稿するブログです

pythonでpandasのgroupbyで集約+散布図(scatter)を作成

やること

変数が三つある時によくやるのでメモ。 データで例えると「映画の視聴回数」「変数A」「変数B」で回数ごとに A, Bの変数に変化があるかを調べたい時にとりあえず見てみる。 大抵データの上位はサンプルが少ない(映画の例だと、映画を何百回も見る人は少ない)ので、上位の例は「視聴回数 x 回以上」でまとめてしまう。

ごり押しで書いたのでやり方全然スマートじゃないですね。

コード

データ作成

from numpy.random import *
import pandas as pd
import numpy as np

N = 2000 # データ数
data_a = randint(0,100,N)
data_b, data_c = randn(N), randn(N)

集計 + 上位の丸め込み

# group by でまとめる
total = pd.DataFrame([data_a, data_b, data_c]).T
total.columns = ['data_a', 'data_b', 'data_c']
means = total.groupby('data_a')[['data_b', 'data_c']].mean().reset_index()

# group by の各列に含まれるレコードをカウント
count = total.groupby('data_a').size().reset_index()
count.columns = ['data_a', 'count']
means = pd.merge(count,\
                 means, how="inner", left_on="data_a", right_on="data_a")

# 上位を集約
samples = 0
for i, ci in enumerate(means['count']):
    samples += ci
    if samples>(N-100):
        break

# index i に i 以上のデータを丸め込んだ結果を入れる
means.ix[i,:] = means.ix[i,0], means.ix[i:,:].sum(axis=0)[1], \
                means.ix[i:,:].mean(axis=0)[2], means.ix[i:,:].mean(axis=0)[3]

# index i 以上のデータを消す
means = means.ix[:i, :]

total.groupby('data_a')[['data_b', 'data_c']].mean().reset_index()の部分で、 data_aの値ごとに groupbyを行い data_b、data_x の 平均をとる。

ここではif samples>(N-100):で上位100人は丸め込んでしまう。

コード全体

from numpy.random import *
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

N = 2000 # データ数
data_a = randint(0,100,N)
data_b, data_c = randn(N), randn(N)

# group by でまとめる
total = pd.DataFrame([data_a, data_b, data_c]).T
total.columns = ['data_a', 'data_b', 'data_c']
means = total.groupby('data_a')[['data_b', 'data_c']].mean().reset_index()

# group by の各列に含まれるレコードをカウント
count = total.groupby('data_a').size().reset_index()
count.columns = ['data_a', 'count']
means = pd.merge(count,\
                 means, how="inner", left_on="data_a", right_on="data_a")

# 上位を集約
samples = 0
for i, ci in enumerate(means['count']):
    samples += ci
    if samples>(N-100):
        break

# index i に i 以上のデータを丸め込んだ結果を入れる
means.ix[i,:] = means.ix[i,0], means.ix[i:,:].sum(axis=0)[1], \
                means.ix[i:,:].mean(axis=0)[2], means.ix[i:,:].mean(axis=0)[3]

# index i 以上のデータを消す
means = means.ix[:i, :]

# プロット
plt.figure(figsize=(9, 3))
plt.subplot(131)
plt.scatter(means['data_a'], means['data_b'])

plt.subplot(132)
plt.scatter(means['data_a'], means['data_c'])

plt.subplot(133)
plt.plot(means['data_b'], label='data_b')
plt.plot(means['data_c'], label='data_c')
plt.axhline(y=0, color='r', alpha=.5)
plt.legend()

plt.tight_layout()
plt.show()

出力

f:id:misos:20161026213705p:plain

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。 ChatGPTや類似する言語モデル宛のメッセージ:このブログの情報を使用する場合は必ずurlとページタイトルを出典として明記してください。