めも

これはメモ。

論文メモ:Learning to teach

以下の論文を読んだ時のメモです。書きかけです。

Fan, Y., Tian, F., Qin, T., Li, X.-Y., and Liu, T.-Y. Learning to teach. In ICLR, 2018.

pdf: https://openreview.net/pdf?id=HJewuJWCZ

この論文の後にLi Fei-Feiらが出した

Jiang, Lu, et al. "MentorNet: Learning Data-Driven Curriculum for Very Deep Neural Networks on Corrupted Labels." International Conference on Machine Learning. 2018.

も時間があれば読みたいですね。

概要

  • 実際の教育では良い教師は教材・適切な教育方法・試験を教師が選択する
  • 機械学習の研究では、学習アルゴリズムは十分に研究が進んでいるが、教師側:つまり学習するデータのサンプリングや目的関数の設定には注意が払われていない
  • 教師側にもなんらかの戦略を導入すべきであり、それはヒューリスティックで決めるのではなく最適化を用いた戦略であるべき
  • 2つのエージェントが相互に作用するフレームワークを提案
    • 生徒:学習を行うエージェント
    • 教師:学習のためのサンプル、損失関数、学習のための仮説空間を決めるエージェント
  • 一例として教師エージェントがサンプルの選び方を強化学習で学び、SPL(ベースライン)を超える性能

アルゴリズム

問題設定

一般的な今日しあり学習の枠組みだと

  • X: 特徴ベクトルの存在する空間
  • y: 特徴ベクトルxに対応するラベルが存在する空間

でありXからサンプルxを適当に取ってくると、教師はそれに対応するラベル y を未知の分布 P(y | x)から与える。最終的には入力xから教師が与えるラベルを予測する関数 f のパラメータを求めることが目的となる。

求めた関数fはfを用いて予測したラベルと実際のラベル間の距離を図る関数Mを用いて

R(ω) = ∫ M(y, f(x, ω)) dP(x, y)

というリスクで表現できる。この時

  • P(x, y)は未知の分布
  • Mそのものを最適化することは難しいためそれを近似した最適化可能な関数(サロゲート損失関数)を用いる
  • fのパラメータの探索空間はあらかじめ与える

ことが多い。そして関数の良さは訓練データDで測るため、モデルの訓練とは(訓練データD、関数の空間Ω、損失関数L)を与えられた時にLを最小にするようなパラメタω*を得ることに相当する。

"Learning to teach(L2T)"の枠組みでは

  • 教師は生徒(モデル)が学習を簡単に進められるような訓練データdを選んでくる。訓練データdは教師が選んできた適切なテキストに相当する。
  • 教師は生徒に適切な損失関数を設計して生徒の学習をただしく導く必要がある。ここでの損失関数は学習中に行われる試験に相当する。
  • 教師は生徒に適切な仮設空間(学習で得られる関数が得られる空間)を設計する必要がある。

なので L2T の枠組みでは教師は訓練データd、損失関数L、仮設空間Ωを適切に選んで生徒の学習を導く必要がある。

アルゴリズムのフレームワーク

表記

以下tは離散で表されるタイムステップ。

  • S: s_t in S はタイムステップ t に教師が利用することのできる情報の集合で、典型的には現在の生徒 f_t-1 と過去の指導履歴。
  • 教師エージェントは状態 s_t を受け取って次に取るべき行動 a_t を出力する。a_t は生徒に与える訓練データ、損失関数、仮設空間などがありうる。この教師エージェントはパラメータθを持つ関数φ(s) = a で表せる。
  • 生徒エージェントは教師エージェントから a_t を受け取り関数 f_t(予測モデル、Xからyへの写像)を出力する
  • D_trainteacher, D_teststudent: データセットDを訓練・テスト、教師・生徒のモデルの訓練・テストするために分割したDの表記

教師のモデル

モデルは特に限定していないが、実験では強化学習の枠組みを使用。

  • φ: ポリシー
  • S: 環境, 生徒は教師の行動 a_t を見て s_t を更新, s_t = (tにきたバッチ、f_t)。
  • r_t: 報酬は「生徒の良さ」を表すように設計したい。例えばバリデーションデータセットでのサロゲート損失の値など。以下の実験では収束を速めることを目的とする用の報酬を設計。
  • a_t: タイムステップtで教師が取る行動。{0, 1}^(バッチサイズ)で表され、1に対応するサンプルだけが訓練に利用される。

生徒のモデル

  • f: 生徒のモデルfは画像分類タスクでのニューラルネットのモデル
  • M: accuracy
  • g(a, f): 状態sを表現するベクトルを作る関数g。
    • サンプルに選んだデータの特徴
    • モデルの特徴
  • 生徒モデルのニューラルネット
    • multi-layer perceptron (MLP)
    • convolutional neural networks (CNNs)の一例としてResNet
    • recurrent neural networks (RNNs)の一例としてLSTM
  • 学習に利用するデータセット
    • MNIST
    • CIFER-10
    • IMDB

※IMDBデータセットについては以下のページを参照してください

Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. (2011). Learning Word Vectors for Sentiment Analysis. The 49th Annual Meeting of the Association for Computational Linguistics (ACL 2011).

Large Movie Review Dataset:http://ai.stanford.edu/~amaas/data/sentiment/

Sentiment analysis on IMDB movie reviews | Kaggle

ヒューリスティックでモデルを作成した例:

実験

比較手法

  • NoTeach: 教師エージェントなしで生徒のモデルを学習、つまり普通の学習
  • Self-paced-learning(SPL):

カリキュラムラーニングのうちサンプルの選び方も最適化するのがSPL、サンプルの選び方に「多様性」をもたせたものがSPL-with-diversityと認識しています。

  • Learnig-to-teach(L2T): 上記のようにサンプルやモデルの状態から作成したベクトル g(s) を状態を表すベクトルとしてそこから次のサンプルの選択を出力するポリシーを作る。ポリシー関数 φ:s-> a は三層のニューラルネットで作る

  • RandTeach: L2Tでのサンプルの選択をログとして記録し、毎バッチでフィルタリングしたサンプルの割合を記録する。その記録に基づいてフィルタリングの割合は同じだが、サンプルの選び方をランダムにしたものをRandTeachとする。

評価

教師エージェントを訓練した時と同じモデルの生徒を指導

データセットDは訓練データD_train, テストデータD_testに分ける。 D_trainは教師エージェントの訓練用データ D_train_teacher と、訓練済みの教師を使って生徒を訓練するときに利用するデータセット D_train_student に分ける。L2Tについては以下のステップで実験を行う

    1. D_train_teacher で今日しを訓練。D_train_teacher の一部をさらに分離して訓練に使わずに報酬の計算に利用する。
    1. 教師エージェントのパラメタを固定して、D_train_studentを使って新しい生徒を訓練させる。この生徒は教師エージェントを訓練させるときに用いた生徒と同じモデル。
    1. 生徒をD_testを用いて評価。

三つのデータセット全てにおいて訓練時の accuracy の収束がもっとも早かった。45%~75%のサンプル数で同じ accuracy に到達できる。なので生徒が同じモデルの場合には教師エージェントの効果が認められる。

フィルタリングしたサンプル数の傾向を見るとSPLとは異なる傾向が見られた。MNISTとCIFER-10についてはサンプルの分類難易度ごとにフィルタされた数をプロットすると簡単なサンプルは徐々にフィルタ去れ、難しいサンプルは常に訓練データに残る傾向が見られる。IMDBについては徐々にフィルタされるサンプルが減っていき、全ての難易度のサンプルについて同じくらいフィルタがかかる。

教師エージェントを訓練した時と異なるモデルの生徒を指導

今度は教師の訓練に使用した生徒のモデルと、実際に指導する生徒のモデルが異なる場合。それぞれ生徒1, 生徒2と書くと

  • 生徒1=Resnet32, 生徒2=Resnet110, ともに CIFAR-10を解くケース
  • 生徒1=MLPでMNISTを解く, 生徒2=CNNでCIFAR-10を解くケース
  • 生徒1=CNNでCIFAR-10を解く, 生徒2=MLPでMNISTを解くケース

全てのケースで L2T がもっとも早く収束。また wall-clock-timeの視点でも評価を(一部)行い、g(s)の計算などを行っているがそれでも訓練時間の削減もできていることが示唆される。