Carpe Diem

備忘録

Goでバンディットアルゴリズムを実装する

概要

複数の案を試す際にA/Bテストがありますが、検証期間中はずっと同じ割合で試行しなければいけないため、もし悪い案であった場合に全体としてその期間損失を生むことになります。

そのような損失を少なくしつつ、良いと思われる案を優先的に試行するアルゴリズムとしてバンディットがあります。

ref: https://vwo.com/blog/multi-armed-bandit-algorithm/

環境

  • Go 1.20.1
  • gonum/stat/distuv v0.12.0

前提知識

バンディットでは

  • UCB1
  • Softmax
  • ε-greedy
  • Thompson Sampling

などがありますが、今回はThompson Samplingを対象とします。
またThompson Samplingではベータ分布の知識が必要なので簡単に説明します。

ベータ分布

ベータ分布は連続確率分布の1つで、成功数 α と失敗数 β が明らかなときに、それらのデータからもっともあり得る成功率 p を導き出すことを可能とします。

例えばAmazonのレビューで

商品 肯定的レビュー α β
商品A 40 個中 24 個 24 16
商品B 200 個中 120 個 120 80

のようなレビューの母数自体が異なってどちらが良いか判断しにくいケースがあります。
これをベータ分布で表すと、

商品A
商品B

となります。
商品Bの方がレビュー数が多い分、より精度が高いレビューになっています。
直感でもレビューが多い方が尤もらしいと感じますよね。

一方でベータ分布の最頻値を以下の式に当てはめてみると、

Mode(p)=\dfrac{\alpha-1}{\alpha+\beta-2}

次のようになります。

商品 最頻値
商品A 0.60526316
商品B 0.6010101

若干商品Aの方が良いとなります。
商品Aの方が母数が少ないことで不確実性が大きい(もちろん失敗する可能性も高い)ためだと思われます。

Thompson Sampling法

Thompson Samplingでは以下の手順を踏みます。

  1. 各案の確率モデルのパラメータをベータ分布のような確率分布で表す
  2. それぞれの確率分布から乱数を生成する(サンプリング)
  3. サンプリング結果の内最も良い乱数を生成した案を選択する
  4. 得られた結果をその案の確率分布に反映する
  5. 2~4を繰り返す

これにより

  • 母平均の高い案が選ばれやすい
  • 何度も選ばれると分布が狭まり確率密度が高まる=精度が上がる
  • ↑によって不確実性が高い案も選ばれるケースが発生する

といった形で良い案が試行されやすくなります。

具体的な実装

CVRの母平均が

コンテンツ CVR
コンテンツ0 0.6
コンテンツ1 0.4
コンテンツ2 0.1

である3つのコンテンツを用意し、それらにバンディットアルゴリズムを適用してみます。

本来は母平均が分からないはずですが、コンバージョンした・しないといったデータが集まる度にαとβが分かるのでThompson Sampling法が使えます。

Goでの実装

では具体的にGoで実装します。

Arm

type Arm struct {
        probs float64
        alpha int
        beta  int
}

func newArm(prob float64) *Arm {
        return &Arm{
                probs: prob,
        }
}

func (a *Arm) reward() {
        if rand.Float64() < a.probs {
                a.alpha++
        } else {
                a.beta++
        }
}

func (a *Arm) total() int {
        return a.alpha + a.beta
}

func (a *Arm) sampling() float64 {
        bd := distuv.Beta{
                Alpha: float64(a.alpha) + 1,
                Beta:  float64(a.beta) + 1,
                Src:   erand.NewSource(uint64(time.Now().Nanosecond())),
        }
        return bd.Rand()
}

ポイント

  • 試行回数を保持するためにαとβをフィールドに持つ
  • distuv.Betaを用いてベータ分布に基づいて次の確率をサンプリングする
  • 成功(コンバージョン)したらαをインクリメントし、失敗したらβをインクリメントする

複数Arm

type Arms []*Arm

func (a Arms) next() *Arm {
        idx := 0
        var max float64
        for i := range a {
                if prob := a[i].sampling(); prob > max {
                        idx = i
                        max = prob
                }
        }
        return a[idx]
}

func ThompsonSampling(pull int, arms ...*Arm) {
        for i := 0; i < pull; i++ {
                arm := Arms(arms).next()
                arm.reward()
        }
}

ポイント

  • 複数のArmの中で最も確率が高いものを選択する
  • 選出されたArmのα、βを更新する

main

func main() {
        arms := Arms{
                newArm(0.6),
                newArm(0.4),
                newArm(0.1),
        }

        ThompsonSampling(1000, arms...)

        for i := range arms {
                fmt.Printf("コンテンツ%d 試行回数: %d, CV回数: %d\n", i, arms[i].total(), arms[i].alpha)
        }
}

1000回ほど試行します。

動作確認

$ go run main.go
コンテンツ0 試行回数: 965, CV回数: 579
コンテンツ1 試行回数: 27, CV回数: 11
コンテンツ2 試行回数: 8, CV回数: 1

期待通り母平均が高い(コンバージョンしやすい)コンテンツ0のArmが優先的に引かれました。

plot

各Armの分布を可視化すると以下のようになります。

func (a Arms) plot() error {
        // Create a new plot and set its dimensions.
        p := plot.New()
        p.X.Label.Text = "Reward Probability"
        p.Y.Label.Text = "Density"
        p.Y.Min = 0

        for _, arm := range a {
                // Create the data for the plot.
                pts := make(plotter.XYs, 100)
                for i := range pts {
                        x := float64(i) / float64(len(pts)-1)
                        pts[i].X = x
                        pts[i].Y = distuv.Beta{Alpha: float64(arm.alpha + 1), Beta: float64(arm.beta + 1)}.Prob(x)
                }

                // Create a line plotter and add it to the plot.
                lp, err := plotter.NewLine(pts)
                if err != nil {
                        return err
                }
                lp.LineStyle.Width = vg.Points(1)
                lp.LineStyle.Color = color.RGBA{
                        R: uint8(rand.Intn(255)),
                        G: uint8(rand.Intn(255)),
                        B: uint8(rand.Intn(255)),
                        A: 255,
                }
                p.Add(lp)
        }

        // Save the plot to a PNG file.
        if err := p.Save(4*vg.Inch, 4*vg.Inch, "arm_distribution.png"); err != nil {
                return err
        }
        return nil
}

コンテンツ0が非常に密度が高く(=多く試行されている)、それ以外のArmはあまり試行されていないことが分かります。

その他

サンプルコード

今回のサンプルコードはこちら

github.com

コンテンツ0に大きく偏らないこともある

多くの場合は母平均の高いコンテンツが優先的に選択されますが、母平均の低いコンテンツでも偶然何度も選定されるケースもあります(ベータ分布的には起きにくくても可能性は十分あるため)。

例えば試行回数を100回に減らして何度か回してみると

$ go run main.go
コンテンツ0 試行回数: 8, CV回数: 2
コンテンツ1 試行回数: 87, CV回数: 43
コンテンツ2 試行回数: 5, CV回数: 0

この様にコンテンツ1が選定されました。

なので必ずしも最も良い案が選択されるわけではないということも理解しておく必要があります。

別ライブラリで検証

Goの多腕バンディットで検索すると以下のライブラリが出てきます。

github.com

自前で実装したArmのデータを使ってこのライブラリで検証してみます。

        // check by another library
        rewards := []mab.Dist{}
        for i := range arms {
                rewards = append(rewards,
                        mab.Beta(float64(arms[i].alpha+1), float64(arms[i].beta+1)),
                )
        }

        b := mab.Bandit{
                RewardSource: &mab.RewardStub{Rewards: rewards},
                Strategy:     mab.NewThompson(numint.NewQuadrature()),
                Sampler:      mab.NewSha1Sampler(),
        }

        result, err := b.SelectArm(context.Background(), "12345", nil)
        if err != nil {
                log.Fatal(err)
        }
        fmt.Println(result.Arm)

動作確認

$ go run main.go
コンテンツ0 試行回数: 981, CV回数: 590
コンテンツ1 試行回数: 13, CV回数: 4
コンテンツ2 試行回数: 6, CV回数: 0
0

期待通りコンテンツ0のArmが選択されました。

まとめ

バンディットアルゴリズムを使うことで試行しつつ最適な案を優先的に選択できるようになります。
そのGoでの実装方法を紹介しました。

参考