やること
変数が三つある時によくやるのでメモ。 データで例えると「映画の視聴回数」「変数A」「変数B」で回数ごとに A, Bの変数に変化があるかを調べたい時にとりあえず見てみる。 大抵データの上位はサンプルが少ない(映画の例だと、映画を何百回も見る人は少ない)ので、上位の例は「視聴回数 x 回以上」でまとめてしまう。
ごり押しで書いたのでやり方全然スマートじゃないですね。
コード
データ作成
from numpy.random import * import pandas as pd import numpy as np N = 2000 # データ数 data_a = randint(0,100,N) data_b, data_c = randn(N), randn(N)
集計 + 上位の丸め込み
# group by でまとめる total = pd.DataFrame([data_a, data_b, data_c]).T total.columns = ['data_a', 'data_b', 'data_c'] means = total.groupby('data_a')[['data_b', 'data_c']].mean().reset_index() # group by の各列に含まれるレコードをカウント count = total.groupby('data_a').size().reset_index() count.columns = ['data_a', 'count'] means = pd.merge(count,\ means, how="inner", left_on="data_a", right_on="data_a") # 上位を集約 samples = 0 for i, ci in enumerate(means['count']): samples += ci if samples>(N-100): break # index i に i 以上のデータを丸め込んだ結果を入れる means.ix[i,:] = means.ix[i,0], means.ix[i:,:].sum(axis=0)[1], \ means.ix[i:,:].mean(axis=0)[2], means.ix[i:,:].mean(axis=0)[3] # index i 以上のデータを消す means = means.ix[:i, :]
total.groupby('data_a')[['data_b', 'data_c']].mean().reset_index()
の部分で、 data_a
の値ごとに groupbyを行い data_b、data_x の 平均をとる。
ここではif samples>(N-100):
で上位100人は丸め込んでしまう。
コード全体
from numpy.random import * import matplotlib.pyplot as plt import pandas as pd import numpy as np N = 2000 # データ数 data_a = randint(0,100,N) data_b, data_c = randn(N), randn(N) # group by でまとめる total = pd.DataFrame([data_a, data_b, data_c]).T total.columns = ['data_a', 'data_b', 'data_c'] means = total.groupby('data_a')[['data_b', 'data_c']].mean().reset_index() # group by の各列に含まれるレコードをカウント count = total.groupby('data_a').size().reset_index() count.columns = ['data_a', 'count'] means = pd.merge(count,\ means, how="inner", left_on="data_a", right_on="data_a") # 上位を集約 samples = 0 for i, ci in enumerate(means['count']): samples += ci if samples>(N-100): break # index i に i 以上のデータを丸め込んだ結果を入れる means.ix[i,:] = means.ix[i,0], means.ix[i:,:].sum(axis=0)[1], \ means.ix[i:,:].mean(axis=0)[2], means.ix[i:,:].mean(axis=0)[3] # index i 以上のデータを消す means = means.ix[:i, :] # プロット plt.figure(figsize=(9, 3)) plt.subplot(131) plt.scatter(means['data_a'], means['data_b']) plt.subplot(132) plt.scatter(means['data_a'], means['data_c']) plt.subplot(133) plt.plot(means['data_b'], label='data_b') plt.plot(means['data_c'], label='data_c') plt.axhline(y=0, color='r', alpha=.5) plt.legend() plt.tight_layout() plt.show()
出力