diff --git a/.codecov.yml b/.codecov.yml index 87940329..be286b01 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -8,3 +8,4 @@ ignore: - "internal/controlpb/control*" # generated code - "client_experimental.go" # experimental code - "internal/websocket/*" # embedded Gorilla WebSocket fork + - "internal/websocket/examples/autobahn/*" # embedded Gorilla WebSocket fork diff --git a/broker_memory_test.go b/broker_memory_test.go index ba4361eb..7d79f126 100644 --- a/broker_memory_test.go +++ b/broker_memory_test.go @@ -588,6 +588,11 @@ func TestClientSubscribeRecover(t *testing.T) { t.Run(tt.Name, func(t *testing.T) { node := defaultNodeNoHandlers() node.config.RecoveryMaxPublicationLimit = tt.Limit + + node.OnCacheEmpty(func(event CacheEmptyEvent) (CacheEmptyReply, error) { + return CacheEmptyReply{}, nil + }) + node.OnConnect(func(client *Client) { client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) { opts := SubscribeOptions{EnableRecovery: true, RecoveryMode: tt.RecoveryMode} diff --git a/client.go b/client.go index d3f7e7ce..ecf897f9 100644 --- a/client.go +++ b/client.go @@ -330,13 +330,23 @@ func extractUnidirectionalDisconnect(err error) Disconnect { } } -// Connect supposed to be called from unidirectional transport layer to pass -// initial information about connection and thus initiate Node.OnConnecting +// Connect supposed to be called only from a unidirectional transport layer +// to pass initial information about connection and thus initiate Node.OnConnecting // event. Bidirectional transport initiate connecting workflow automatically // since client passes Connect command upon successful connection establishment -// with a server. +// with a server. If there is an error during connect method processing Centrifuge +// extracts Disconnect from it and closes the connection with that Disconnect message. func (c *Client) Connect(req ConnectRequest) { - c.unidirectionalConnect(req.toProto(), 0) + // unidirectionalConnect never returns errors when errorToDisconnect is true. + _ = c.unidirectionalConnect(req.toProto(), 0, true) +} + +// ConnectNoErrorToDisconnect is the same as Client.Connect but does not try to extract +// Disconnect code from the error returned by the connect logic, instead it just returns +// the error to the caller. This error must be handled by the caller on the Transport level, +// and the connection must be closed on Transport level upon receiving an error. +func (c *Client) ConnectNoErrorToDisconnect(req ConnectRequest) error { + return c.unidirectionalConnect(req.toProto(), 0, false) } func (c *Client) getDisconnectPushReply(d Disconnect) ([]byte, error) { @@ -344,10 +354,14 @@ func (c *Client) getDisconnectPushReply(d Disconnect) ([]byte, error) { Code: d.Code, Reason: d.Reason, } + push := &protocol.Push{ + Disconnect: disconnect, + } + if c.node.LogEnabled(LogLevelTrace) { + c.traceOutPush(push) + } return c.encodeReply(&protocol.Reply{ - Push: &protocol.Push{ - Disconnect: disconnect, - }, + Push: push, }) } @@ -371,7 +385,7 @@ func (c *Client) issueCommandProcessedEvent(event CommandProcessedEvent) { } } -func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest, connectCmdSize int) { +func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest, connectCmdSize int, errorToDisconnect bool) error { started := time.Now() var cmd *protocol.Command @@ -385,28 +399,35 @@ func (c *Client) unidirectionalConnect(connectRequest *protocol.ConnectRequest, cmd = &protocol.Command{Id: 1, Connect: connectRequest} err := c.issueCommandReadEvent(cmd, connectCmdSize) if err != nil { - d := extractUnidirectionalDisconnect(err) - go func() { _ = c.close(d) }() if c.node.clientEvents.commandProcessedHandler != nil { - c.handleCommandFinished(cmd, protocol.FrameTypeConnect, &d, nil, started) + c.handleCommandFinished(cmd, protocol.FrameTypeConnect, err, nil, started) } - return + if errorToDisconnect { + d := extractUnidirectionalDisconnect(err) + go func() { _ = c.close(d) }() + return nil + } + return err } } _, err := c.connectCmd(connectRequest, nil, time.Time{}, nil) if err != nil { - d := extractUnidirectionalDisconnect(err) - go func() { _ = c.close(d) }() if c.node.clientEvents.commandProcessedHandler != nil { - c.handleCommandFinished(cmd, protocol.FrameTypeConnect, &d, nil, started) + c.handleCommandFinished(cmd, protocol.FrameTypeConnect, err, nil, started) } - return + if errorToDisconnect { + d := extractUnidirectionalDisconnect(err) + go func() { _ = c.close(d) }() + return nil + } + return err } if c.node.clientEvents.commandProcessedHandler != nil { c.handleCommandFinished(cmd, protocol.FrameTypeConnect, nil, nil, started) } c.triggerConnect() c.scheduleOnConnectTimers() + return nil } func (c *Client) onTimerOp() { @@ -1112,12 +1133,12 @@ func isPong(cmd *protocol.Command) bool { return cmd.Id == 0 && cmd.Send == nil } -func (c *Client) handleCommandFinished(cmd *protocol.Command, frameType protocol.FrameType, disconnect *Disconnect, reply *protocol.Reply, started time.Time) { +func (c *Client) handleCommandFinished(cmd *protocol.Command, frameType protocol.FrameType, err error, reply *protocol.Reply, started time.Time) { defer func() { c.node.metrics.observeCommandDuration(frameType, time.Since(started)) }() if c.node.clientEvents.commandProcessedHandler != nil { - event := newCommandProcessedEvent(cmd, disconnect, reply, started) + event := newCommandProcessedEvent(cmd, err, reply, started) c.issueCommandProcessedEvent(event) } } @@ -1129,13 +1150,13 @@ func (c *Client) handleCommandDispatchError(ch string, cmd *protocol.Command, fr switch t := err.(type) { case *Disconnect: if c.node.clientEvents.commandProcessedHandler != nil { - event := newCommandProcessedEvent(cmd, t, nil, started) + event := newCommandProcessedEvent(cmd, err, nil, started) c.issueCommandProcessedEvent(event) } return t, false case Disconnect: if c.node.clientEvents.commandProcessedHandler != nil { - event := newCommandProcessedEvent(cmd, &t, nil, started) + event := newCommandProcessedEvent(cmd, err, nil, started) c.issueCommandProcessedEvent(event) } return &t, false @@ -1148,7 +1169,7 @@ func (c *Client) handleCommandDispatchError(ch string, cmd *protocol.Command, fr errorReply := &protocol.Reply{Error: toClientErr(err).toProto()} c.writeError(ch, frameType, cmd, errorReply, nil) if c.node.clientEvents.commandProcessedHandler != nil { - event := newCommandProcessedEvent(cmd, nil, errorReply, started) + event := newCommandProcessedEvent(cmd, err, errorReply, started) c.issueCommandProcessedEvent(event) } return nil, cmd.Connect == nil @@ -2089,30 +2110,30 @@ func (c *Client) writeError(ch string, frameType protocol.FrameType, cmd *protoc c.writeEncodedCommandReply(ch, frameType, cmd, errorReply, rw) } -func (c *Client) writeDisconnectOrErrorFlush(ch string, frameType protocol.FrameType, cmd *protocol.Command, replyError error, started time.Time, rw *replyWriter) { +func (c *Client) writeDisconnectOrErrorFlush(ch string, frameType protocol.FrameType, cmd *protocol.Command, err error, started time.Time, rw *replyWriter) { defer func() { c.node.metrics.observeCommandDuration(frameType, time.Since(started)) }() - switch t := replyError.(type) { + switch t := err.(type) { case *Disconnect: go func() { _ = c.close(*t) }() if c.node.clientEvents.commandProcessedHandler != nil { - event := newCommandProcessedEvent(cmd, t, nil, started) + event := newCommandProcessedEvent(cmd, err, nil, started) c.issueCommandProcessedEvent(event) } return case Disconnect: go func() { _ = c.close(t) }() if c.node.clientEvents.commandProcessedHandler != nil { - event := newCommandProcessedEvent(cmd, &t, nil, started) + event := newCommandProcessedEvent(cmd, err, nil, started) c.issueCommandProcessedEvent(event) } return default: - errorReply := &protocol.Reply{Error: toClientErr(replyError).toProto()} + errorReply := &protocol.Reply{Error: toClientErr(err).toProto()} c.writeError(ch, frameType, cmd, errorReply, rw) if c.node.clientEvents.commandProcessedHandler != nil { - event := newCommandProcessedEvent(cmd, nil, errorReply, started) + event := newCommandProcessedEvent(cmd, err, errorReply, started) c.issueCommandProcessedEvent(event) } } @@ -2430,6 +2451,9 @@ func (c *Client) connectCmd(req *protocol.ConnectRequest, cmd *protocol.Command, return nil, DisconnectServerError } c.writeEncodedPush(protoReply, rw, "", protocol.FrameTypePushConnect) + if c.node.LogEnabled(LogLevelTrace) { + c.traceOutPush(&protocol.Push{Connect: protoReply.Push.Connect}) + } } } else { protoReply, err := c.getConnectCommandReply(res) @@ -2626,11 +2650,15 @@ func (c *Client) getSubscribePushReply(channel string, res *protocol.SubscribeRe Positioned: res.GetPositioned(), Data: res.Data, } + push := &protocol.Push{ + Channel: channel, + Subscribe: sub, + } + if c.node.LogEnabled(LogLevelTrace) { + c.traceOutPush(push) + } return c.encodeReply(&protocol.Reply{ - Push: &protocol.Push{ - Channel: channel, - Subscribe: sub, - }, + Push: push, }) } diff --git a/client_test.go b/client_test.go index a0fbafed..09ca2788 100644 --- a/client_test.go +++ b/client_test.go @@ -3121,6 +3121,21 @@ func TestClientCheckPosition(t *testing.T) { // not initial, not time to check. got = client.checkPosition(300*time.Second, "channel", ChannelContext{positionCheckTime: 50, flags: flagPositioning}) require.True(t, got) + + // not subscribed. + got = client.checkPosition(100*time.Second, "channel", ChannelContext{ + positionCheckTime: 50, flags: flagPositioning, metaTTLSeconds: 10, + }) + require.True(t, got) + + // closed client. + client.mu.Lock() + client.status = statusClosed + client.mu.Unlock() + got = client.checkPosition(100*time.Second, "channel", ChannelContext{ + positionCheckTime: 50, flags: flagPositioning, metaTTLSeconds: 10, + }) + require.True(t, got) } func TestErrLogLevel(t *testing.T) { @@ -3984,3 +3999,94 @@ func TestClientUnsubscribeDuringSubscribeCorrectChannels(t *testing.T) { err := client.close(DisconnectForceNoReconnect) require.NoError(t, err) } + +func TestClientConnectNoErrorToDisconnect(t *testing.T) { + t.Parallel() + errBoom := errors.New("boom") + + testCases := []struct { + Name string + Err error + CommandReadErr error + }{ + {"nil", nil, nil}, + {"error", errBoom, nil}, + {"cmd_read_error", nil, errBoom}, + } + + for _, tt := range testCases { + t.Run(tt.Name, func(t *testing.T) { + node := defaultTestNode() + defer func() { _ = node.Shutdown(context.Background()) }() + + node.OnCommandRead(func(client *Client, event CommandReadEvent) error { + require.NotNil(t, event.Command.Connect) + return tt.CommandReadErr + }) + + node.OnConnecting(func(context.Context, ConnectEvent) (ConnectReply, error) { + return ConnectReply{}, tt.Err + }) + transport := newTestTransport(func() {}) + transport.setUnidirectional(true) + transport.sink = make(chan []byte, 100) + ctx := context.Background() + newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) + client, _ := newClient(newCtx, node, transport) + err := client.ConnectNoErrorToDisconnect(ConnectRequest{}) + if tt.Err != nil { + require.Equal(t, tt.Err, err) + } else if tt.CommandReadErr != nil { + require.Equal(t, tt.CommandReadErr, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestClientConnect(t *testing.T) { + t.Parallel() + errBoom := errors.New("boom") + + testCases := []struct { + Name string + Err error + CommandReadErr error + }{ + {"nil", nil, nil}, + {"error", errBoom, nil}, + {"cmd_read_error", nil, errBoom}, + } + + for _, tt := range testCases { + t.Run(tt.Name, func(t *testing.T) { + node := defaultTestNode() + defer func() { _ = node.Shutdown(context.Background()) }() + + node.OnCommandRead(func(client *Client, event CommandReadEvent) error { + require.NotNil(t, event.Command.Connect) + return tt.CommandReadErr + }) + + node.OnConnecting(func(context.Context, ConnectEvent) (ConnectReply, error) { + return ConnectReply{}, tt.Err + }) + + transport := newTestTransport(func() {}) + transport.setUnidirectional(true) + transport.sink = make(chan []byte, 100) + ctx := context.Background() + newCtx := SetCredentials(ctx, &Credentials{UserID: "42"}) + client, _ := newClient(newCtx, node, transport) + client.Connect(ConnectRequest{}) + if tt.Err != nil || tt.CommandReadErr != nil { + msg := <-transport.sink + require.True(t, strings.HasPrefix(string(msg), `{"disconnect"`)) + } else { + msg := <-transport.sink + require.True(t, strings.HasPrefix(string(msg), `{"connect"`)) + } + }) + } +} diff --git a/events.go b/events.go index 0fe10e98..886661e4 100644 --- a/events.go +++ b/events.go @@ -463,22 +463,22 @@ type CommandReadHandler func(*Client, CommandReadEvent) error type CommandProcessedEvent struct { // Command which was processed. May be pooled - see comment of CommandProcessedEvent. Command *protocol.Command - // Disconnect may be set if Command processing resulted into disconnection. - Disconnect *Disconnect + // Error may be set to non-nil if Command processing resulted into error. + Error error // Reply to the command. Reply may be pooled - see comment of CommandProcessedEvent. - // This Reply may be nil in the following cases: + // This Reply may be nil, for example in the following cases: // 1. For Send command since send commands do not have replies - // 2. When Disconnect field of CommandProcessedEvent is not nil - // 3. When unidirectional transport connects (we create Connect Command artificially - // with id: 1 and we never send replies to unidirectional transport, only pushes). + // 2. When command processing resulted into disconnection of the client without sending a reply. + // 3. When unidirectional transport connects (Centrifuge creates Connect Command artificially + // with id: 1 and never sends replies to the unidirectional transport, only pushes). Reply *protocol.Reply // Started is a time command was passed to Client for processing. Started time.Time } // newCommandProcessedEvent is a helper to create CommandProcessedEvent. -func newCommandProcessedEvent(command *protocol.Command, disconnect *Disconnect, reply *protocol.Reply, started time.Time) CommandProcessedEvent { - return CommandProcessedEvent{Command: command, Disconnect: disconnect, Reply: reply, Started: started} +func newCommandProcessedEvent(command *protocol.Command, err error, reply *protocol.Reply, started time.Time) CommandProcessedEvent { + return CommandProcessedEvent{Command: command, Error: err, Reply: reply, Started: started} } // CommandProcessedHandler allows setting a callback which will be called after diff --git a/handler_http_stream_test.go b/handler_http_stream_test.go index 1dafae06..897df9b1 100644 --- a/handler_http_stream_test.go +++ b/handler_http_stream_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "encoding/json" "io" "net/http" @@ -66,6 +67,56 @@ func TestHTTPStreamHandler(t *testing.T) { } } +func TestHTTPStreamHandler_Protobuf(t *testing.T) { + t.Parallel() + n, _ := New(Config{ + LogLevel: LogLevelDebug, + }) + + n.OnConnecting(func(ctx context.Context, event ConnectEvent) (ConnectReply, error) { + return ConnectReply{Credentials: &Credentials{ + UserID: "test", + }}, nil + }) + + require.NoError(t, n.Run()) + defer func() { _ = n.Shutdown(context.Background()) }() + mux := http.NewServeMux() + mux.Handle("/connection/http_stream", NewHTTPStreamHandler(n, HTTPStreamConfig{})) + server := httptest.NewServer(mux) + defer server.Close() + + url := server.URL + "/connection/http_stream" + client := &http.Client{Timeout: 5 * time.Second} + command := &protocol.Command{ + Id: 1, + Connect: &protocol.ConnectRequest{}, + } + enc := protocol.NewProtobufCommandEncoder() + protoData, err := enc.Encode(command) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(protoData)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer func() { _ = resp.Body.Close() }() + + dec := newProtobufStreamCommandDecoder(resp.Body) + for { + reply, _, err := dec.decode() + require.NoError(t, err) + require.NotNil(t, reply.Connect) + require.Equal(t, uint32(1), reply.Id) + require.NotZero(t, reply.Connect.Session) + require.NotZero(t, reply.Connect.Node) + break + } +} + func TestHTTPStreamHandler_RequestTooLarge(t *testing.T) { t.Parallel() n, _ := New(Config{}) @@ -130,3 +181,33 @@ func (d *jsonStreamDecoder) decode() ([]byte, error) { line, _, err := d.r.ReadLine() return line, err } + +type protobufStreamCommandDecoder struct { + reader *bufio.Reader +} + +func newProtobufStreamCommandDecoder(reader io.Reader) *protobufStreamCommandDecoder { + return &protobufStreamCommandDecoder{reader: bufio.NewReader(reader)} +} + +func (d *protobufStreamCommandDecoder) decode() (*protocol.Reply, int, error) { + msgLength, err := binary.ReadUvarint(d.reader) + if err != nil { + return nil, 0, err + } + + b := make([]byte, msgLength) + n, err := io.ReadFull(d.reader, b) + if err != nil { + return nil, 0, err + } + if uint64(n) != msgLength { + return nil, 0, io.ErrShortBuffer + } + var c protocol.Reply + err = c.UnmarshalVT(b[:int(msgLength)]) + if err != nil { + return nil, 0, err + } + return &c, int(msgLength) + 8, nil +}