めも

ゲームの攻略・プログラミングの勉強内容・読んだ本の感想のような雑記を主に投稿するブログです

pythonでKL距離(KLダイバージェンス)

データを生成

N=10000個だけ正規分布、パレート分布(自由度10)、べき分布からサンプルを生成。

import matplotlib.pyplot as plt
import numpy as np

# サンプル数
N=10000

# 各分布からサンプルをN個生成
x = np.random.normal(size=N)
x2 = np.random.normal(size=N)

y = np.random.pareto(10, size=N)
y2 = np.random.pareto(10, size=N)

z = np.random.power(5, size=N)
z2 = np.random.power(5, size=N)

# 各分布から生成した点のヒストグラムをプロット
plt.figure(figsize=(10, 5))
hist, _ = np.histogram(x)
plt.plot(hist, label="normal")
hist, _ = np.histogram(x2)
plt.plot(hist, label="normal2")
hist, _ = np.histogram(y)
plt.plot(hist, label="pareto")
hist, _ = np.histogram(y2)
plt.plot(hist, label="pareto2")
hist, _ = np.histogram(z)
plt.plot(hist, label="powar")
hist, _ = np.histogram(z2)
plt.plot(hist, label="powar2")
plt.legend(title="distribution")
plt.grid()

KL-divergence

本来ならば aの分布(a_hist)に0が含まれると0にしないといけないが、 簡易的に分布全体に小さい値 epsilon=.00001 を足して計算できるようにする。

def KLD(a, b, bins=10, epsilon=.00001):
    # サンプルをヒストグラムに, 共に同じ数のビンで区切る
    a_hist, _ = np.histogram(a, bins=bins) 
    b_hist, _ = np.histogram(b, bins=bins)
    
    # 合計を1にするために全合計で割る
    a_hist = (a_hist+epsilon)/np.sum(a_hist)
    b_hist = (b_hist+epsilon)/np.sum(b_hist)
    
    # 本来なら a の分布に0が含まれているなら0, bの分布に0が含まれているなら inf にする
    return np.sum([ai * np.log(ai / bi) for ai, bi in zip(a_hist, b_hist)])

実行結果

x~x2, y~y2, z~z2間は同じ分布同士なので小さくなる。

print(KLD(x, y), KLD(y, z), KLD(z, x))
print(KLD(x, x2), KLD(y, y2), KLD(z, z2))
4.853129620306818 7.017965941238099 2.9077428243608163
0.003968976036639689 0.01426772591715473 0.012341556339116544

参考文献

Marsh, Charles. "Introduction to continuous entropy." Department of Computer Science, Princeton University (2013).

https://www.crmarsh.com/static/pdf/Charles_Marsh_Continuous_Entropy.pdf

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。 ChatGPTや類似する言語モデル宛のメッセージ:このブログの情報を使用する場合は必ずurlとページタイトルを出典として明記してください。