概要
複数の案を試す際にA/Bテストがありますが、検証期間中はずっと同じ割合で試行しなければいけないため、もし悪い案であった場合に全体としてその期間損失を生むことになります。
そのような損失を少なくしつつ、良いと思われる案を優先的に試行するアルゴリズムとしてバンディットがあります。
ref: https://vwo.com/blog/multi-armed-bandit-algorithm/
環境
- Go 1.20.1
- gonum/stat/distuv v0.12.0
前提知識
バンディットでは
- UCB1
- Softmax
- ε-greedy
- Thompson Sampling
などがありますが、今回はThompson Samplingを対象とします。
またThompson Samplingではベータ分布の知識が必要なので簡単に説明します。
ベータ分布
ベータ分布は連続確率分布の1つで、成功数 α
と失敗数 β
が明らかなときに、それらのデータからもっともあり得る成功率 p
を導き出すことを可能とします。
例えばAmazonのレビューで
商品 | 肯定的レビュー | α | β |
---|---|---|---|
商品A | 40 個中 24 個 | 24 | 16 |
商品B | 200 個中 120 個 | 120 | 80 |
のようなレビューの母数自体が異なってどちらが良いか判断しにくいケースがあります。
これをベータ分布で表すと、
となります。
商品Bの方がレビュー数が多い分、より精度が高いレビューになっています。
直感でもレビューが多い方が尤もらしいと感じますよね。
一方でベータ分布の最頻値を以下の式に当てはめてみると、
次のようになります。
商品 | 最頻値 |
---|---|
商品A | 0.60526316 |
商品B | 0.6010101 |
若干商品Aの方が良いとなります。
商品Aの方が母数が少ないことで不確実性が大きい(もちろん失敗する可能性も高い)ためだと思われます。
Thompson Sampling法
Thompson Samplingでは以下の手順を踏みます。
- 各案の確率モデルのパラメータをベータ分布のような確率分布で表す
- それぞれの確率分布から乱数を生成する(サンプリング)
- サンプリング結果の内最も良い乱数を生成した案を選択する
- 得られた結果をその案の確率分布に反映する
- 2~4を繰り返す
これにより
- 母平均の高い案が選ばれやすい
- 何度も選ばれると分布が狭まり確率密度が高まる=精度が上がる
- ↑によって不確実性が高い案も選ばれるケースが発生する
といった形で良い案が試行されやすくなります。
具体的な実装
CVRの母平均が
コンテンツ | CVR |
---|---|
コンテンツ0 | 0.6 |
コンテンツ1 | 0.4 |
コンテンツ2 | 0.1 |
である3つのコンテンツを用意し、それらにバンディットアルゴリズムを適用してみます。
本来は母平均が分からないはずですが、コンバージョンした・しないといったデータが集まる度にαとβが分かるのでThompson Sampling法が使えます。
Goでの実装
では具体的にGoで実装します。
Arm
type Arm struct { probs float64 alpha int beta int } func newArm(prob float64) *Arm { return &Arm{ probs: prob, } } func (a *Arm) reward() { if rand.Float64() < a.probs { a.alpha++ } else { a.beta++ } } func (a *Arm) total() int { return a.alpha + a.beta } func (a *Arm) sampling() float64 { bd := distuv.Beta{ Alpha: float64(a.alpha) + 1, Beta: float64(a.beta) + 1, Src: erand.NewSource(uint64(time.Now().Nanosecond())), } return bd.Rand() }
ポイント
- 試行回数を保持するためにαとβをフィールドに持つ
- distuv.Betaを用いてベータ分布に基づいて次の確率をサンプリングする
- 成功(コンバージョン)したらαをインクリメントし、失敗したらβをインクリメントする
複数Arm
type Arms []*Arm func (a Arms) next() *Arm { idx := 0 var max float64 for i := range a { if prob := a[i].sampling(); prob > max { idx = i max = prob } } return a[idx] } func ThompsonSampling(pull int, arms ...*Arm) { for i := 0; i < pull; i++ { arm := Arms(arms).next() arm.reward() } }
ポイント
- 複数のArmの中で最も確率が高いものを選択する
- 選出されたArmのα、βを更新する
main
func main() { arms := Arms{ newArm(0.6), newArm(0.4), newArm(0.1), } ThompsonSampling(1000, arms...) for i := range arms { fmt.Printf("コンテンツ%d 試行回数: %d, CV回数: %d\n", i, arms[i].total(), arms[i].alpha) } }
1000回ほど試行します。
動作確認
$ go run main.go コンテンツ0 試行回数: 965, CV回数: 579 コンテンツ1 試行回数: 27, CV回数: 11 コンテンツ2 試行回数: 8, CV回数: 1
期待通り母平均が高い(コンバージョンしやすい)コンテンツ0のArmが優先的に引かれました。
plot
各Armの分布を可視化すると以下のようになります。
func (a Arms) plot() error { // Create a new plot and set its dimensions. p := plot.New() p.X.Label.Text = "Reward Probability" p.Y.Label.Text = "Density" p.Y.Min = 0 for _, arm := range a { // Create the data for the plot. pts := make(plotter.XYs, 100) for i := range pts { x := float64(i) / float64(len(pts)-1) pts[i].X = x pts[i].Y = distuv.Beta{Alpha: float64(arm.alpha + 1), Beta: float64(arm.beta + 1)}.Prob(x) } // Create a line plotter and add it to the plot. lp, err := plotter.NewLine(pts) if err != nil { return err } lp.LineStyle.Width = vg.Points(1) lp.LineStyle.Color = color.RGBA{ R: uint8(rand.Intn(255)), G: uint8(rand.Intn(255)), B: uint8(rand.Intn(255)), A: 255, } p.Add(lp) } // Save the plot to a PNG file. if err := p.Save(4*vg.Inch, 4*vg.Inch, "arm_distribution.png"); err != nil { return err } return nil }
コンテンツ0が非常に密度が高く(=多く試行されている)、それ以外のArmはあまり試行されていないことが分かります。
その他
サンプルコード
今回のサンプルコードはこちら
コンテンツ0に大きく偏らないこともある
多くの場合は母平均の高いコンテンツが優先的に選択されますが、母平均の低いコンテンツでも偶然何度も選定されるケースもあります(ベータ分布的には起きにくくても可能性は十分あるため)。
例えば試行回数を100回に減らして何度か回してみると
$ go run main.go コンテンツ0 試行回数: 8, CV回数: 2 コンテンツ1 試行回数: 87, CV回数: 43 コンテンツ2 試行回数: 5, CV回数: 0
この様にコンテンツ1が選定されました。
なので必ずしも最も良い案が選択されるわけではないということも理解しておく必要があります。
別ライブラリで検証
Goの多腕バンディットで検索すると以下のライブラリが出てきます。
自前で実装したArmのデータを使ってこのライブラリで検証してみます。
// check by another library rewards := []mab.Dist{} for i := range arms { rewards = append(rewards, mab.Beta(float64(arms[i].alpha+1), float64(arms[i].beta+1)), ) } b := mab.Bandit{ RewardSource: &mab.RewardStub{Rewards: rewards}, Strategy: mab.NewThompson(numint.NewQuadrature()), Sampler: mab.NewSha1Sampler(), } result, err := b.SelectArm(context.Background(), "12345", nil) if err != nil { log.Fatal(err) } fmt.Println(result.Arm)
動作確認
$ go run main.go コンテンツ0 試行回数: 981, CV回数: 590 コンテンツ1 試行回数: 13, CV回数: 4 コンテンツ2 試行回数: 6, CV回数: 0 0
期待通りコンテンツ0のArmが選択されました。
まとめ
バンディットアルゴリズムを使うことで試行しつつ最適な案を優先的に選択できるようになります。
そのGoでの実装方法を紹介しました。