Skip to content

Commit

Permalink
Unidirectional connect without automatic error to disconnect transform (
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia authored Oct 26, 2024
1 parent 612c8bd commit a3a77cc
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 39 deletions.
1 change: 1 addition & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ ignore:
- "internal/controlpb/control*" # generated code
- "client_experimental.go" # experimental code
- "internal/websocket/*" # embedded Gorilla WebSocket fork
- "internal/websocket/examples/autobahn/*" # embedded Gorilla WebSocket fork
5 changes: 5 additions & 0 deletions broker_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,11 @@ func TestClientSubscribeRecover(t *testing.T) {
t.Run(tt.Name, func(t *testing.T) {
node := defaultNodeNoHandlers()
node.config.RecoveryMaxPublicationLimit = tt.Limit

node.OnCacheEmpty(func(event CacheEmptyEvent) (CacheEmptyReply, error) {
return CacheEmptyReply{}, nil
})

node.OnConnect(func(client *Client) {
client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) {
opts := SubscribeOptions{EnableRecovery: true, RecoveryMode: tt.RecoveryMode}
Expand Down
90 changes: 59 additions & 31 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,24 +330,38 @@ func extractUnidirectionalDisconnect(err error) Disconnect {
}
}

// Connect supposed to be called from unidirectional transport layer to pass
// initial information about connection and thus initiate Node.OnConnecting
// Connect supposed to be called only from a unidirectional transport layer
// to pass initial information about connection and thus initiate Node.OnConnecting
// event. Bidirectional transport initiate connecting workflow automatically
// since client passes Connect command upon successful connection establishment
// with a server.
// with a server. If there is an error during connect method processing Centrifuge
// extracts Disconnect from it and closes the connection with that Disconnect message.
func (c *Client) Connect(req ConnectRequest) {
c.unidirectionalConnect(req.toProto(), 0)
// unidirectionalConnect never returns errors when errorToDisconnect is true.
_ = c.unidirectionalConnect(req.toProto(), 0, true)
}

// ConnectNoErrorToDisconnect is the same as Client.Connect but does not try to extract
// Disconnect code from the error returned by the connect logic, instead it just returns
// the error to the caller. This error must be handled by the caller on the Transport level,
// and the connection must be closed on Transport level upon receiving an error.
func (c *Client) ConnectNoErrorToDisconnect(req ConnectRequest) error {
return c.unidirectionalConnect(req.toProto(), 0, false)
}

func (c *Client) getDisconnectPushReply(d Disconnect) ([]byte, error) {
disconnect := &protocol.Disconnect{
Code: d.Code,
Reason: d.Reason,
}
push := &protocol.Push{
Disconnect: disconnect,
}
if c.node.LogEnabled(LogLevelTrace) {
c.traceOutPush(push)
}
return c.encodeReply(&protocol.Reply{
Push: &protocol.Push{
Disconnect: disconnect,
},
Push: push,
})
}

Expand All @@ -371,7 +385,7 @@ func (c *Client) issueCommandProcessedEvent(event CommandProcessedEvent) {
}
}

func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest, connectCmdSize int) {
func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest, connectCmdSize int, errorToDisconnect bool) error {
started := time.Now()

var cmd *protocol.Command
Expand All @@ -385,28 +399,35 @@ func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest,
cmd = &protocol.Command{Id: 1, Connect: connectRequest}
err := c.issueCommandReadEvent(cmd, connectCmdSize)
if err != nil {
d := extractUnidirectionalDisconnect(err)
go func() { _ = c.close(d) }()
if c.node.clientEvents.commandProcessedHandler != nil {
c.handleCommandFinished(cmd, protocol.FrameTypeConnect, &d, nil, started)
c.handleCommandFinished(cmd, protocol.FrameTypeConnect, err, nil, started)
}
return
if errorToDisconnect {
d := extractUnidirectionalDisconnect(err)
go func() { _ = c.close(d) }()
return nil
}
return err
}
}
_, err := c.connectCmd(connectRequest, nil, time.Time{}, nil)
if err != nil {
d := extractUnidirectionalDisconnect(err)
go func() { _ = c.close(d) }()
if c.node.clientEvents.commandProcessedHandler != nil {
c.handleCommandFinished(cmd, protocol.FrameTypeConnect, &d, nil, started)
c.handleCommandFinished(cmd, protocol.FrameTypeConnect, err, nil, started)
}
return
if errorToDisconnect {
d := extractUnidirectionalDisconnect(err)
go func() { _ = c.close(d) }()
return nil
}
return err
}
if c.node.clientEvents.commandProcessedHandler != nil {
c.handleCommandFinished(cmd, protocol.FrameTypeConnect, nil, nil, started)
}
c.triggerConnect()
c.scheduleOnConnectTimers()
return nil
}

func (c *Client) onTimerOp() {
Expand Down Expand Up @@ -1112,12 +1133,12 @@ func isPong(cmd *protocol.Command) bool {
return cmd.Id == 0 && cmd.Send == nil
}

func (c *Client) handleCommandFinished(cmd *protocol.Command, frameType protocol.FrameType, disconnect *Disconnect, reply *protocol.Reply, started time.Time) {
func (c *Client) handleCommandFinished(cmd *protocol.Command, frameType protocol.FrameType, err error, reply *protocol.Reply, started time.Time) {
defer func() {
c.node.metrics.observeCommandDuration(frameType, time.Since(started))
}()
if c.node.clientEvents.commandProcessedHandler != nil {
event := newCommandProcessedEvent(cmd, disconnect, reply, started)
event := newCommandProcessedEvent(cmd, err, reply, started)
c.issueCommandProcessedEvent(event)
}
}
Expand All @@ -1129,13 +1150,13 @@ func (c *Client) handleCommandDispatchError(ch string, cmd *protocol.Command, fr
switch t := err.(type) {
case *Disconnect:
if c.node.clientEvents.commandProcessedHandler != nil {
event := newCommandProcessedEvent(cmd, t, nil, started)
event := newCommandProcessedEvent(cmd, err, nil, started)
c.issueCommandProcessedEvent(event)
}
return t, false
case Disconnect:
if c.node.clientEvents.commandProcessedHandler != nil {
event := newCommandProcessedEvent(cmd, &t, nil, started)
event := newCommandProcessedEvent(cmd, err, nil, started)
c.issueCommandProcessedEvent(event)
}
return &t, false
Expand All @@ -1148,7 +1169,7 @@ func (c *Client) handleCommandDispatchError(ch string, cmd *protocol.Command, fr
errorReply := &protocol.Reply{Error: toClientErr(err).toProto()}
c.writeError(ch, frameType, cmd, errorReply, nil)
if c.node.clientEvents.commandProcessedHandler != nil {
event := newCommandProcessedEvent(cmd, nil, errorReply, started)
event := newCommandProcessedEvent(cmd, err, errorReply, started)
c.issueCommandProcessedEvent(event)
}
return nil, cmd.Connect == nil
Expand Down Expand Up @@ -2089,30 +2110,30 @@ func (c *Client) writeError(ch string, frameType protocol.FrameType, cmd *protoc
c.writeEncodedCommandReply(ch, frameType, cmd, errorReply, rw)
}

func (c *Client) writeDisconnectOrErrorFlush(ch string, frameType protocol.FrameType, cmd *protocol.Command, replyError error, started time.Time, rw *replyWriter) {
func (c *Client) writeDisconnectOrErrorFlush(ch string, frameType protocol.FrameType, cmd *protocol.Command, err error, started time.Time, rw *replyWriter) {
defer func() {
c.node.metrics.observeCommandDuration(frameType, time.Since(started))
}()
switch t := replyError.(type) {
switch t := err.(type) {
case *Disconnect:
go func() { _ = c.close(*t) }()
if c.node.clientEvents.commandProcessedHandler != nil {
event := newCommandProcessedEvent(cmd, t, nil, started)
event := newCommandProcessedEvent(cmd, err, nil, started)
c.issueCommandProcessedEvent(event)
}
return
case Disconnect:
go func() { _ = c.close(t) }()
if c.node.clientEvents.commandProcessedHandler != nil {
event := newCommandProcessedEvent(cmd, &t, nil, started)
event := newCommandProcessedEvent(cmd, err, nil, started)
c.issueCommandProcessedEvent(event)
}
return
default:
errorReply := &protocol.Reply{Error: toClientErr(replyError).toProto()}
errorReply := &protocol.Reply{Error: toClientErr(err).toProto()}
c.writeError(ch, frameType, cmd, errorReply, rw)
if c.node.clientEvents.commandProcessedHandler != nil {
event := newCommandProcessedEvent(cmd, nil, errorReply, started)
event := newCommandProcessedEvent(cmd, err, errorReply, started)
c.issueCommandProcessedEvent(event)
}
}
Expand Down Expand Up @@ -2430,6 +2451,9 @@ func (c *Client) connectCmd(req *protocol.ConnectRequest, cmd *protocol.Command,
return nil, DisconnectServerError
}
c.writeEncodedPush(protoReply, rw, "", protocol.FrameTypePushConnect)
if c.node.LogEnabled(LogLevelTrace) {
c.traceOutPush(&protocol.Push{Connect: protoReply.Push.Connect})
}
}
} else {
protoReply, err := c.getConnectCommandReply(res)
Expand Down Expand Up @@ -2626,11 +2650,15 @@ func (c *Client) getSubscribePushReply(channel string, res *protocol.SubscribeRe
Positioned: res.GetPositioned(),
Data: res.Data,
}
push := &protocol.Push{
Channel: channel,
Subscribe: sub,
}
if c.node.LogEnabled(LogLevelTrace) {
c.traceOutPush(push)
}
return c.encodeReply(&protocol.Reply{
Push: &protocol.Push{
Channel: channel,
Subscribe: sub,
},
Push: push,
})
}

Expand Down
106 changes: 106 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3121,6 +3121,21 @@ func TestClientCheckPosition(t *testing.T) {
// not initial, not time to check.
got = client.checkPosition(300*time.Second, "channel", ChannelContext{positionCheckTime: 50, flags: flagPositioning})
require.True(t, got)

// not subscribed.
got = client.checkPosition(100*time.Second, "channel", ChannelContext{
positionCheckTime: 50, flags: flagPositioning, metaTTLSeconds: 10,
})
require.True(t, got)

// closed client.
client.mu.Lock()
client.status = statusClosed
client.mu.Unlock()
got = client.checkPosition(100*time.Second, "channel", ChannelContext{
positionCheckTime: 50, flags: flagPositioning, metaTTLSeconds: 10,
})
require.True(t, got)
}

func TestErrLogLevel(t *testing.T) {
Expand Down Expand Up @@ -3984,3 +3999,94 @@ func TestClientUnsubscribeDuringSubscribeCorrectChannels(t *testing.T) {
err := client.close(DisconnectForceNoReconnect)
require.NoError(t, err)
}

func TestClientConnectNoErrorToDisconnect(t *testing.T) {
t.Parallel()
errBoom := errors.New("boom")

testCases := []struct {
Name string
Err error
CommandReadErr error
}{
{"nil", nil, nil},
{"error", errBoom, nil},
{"cmd_read_error", nil, errBoom},
}

for _, tt := range testCases {
t.Run(tt.Name, func(t *testing.T) {
node := defaultTestNode()
defer func() { _ = node.Shutdown(context.Background()) }()

node.OnCommandRead(func(client *Client, event CommandReadEvent) error {
require.NotNil(t, event.Command.Connect)
return tt.CommandReadErr
})

node.OnConnecting(func(context.Context, ConnectEvent) (ConnectReply, error) {
return ConnectReply{}, tt.Err
})
transport := newTestTransport(func() {})
transport.setUnidirectional(true)
transport.sink = make(chan []byte, 100)
ctx := context.Background()
newCtx := SetCredentials(ctx, &Credentials{UserID: "42"})
client, _ := newClient(newCtx, node, transport)
err := client.ConnectNoErrorToDisconnect(ConnectRequest{})
if tt.Err != nil {
require.Equal(t, tt.Err, err)
} else if tt.CommandReadErr != nil {
require.Equal(t, tt.CommandReadErr, err)
} else {
require.NoError(t, err)
}
})
}
}

func TestClientConnect(t *testing.T) {
t.Parallel()
errBoom := errors.New("boom")

testCases := []struct {
Name string
Err error
CommandReadErr error
}{
{"nil", nil, nil},
{"error", errBoom, nil},
{"cmd_read_error", nil, errBoom},
}

for _, tt := range testCases {
t.Run(tt.Name, func(t *testing.T) {
node := defaultTestNode()
defer func() { _ = node.Shutdown(context.Background()) }()

node.OnCommandRead(func(client *Client, event CommandReadEvent) error {
require.NotNil(t, event.Command.Connect)
return tt.CommandReadErr
})

node.OnConnecting(func(context.Context, ConnectEvent) (ConnectReply, error) {
return ConnectReply{}, tt.Err
})

transport := newTestTransport(func() {})
transport.setUnidirectional(true)
transport.sink = make(chan []byte, 100)
ctx := context.Background()
newCtx := SetCredentials(ctx, &Credentials{UserID: "42"})
client, _ := newClient(newCtx, node, transport)
client.Connect(ConnectRequest{})
if tt.Err != nil || tt.CommandReadErr != nil {
msg := <-transport.sink
require.True(t, strings.HasPrefix(string(msg), `{"disconnect"`))
} else {
msg := <-transport.sink
require.True(t, strings.HasPrefix(string(msg), `{"connect"`))
}
})
}
}
16 changes: 8 additions & 8 deletions events.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,22 +463,22 @@ type CommandReadHandler func(*Client, CommandReadEvent) error
type CommandProcessedEvent struct {
// Command which was processed. May be pooled - see comment of CommandProcessedEvent.
Command *protocol.Command
// Disconnect may be set if Command processing resulted into disconnection.
Disconnect *Disconnect
// Error may be set to non-nil if Command processing resulted into error.
Error error
// Reply to the command. Reply may be pooled - see comment of CommandProcessedEvent.
// This Reply may be nil in the following cases:
// This Reply may be nil, for example in the following cases:
// 1. For Send command since send commands do not have replies
// 2. When Disconnect field of CommandProcessedEvent is not nil
// 3. When unidirectional transport connects (we create Connect Command artificially
// with id: 1 and we never send replies to unidirectional transport, only pushes).
// 2. When command processing resulted into disconnection of the client without sending a reply.
// 3. When unidirectional transport connects (Centrifuge creates Connect Command artificially
// with id: 1 and never sends replies to the unidirectional transport, only pushes).
Reply *protocol.Reply
// Started is a time command was passed to Client for processing.
Started time.Time
}

// newCommandProcessedEvent is a helper to create CommandProcessedEvent.
func newCommandProcessedEvent(command *protocol.Command, disconnect *Disconnect, reply *protocol.Reply, started time.Time) CommandProcessedEvent {
return CommandProcessedEvent{Command: command, Disconnect: disconnect, Reply: reply, Started: started}
func newCommandProcessedEvent(command *protocol.Command, err error, reply *protocol.Reply, started time.Time) CommandProcessedEvent {
return CommandProcessedEvent{Command: command, Error: err, Reply: reply, Started: started}
}

// CommandProcessedHandler allows setting a callback which will be called after
Expand Down
Loading

0 comments on commit a3a77cc

Please sign in to comment.