Skip to content

Commit

Permalink
feat: add option for saving bandit metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
WendelHime committed Nov 22, 2024
1 parent 2dfdea4 commit 44f3ee6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
38 changes: 28 additions & 10 deletions dialer/bandit.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (

// BanditDialer is responsible for continually choosing the optimized dialer.
type BanditDialer struct {
dialers []ProxyDialer
bandit *bandit.EpsilonGreedy
opts *Options
dialers []ProxyDialer
bandit bandit.Bandit
opts *Options
dialerWeights map[string]BanditMetrics
}

// NewBandit creates a new bandit given the available dialers and options with
Expand All @@ -30,12 +31,17 @@ func NewBandit(opts *Options) (Dialer, error) {
dialers := opts.Dialers
log.Debugf("Creating bandit with %d dialers", len(dialers))

var b *bandit.EpsilonGreedy
var b bandit.Bandit
var err error
dialer := &BanditDialer{
dialers: dialers,
opts: opts,
}
if opts.LoadLastBanditRewards != nil {
log.Debugf("Loading bandit weights from %s", opts.LoadLastBanditRewards)
// TODO: Load the weights from the file.
dialerWeights := opts.LoadLastBanditRewards()
dialer.dialerWeights = dialerWeights
counts := make([]int, len(dialers))
rewards := make([]float64, len(dialers))
for arm, dialer := range dialers {
Expand All @@ -60,12 +66,7 @@ func NewBandit(opts *Options) (Dialer, error) {
return nil, err
}
}

dialer := &BanditDialer{
dialers: dialers,
bandit: b,
opts: opts,
}
dialer.bandit = b

return dialer, nil
}
Expand Down Expand Up @@ -124,6 +125,23 @@ func (bd *BanditDialer) DialContext(ctx context.Context, network, addr string) (
}
})

time.AfterFunc(30*time.Second, func() {
// Save the bandit weights
if bd.opts.SaveBanditRewards != nil {
metrics := make(map[string]BanditMetrics)
rewards := bd.bandit.GetRewards()
counts := bd.bandit.GetCounts()
for i, d := range bd.dialers {
metrics[d.Name()] = BanditMetrics{
Reward: rewards[i],
Count: counts[i],
}
}

bd.opts.SaveBanditRewards(metrics)
}
})

bd.opts.OnSuccess(d)
return dt, err
}
Expand Down
2 changes: 2 additions & 0 deletions dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ type Options struct {
// for each dialer. If this is set, the bandit will be initialized with the
// last metrics.
LoadLastBanditRewards func() map[string]BanditMetrics

SaveBanditRewards func(map[string]BanditMetrics)
}

type BanditMetrics struct {
Expand Down

0 comments on commit 44f3ee6

Please sign in to comment.