Skip to content

Commit

Permalink
feat: improve logic for connecting to peers. (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
microup authored Aug 13, 2023
1 parent 8ae059a commit 5c95f4d
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 83 deletions.
4 changes: 2 additions & 2 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ proxy:

port: 8080
clientDeadLineTime: 30s
peerHostTimeout: 60s
peerHostDeadLine: 30s
peerConnectionTimeout: 30s
maxCountConnection: 1000
countDialAttemptsToPeer: 30

rules:
blacklist:
Expand All @@ -15,6 +14,7 @@ proxy:
- 192.168.1.40

peers:
timeToEvictNotResponsePeers: 60s
list:
- name: test_backend_1
uri: 127.0.0.1:8081
Expand Down
16 changes: 11 additions & 5 deletions internal/proxy/peer/peer.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
package peer

import (
"context"
"fmt"
"net"
"time"
)

// IPeer this is the interface that defines the methods for dialing a connection.
type IPeer interface {
Dial(timeOut time.Duration, timeOutDeadLine time.Duration) (net.Conn, error)
GetURI() string
Dial(ctx context.Context, timeOut time.Duration, timeOutDeadLine time.Duration) (net.Conn, error)
}

// Peer this is the struct that implements the IPeer interface.
type Peer struct {
Name string `yaml:"name"`
URI string `yaml:"uri"`
Name string `yaml:"name"`
URI string `yaml:"uri"`
}

// Dial dials a connection to a peer.
func (p *Peer) Dial(timeOut time.Duration, timeOutDeadLine time.Duration) (net.Conn, error) {
connect, err := net.DialTimeout("tcp", p.GetURI(), timeOut)
func (p *Peer) Dial(ctx context.Context, timeOut time.Duration, timeOutDeadLine time.Duration) (net.Conn, error) {
//nolint:exhaustivestruct,exhaustruct
dialer := net.Dialer{
Timeout: timeOut,
}

connect, err := dialer.DialContext(ctx, "tcp", p.GetURI())
if err != nil {
return nil, fmt.Errorf("%w", err)
}
Expand Down
44 changes: 35 additions & 9 deletions internal/proxy/peers/peers.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
package peers

import (
"context"
"fmt"
"sync/atomic"
"time"

"vbalancer/internal/proxy/peer"
"vbalancer/internal/types"

cache "github.com/microup/vcache"
)

// Peers define a struct that contains a list of peers.
type Peers struct {
CurrentPeerIndex *uint64
List []peer.Peer `yaml:"list" json:"list"`
currentPeerIndex *uint64
blackListNotResponsePeers *cache.VCache
TimeToEvictNotResponsePeers time.Duration `yaml:"timeToEvictNotResponsePeers"`
List []peer.Peer `yaml:"list" json:"list"`
}

// Initialize Peers struct with a slice of Peer objects,
// copy peers from input to the new slice, set CurrentPeerIndex to 0
// to track the selected peer's index in the slice.
func (p *Peers) Init(peers []peer.Peer) error {
func (p *Peers) Init(ctx context.Context, peers []peer.Peer) error {
p.blackListNotResponsePeers = cache.New(time.Second, p.TimeToEvictNotResponsePeers)

err := p.blackListNotResponsePeers.StartEvict(ctx)
if err != nil {
return fmt.Errorf("%w", err)
}

var startIndexInListPeer uint64
p.CurrentPeerIndex = &startIndexInListPeer
p.currentPeerIndex = &startIndexInListPeer

p.List = make([]peer.Peer, len(peers))

Expand All @@ -34,24 +48,36 @@ func (p *Peers) Init(peers []peer.Peer) error {
func (p *Peers) GetNextPeer() (*peer.Peer, types.ResultCode) {
var next int

if *p.CurrentPeerIndex >= uint64(len(p.List)) {
atomic.StoreUint64(p.CurrentPeerIndex, uint64(0))
if *p.currentPeerIndex >= uint64(len(p.List)) {
atomic.StoreUint64(p.currentPeerIndex, uint64(0))
} else {
next = p.nextIndex()
}

l := len(p.List) + next
for i := next; i < l; i++ {
idx := i % len(p.List)
atomic.StoreUint64(p.CurrentPeerIndex, uint64(idx))
atomic.StoreUint64(p.currentPeerIndex, uint64(idx))

return &p.List[idx], types.ResultOK
peer := p.List[idx]

_, found := p.blackListNotResponsePeers.Get(peer.URI)
if found {
continue
}

return &peer, types.ResultOK
}

return nil, types.ErrCantFindActivePeers
}

// AddToCacheBadPeer.
func (p *Peers) AddToCacheBadPeer(uri string) {
_ = p.blackListNotResponsePeers.Add(uri, true)
}

// nextIndex returns the next index in a list of peers.
func (p *Peers) nextIndex() int {
return int(atomic.AddUint64(p.CurrentPeerIndex, uint64(1)) % uint64(len(p.List)))
return int(atomic.AddUint64(p.currentPeerIndex, uint64(1)) % uint64(len(p.List)))
}
113 changes: 56 additions & 57 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,18 @@ var ErrConfigPeersIsNil = errors.New("empty list peer in config file")

// Proxy defines the structure for the proxy server.
type Proxy struct {
//
//
Logger vlog.ILog `yaml:"-" json:"-"`
// Define the default port to listen on
Port string `yaml:"port" json:"port"`
// Define the client deadline time
ClientDeadLineTime time.Duration `yaml:"clientDeadLineTime" json:"clientDeadLineTime"`
// Define the peer host timeout
PeerHostTimeOut time.Duration `yaml:"peerHostTimeout" json:"peerHostTimeout"`
PeerConnectionTimeout time.Duration `yaml:"peerConnectionTimeout" json:"peerConnectionTimeout"`
// Define the peer host deadline
PeerHostDeadLine time.Duration `yaml:"peerHostDeadLine" json:"peerHostDeadLine"`
// Define the max connection semaphore
MaxCountConnection uint `yaml:"maxCountConnection" json:"maxCountConnection"`
// Define the count dial attempts to peer
CountMaxDialAttemptsToPeer uint `yaml:"countDialAttemptsToPeer" json:"countDialAttemptsToPeer"`
// Peers is a list of peer configurations.
Peers *peers.Peers `yaml:"peers" json:"peers"`
// Defien allows configuration of blacklist rules to be passed to the proxy server
Expand All @@ -51,10 +49,9 @@ func New() *Proxy {
Logger: nil,
Port: types.DefaultProxyPort,
ClientDeadLineTime: types.DeafultClientDeadLineTime,
PeerHostTimeOut: types.DeafultPeerHostTimeOut,
PeerConnectionTimeout: types.DeafultPeerConnectionTimeout,
PeerHostDeadLine: types.DeafultPeerHostDeadLine,
MaxCountConnection: types.DeafultMaxCountConnection,
CountMaxDialAttemptsToPeer: types.DeafultCountMaxDialAttemptsToPeer,
//nolint:exhaustivestruct,exhaustruct
Peers: &peers.Peers{},
//nolint:exhaustivestruct,exhaustruct
Expand All @@ -67,7 +64,7 @@ func (p *Proxy) Init(ctx context.Context, logger vlog.ILog) error {
p.Logger = logger

if p.Peers != nil && len(p.Peers.List) != 0 {
err := p.Peers.Init(p.Peers.List)
err := p.Peers.Init(ctx, p.Peers.List)
if err != nil {
return fmt.Errorf("%w", err)
}
Expand All @@ -82,7 +79,7 @@ func (p *Proxy) Init(ctx context.Context, logger vlog.ILog) error {
}
}

if resultCode := p.UpdatePort(); resultCode != types.ResultOK {
if resultCode := p.updatePort(); resultCode != types.ResultOK {
return fmt.Errorf("%w: %s", ErrCantGetProxyPort, resultCode.ToStr())
}

Expand Down Expand Up @@ -132,47 +129,13 @@ func (p *Proxy) AcceptConnections(ctx context.Context, proxySrv net.Listener) {

semaphore <- struct{}{}

go p.handleIncomingConnection(conn, semaphore)
go p.handleIncomingConnection(ctx, conn, semaphore)
}
}
}

// GetProxyPortConfig get the proxy port to serverconfiguration.
func (p *Proxy) UpdatePort() types.ResultCode {
var proxyPort string

if p.Port == "" || p.Port == ":" {
proxyPort = os.Getenv("ProxyPort")
if proxyPort == ":" || proxyPort == "" {
proxyPort = types.DefaultProxyPort
}
} else {
proxyPort = p.Port
}

proxyPort = fmt.Sprintf(":%s", proxyPort)

proxyPort = strings.Trim(proxyPort, " ")
if proxyPort == strings.Trim(":", " ") {
return types.ErrEmptyValue
}

p.Port = proxyPort

return types.ResultOK
}

func (p *Proxy) getCheckIsBlackListIP(remoteIP string) bool {
if p.Rules != nil && p.Rules.Blacklist != nil {
if p.Rules.Blacklist.IsBlacklistIP(remoteIP) {
return true
}
}

return false
}

func (p *Proxy) handleIncomingConnection(conn net.Conn, semaphore chan struct{}) {
// handleIncomingConnection.
func (p *Proxy) handleIncomingConnection(ctx context.Context, conn net.Conn, semaphore chan struct{}) {
defer func() {
<-semaphore
}()
Expand All @@ -192,11 +155,13 @@ func (p *Proxy) handleIncomingConnection(conn net.Conn, semaphore chan struct{})
return
}

clientAddr := conn.RemoteAddr().String()
ctxConnectionTimeout, cancel := context.WithTimeout(ctx, p.PeerConnectionTimeout)
defer cancel()

err = p.reverseData(conn, 0, p.CountMaxDialAttemptsToPeer)
err = p.reverseData(ctxConnectionTimeout, conn)

if err != nil {
clientAddr := conn.RemoteAddr().String()
p.Logger.Add(vlog.Debug, types.ErrProxy, vlog.RemoteAddr(clientAddr), fmt.Errorf("failed in reverseData() %w", err))

responseLogger := response.New()
Expand All @@ -215,26 +180,24 @@ func (p *Proxy) handleIncomingConnection(conn net.Conn, semaphore chan struct{})

// ReverseData reverses data from the client to the next available peer,
// it returns an error if the maximum number of attempts is reached or if it fails to get the next peer.
func (p *Proxy) reverseData(client net.Conn, curentDialAttemptsToPeer uint, maxDialAttemptsToPeer uint) error {
if curentDialAttemptsToPeer >= maxDialAttemptsToPeer {
return ErrMaxCountAttempts
}

func (p *Proxy) reverseData(ctxTimeOut context.Context, client net.Conn) error {
pPeer, resultCode := p.Peers.GetNextPeer()
if resultCode != types.ResultOK || pPeer == nil {
//nolint:goerr113
return fmt.Errorf("failed get next peer, result code: %s", resultCode.ToStr())
}

dst, err := pPeer.Dial(p.PeerHostTimeOut, p.PeerHostDeadLine)
if err != nil {
curentDialAttemptsToPeer++
dst, err := pPeer.Dial(ctxTimeOut, p.PeerConnectionTimeout, p.PeerHostDeadLine)
if err != nil || dst == nil {
p.Peers.AddToCacheBadPeer(pPeer.URI)

return p.reverseData(client, curentDialAttemptsToPeer, maxDialAttemptsToPeer)
return p.reverseData(ctxTimeOut, client)
}
defer dst.Close()

p.Logger.Add(vlog.Debug, types.ResultOK,
p.Logger.Add(
vlog.Debug,
types.ResultOK,
vlog.RemoteAddr(dst.RemoteAddr().String()),
vlog.ProxyHost(client.LocalAddr().String()),
fmt.Sprintf("try to copy data from remote: %s to peer: %s",
Expand Down Expand Up @@ -272,3 +235,39 @@ func (p *Proxy) proxyDataCopy(waitGroup *sync.WaitGroup, client io.ReadWriter, d
_, _ = io.Copy(client, dst)
}()
}

// updatePort.
func (p *Proxy) updatePort() types.ResultCode {
var proxyPort string

if p.Port == "" || p.Port == ":" {
proxyPort = os.Getenv("ProxyPort")
if proxyPort == ":" || proxyPort == "" {
proxyPort = types.DefaultProxyPort
}
} else {
proxyPort = p.Port
}

proxyPort = fmt.Sprintf(":%s", proxyPort)

proxyPort = strings.Trim(proxyPort, " ")
if proxyPort == strings.Trim(":", " ") {
return types.ErrEmptyValue
}

p.Port = proxyPort

return types.ResultOK
}

// getCheckIsBlackListIP.
func (p *Proxy) getCheckIsBlackListIP(remoteIP string) bool {
if p.Rules != nil && p.Rules.Blacklist != nil {
if p.Rules.Blacklist.IsBlacklistIP(remoteIP) {
return true
}
}

return false
}
Loading

0 comments on commit 5c95f4d

Please sign in to comment.