diff --git a/dialer/bandit.go b/dialer/bandit.go index c58359ef3..a7b453a3f 100644 --- a/dialer/bandit.go +++ b/dialer/bandit.go @@ -12,9 +12,10 @@ import ( // BanditDialer is responsible for continually choosing the optimized dialer. type BanditDialer struct { - dialers []ProxyDialer - bandit *bandit.EpsilonGreedy - opts *Options + dialers []ProxyDialer + bandit bandit.Bandit + opts *Options + dialerWeights map[string]BanditMetrics } // NewBandit creates a new bandit given the available dialers and options with @@ -30,12 +31,17 @@ func NewBandit(opts *Options) (Dialer, error) { dialers := opts.Dialers log.Debugf("Creating bandit with %d dialers", len(dialers)) - var b *bandit.EpsilonGreedy + var b bandit.Bandit var err error + dialer := &BanditDialer{ + dialers: dialers, + opts: opts, + } if opts.LoadLastBanditRewards != nil { log.Debugf("Loading bandit weights from %s", opts.LoadLastBanditRewards) // TODO: Load the weights from the file. dialerWeights := opts.LoadLastBanditRewards() + dialer.dialerWeights = dialerWeights counts := make([]int, len(dialers)) rewards := make([]float64, len(dialers)) for arm, dialer := range dialers { @@ -60,12 +66,7 @@ func NewBandit(opts *Options) (Dialer, error) { return nil, err } } - - dialer := &BanditDialer{ - dialers: dialers, - bandit: b, - opts: opts, - } + dialer.bandit = b return dialer, nil } @@ -124,6 +125,23 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) ( } }) + time.AfterFunc(30*time.Second, func() { + // Save the bandit weights + if bd.opts.SaveBanditRewards != nil { + metrics := make(map[string]BanditMetrics) + rewards := bd.bandit.GetRewards() + counts := bd.bandit.GetCounts() + for i, d := range bd.dialers { + metrics[d.Name()] = BanditMetrics{ + Reward: rewards[i], + Count: counts[i], + } + } + + bd.opts.SaveBanditRewards(metrics) + } + }) + bd.opts.OnSuccess(d) return dt, err } diff --git a/dialer/dialer.go b/dialer/dialer.go index 6f0d31494..03bace207 100644 --- a/dialer/dialer.go +++ b/dialer/dialer.go @@ -77,6 +77,8 @@ type Options struct { // for each dialer. If this is set, the bandit will be initialized with the // last metrics. LoadLastBanditRewards func() map[string]BanditMetrics + + SaveBanditRewards func(map[string]BanditMetrics) } type BanditMetrics struct {