Skip to content

Commit

Permalink
add Config.UnidirectionalCodeToDisconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia committed Nov 3, 2024
1 parent 106d43f commit ea2832c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
15 changes: 10 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,21 +307,26 @@ func NewClient(ctx context.Context, n *Node, t Transport) (*Client, ClientCloseF
return client, func() error { return client.close(DisconnectConnectionClosed) }, nil
}

var uniErrorCodeToDisconnect = map[uint32]Disconnect{
var defaultUniErrorCodeToDisconnect = map[uint32]Disconnect{
ErrorExpired.Code: DisconnectExpired,
ErrorTokenExpired.Code: DisconnectExpired,
ErrorTooManyRequests.Code: DisconnectTooManyRequests,
ErrorPermissionDenied.Code: DisconnectPermissionDenied,
}

func extractUnidirectionalDisconnect(err error) Disconnect {
func (c *Client) extractUnidirectionalDisconnect(err error) Disconnect {
switch t := err.(type) {
case *Disconnect:
return *t
case Disconnect:
return t
case *Error:
if d, ok := uniErrorCodeToDisconnect[t.Code]; ok {
if c.node.config.UnidirectionalCodeToDisconnect != nil {
if d, ok := c.node.config.UnidirectionalCodeToDisconnect[t.Code]; ok {
return d
}
}
if d, ok := defaultUniErrorCodeToDisconnect[t.Code]; ok {
return d
}
return DisconnectServerError
Expand Down Expand Up @@ -403,7 +408,7 @@ func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest,
c.handleCommandFinished(cmd, protocol.FrameTypeConnect, err, nil, started)
}
if errorToDisconnect {
d := extractUnidirectionalDisconnect(err)
d := c.extractUnidirectionalDisconnect(err)
go func() { _ = c.close(d) }()
return nil
}
Expand All @@ -416,7 +421,7 @@ func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest,
c.handleCommandFinished(cmd, protocol.FrameTypeConnect, err, nil, started)
}
if errorToDisconnect {
d := extractUnidirectionalDisconnect(err)
d := c.extractUnidirectionalDisconnect(err)
go func() { _ = c.close(d) }()
return nil
}
Expand Down
22 changes: 17 additions & 5 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2233,16 +2233,28 @@ func TestClientCloseUnauthenticated(t *testing.T) {
}

func TestExtractUnidirectionalDisconnect(t *testing.T) {
d := extractUnidirectionalDisconnect(errors.New("test"))
t.Parallel()
node := defaultTestNode()
defer func() { _ = node.Shutdown(context.Background()) }()

client := newTestClient(t, node, "42")
d := client.extractUnidirectionalDisconnect(errors.New("test"))
require.Equal(t, DisconnectServerError, d)
d = extractUnidirectionalDisconnect(ErrorLimitExceeded)
d = client.extractUnidirectionalDisconnect(ErrorLimitExceeded)
require.Equal(t, DisconnectServerError, d)
d = extractUnidirectionalDisconnect(DisconnectChannelLimit)
d = client.extractUnidirectionalDisconnect(DisconnectChannelLimit)
require.Equal(t, DisconnectChannelLimit, d)
d = extractUnidirectionalDisconnect(DisconnectServerError)
d = client.extractUnidirectionalDisconnect(DisconnectServerError)
require.Equal(t, DisconnectServerError, d)
d = extractUnidirectionalDisconnect(ErrorExpired)
d = client.extractUnidirectionalDisconnect(ErrorExpired)
require.Equal(t, DisconnectExpired, d)

// Test additional mapping through the Config.
node.config.UnidirectionalCodeToDisconnect = map[uint32]Disconnect{
400: DisconnectBadRequest,
}
d = client.extractUnidirectionalDisconnect(&Error{Code: 400})
require.Equal(t, DisconnectBadRequest, d)
}

func TestClientHandleEmptyData(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ type Config struct {
// know about custom PresenceManager instances. When GetPresenceManager returns false as the second
// argument then Node will use the default PresenceManager for the channel.
GetPresenceManager func(channel string) (PresenceManager, bool)
// Tell Centrifuge how to transform connect error codes to disconnect objects for unidirectional
// transports. If not set then the default mapping is used.
UnidirectionalCodeToDisconnect map[uint32]Disconnect
}

const (
Expand Down

0 comments on commit ea2832c

Please sign in to comment.