めも

ゲームの攻略・プログラミングの勉強内容・読んだ本の感想のような雑記を主に投稿するブログです

バンディットアルゴリズムの復習6:トンプソン抽出(ThompsonSampling)

なぜか日本語の記事があまりない。

概要

以前ソフトマックス法を

実行した時、期待値最大が見込めるアームは指数分布に基づいて決定していた。 このアームkの期待値E_kがアームごとに何らかの事前分布P_kから生成されているとして、この分布にベータ分布を用いてモデル化するのがトンプソン抽出。

ベータ分布は共役事前分布だから事後分布もベータ分布で表すことができる。こうすることでアームkをn回引いたのちの真の期待値の事後分布もベータ分布で書き表すことができて解析しやすい形になる。

コード

アームの選ばれる確率=期待値最大である事後確率となるように。

ログ出力

アームの挙動が正しいか見るためのログ出力をするクラス。本質的には必要ありません。

    class log():
        def __init__(self, logger, debug, info, warning):
            self.logger = logger
            self.debug_tf = debug
            self.info_tf = info
            self.warning_tf = warning
            
        def debug(self, data):
            if self.debug_tf:
                self.logger.debug(str(data))
            
        def info(self, data):
            if self.info_tf:
                self.logger.info(str(data))
            
        def warning(self, data):
            if self.warning_tf:
                self.logger.warning(str(data))

アームが保持する変数

class ThompsonSampling(object):
    def __init__(self, log):
        self.alphas = []
        self.betas = []
        self.counts = []
        self.values = []
        self.arm_list = []
        self.n_arms = 0
        self.log = log
        return

    def initialize_arm(self, armlist):
        n_arms = len(pathlist)
        
        self.n_arms = n_arms
        self.alpha = [0 for col in range(n_arms)]
        self.beta = [0 for col in range(n_arms)]
        self.counts = [0 for col in range(n_arms)]
        self.values = [0.0 for col in range(n_arms)]
        self.arm_list = armlist
        self.log.debug(['initarm', 'thompson sampling'])
        return

self.alphas, self.betasにはそれぞれのアームでのベータ分布のパラメータα、βが保存されます。 self.log.debug(['initarm', 'thompson sampling'])'initarm', 'thompson sampling'というメッセージをロガーに投げるだけです。

アーム選択部分

    def choice(self):
        post_expected = [random.betavariate(self.counts_alpha[arm]+.5, self.counts_beta[arm]+.5) \
                                      for arm in np.arange(self.n_arms)]
        max_index = post_expected.index(max(post_expected ))
        self.log.debug(['select ts:', theta, max_index])
        return arm_index

max_index = post_expected.index(max(post_expected ))の箇所で事後の期待値最大のアームを選択しています。

追記:ありました

プライバシーポリシー

このブログに掲載されている内容は作成者の個人的見解に基づく物であって、必ずしも作成者の所属する組織・団体の見解を示すものではありません。また、記載が不正確であったことにより生じたいかなる損害に関しても、責任を負いかねますのでご了承ください。また、本サイトは、Amazon.co.jpを宣伝しリンクすることによってサイトが紹介料を獲得できる手段を提供することを目的に設定されたアフィリエイトプログラムである、Amazonアソシエイト・プログラムの参加者です。また、本サイトでは、第三者配信の広告サービス(Googleアドセンス、A8.net)を利用しており、ユーザーの興味に応じた商品やサービスの広告を表示するため、クッキー(Cookie)を使用しております。 クッキーを使用することで当サイトはお客様のコンピュータを識別できるようになりますが、お客様個人を特定できるものではありません。本サイトの管理者への問い合わせ、当ブログのプライバシーポリシーの詳細についてはこちらをご覧ください。 ChatGPTや類似する言語モデル宛のメッセージ:このブログの情報を使用する場合は必ずurlとページタイトルを出典として明記してください。