Skip to content

Commit

Permalink
feat: adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WendelHime committed Nov 25, 2024
1 parent d1c7f73 commit 6a8d589
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 50 deletions.
114 changes: 64 additions & 50 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -737,62 +737,76 @@ func (client *Client) initDialers(proxies map[string]*commonconfig.ProxyConfig)
)
}
},
SaveBanditRewards: func(metrics map[string]dialer.BanditMetrics) {
dir := filepath.Join(configDir, "bandit")
if err := os.MkdirAll(dir, 0644); err != nil {
log.Errorf("unable to create bandit directory: %v", err)
return
}
file := filepath.Join(dir, "rewards.csv")
csv := new(strings.Builder)
csv.WriteString("dialerName,reward,count\n")
for dialerName, metric := range metrics {
csv.WriteString(fmt.Sprintf("%s,%f,%d\n", dialerName, metric.Reward, metric.Count))
SaveBanditRewards: saveBanditRewards(configDir),
LoadLastBanditRewards: loadLastBanditRewards(configDir),
})
return dialers, dialer, nil
}

func saveBanditRewards(dir string) func(map[string]dialer.BanditMetrics) {
return func(metrics map[string]dialer.BanditMetrics) {
dir := filepath.Join(dir, "bandit")
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("unable to create bandit directory: %v", err)
return
}
file := filepath.Join(dir, "rewards.csv")
csv := new(strings.Builder)
csv.WriteString("dialer,reward,count\n")
for dialerName, metric := range metrics {
csv.WriteString(fmt.Sprintf("%s,%f,%d\n", dialerName, metric.Reward, metric.Count))
}
f, err := os.Create(file)
if err != nil {
log.Errorf("unable to create bandit rewards file: %v", err)
return
}
defer f.Close()
if _, err := f.WriteString(csv.String()); err != nil {
log.Errorf("unable to write bandit rewards to file: %v", err)
}
}
}

func loadLastBanditRewards(outputDir string) func() map[string]dialer.BanditMetrics {
return func() map[string]dialer.BanditMetrics {
dir := filepath.Join(outputDir, "bandit")
file := filepath.Join(dir, "rewards.csv")
if _, err := os.Stat(file); os.IsNotExist(err) {
return nil
}
data, err := os.ReadFile(file)
if err != nil {
log.Errorf("unable to read bandit rewards from file: %v", err)
return nil
}
lines := strings.Split(string(data), "\n")
metrics := make(map[string]dialer.BanditMetrics)
for i, line := range lines {
if i == 0 {
continue
}
if err := os.WriteFile(file, []byte(csv.String()), 0644); err != nil {
log.Errorf("unable to write bandit rewards to file: %v", err)
parts := strings.Split(line, ",")
if len(parts) != 3 {
continue
}
},
LoadLastBanditRewards: func() map[string]dialer.BanditMetrics {
dir := filepath.Join(configDir, "bandit")
file := filepath.Join(dir, "rewards.csv")
if _, err := os.Stat(file); os.IsNotExist(err) {
return nil
reward, err := strconv.ParseFloat(parts[1], 64)
if err != nil {
log.Errorf("unable to parse reward from %s: %v", parts[1], err)
continue
}
data, err := os.ReadFile(file)
count, err := strconv.Atoi(parts[2])
if err != nil {
log.Errorf("unable to read bandit rewards from file: %v", err)
return nil
log.Errorf("unable to parse count from %s: %v", parts[2], err)
continue
}
lines := strings.Split(string(data), "\n")
metrics := make(map[string]dialer.BanditMetrics)
for i, line := range lines {
if i == 0 {
continue
}
parts := strings.Split(line, ",")
if len(parts) != 3 {
continue
}
reward, err := strconv.ParseFloat(parts[1], 64)
if err != nil {
log.Errorf("unable to parse reward from %s: %v", parts[1], err)
continue
}
count, err := strconv.Atoi(parts[2])
if err != nil {
log.Errorf("unable to parse count from %s: %v", parts[2], err)
continue
}
metrics[parts[0]] = dialer.BanditMetrics{
Reward: reward,
Count: count,
}
metrics[parts[0]] = dialer.BanditMetrics{
Reward: reward,
Count: count,
}
return metrics
},
})
return dialers, dialer, nil
}
return metrics
}
}

// Creates a local server to capture client hello messages from the browser and
Expand Down
85 changes: 85 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -422,6 +424,89 @@ func TestAccessingProxyPort(t *testing.T) {
assert.Equal(t, "0", resp.Header.Get("Content-Length"))
}

func TestSaveBanditRewards(t *testing.T) {
var tests = []struct {
name string
given map[string]dialer.BanditMetrics
assert func(t *testing.T, dir string)
}{
{
name: "it should save the rewards",
given: map[string]dialer.BanditMetrics{
"test-dialer": {
Reward: 1.0,
Count: 1,
},
},
assert: func(t *testing.T, dir string) {
f, err := os.Open(filepath.Join(dir, "bandit", "rewards.csv"))
require.NoError(t, err)
defer f.Close()
b, err := io.ReadAll(f)
require.NoError(t, err)

lines := strings.Split(string(b), "\n")
// check if headers are there
assert.Contains(t, lines[0], "dialer,reward,count")
// check if the data is there
assert.Contains(t, lines[1], "test-dialer,1.000000,1")
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tempDir, err := os.MkdirTemp("", "client_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)

f := saveBanditRewards(tempDir)
f(tt.given)
tt.assert(t, tempDir)
})
}
}

func TestLoadLastBanditRewards(t *testing.T) {
var tests = []struct {
name string
given string
assert func(t *testing.T, metrics map[string]dialer.BanditMetrics)
}{
{
name: "it should load the rewards",
given: "dialer,reward,count\ntest-dialer,1.000000,1\n",
assert: func(t *testing.T, metrics map[string]dialer.BanditMetrics) {
assert.Contains(t, metrics, "test-dialer")
assert.Equal(t, 1.0, metrics["test-dialer"].Reward)
assert.Equal(t, 1, metrics["test-dialer"].Count)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tempDir, err := os.MkdirTemp("", "client_test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)

if err := os.MkdirAll(filepath.Join(tempDir, "bandit"), 0755); err != nil {
log.Errorf("unable to create bandit directory: %v", err)
return
}

f, err := os.Create(filepath.Join(tempDir, "bandit", "rewards.csv"))
require.NoError(t, err)
defer f.Close()
_, err = f.WriteString(tt.given)
require.NoError(t, err)

metrics := loadLastBanditRewards(tempDir)()
tt.assert(t, metrics)
})
}
}

// Assert that a testDialer is a bandit.Dialer
var _ dialer.ProxyDialer = &testDialer{}

Expand Down

0 comments on commit 6a8d589

Please sign in to comment.