Skip to content

Commit

Permalink
Add DOS Protection (max subscriptions per client) (#31)
Browse files Browse the repository at this point in the history
* Add DOS Protection (max subscriptions per client)

* Fix panic on topic subscribe if no milestone received yet
  • Loading branch information
muXxer authored Jul 8, 2022
1 parent b453e69 commit c202c43
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 28 deletions.
7 changes: 5 additions & 2 deletions config_template.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
"mqtt": {
"bufferSize": 0,
"bufferBlockSize": 0,
"topicCleanupThresholdCount": 10000,
"topicCleanupThresholdRatio": 1.0,
"subscriptions": {
"maxTopicSubscriptionsPerClient": 1000,
"topicsCleanupThresholdCount": 10000,
"topicsCleanupThresholdRatio": 1
},
"websocket": {
"enabled": true,
"bindAddress": "localhost:1888"
Expand Down
5 changes: 3 additions & 2 deletions core/mqtt/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ func provide(c *dig.Container) error {
deps.ShutdownHandler,
mqtt.WithBufferSize(ParamsMQTT.BufferSize),
mqtt.WithBufferBlockSize(ParamsMQTT.BufferBlockSize),
mqtt.WithTopicCleanupThresholdCount(ParamsMQTT.TopicCleanupThresholdCount),
mqtt.WithTopicCleanupThresholdRatio(ParamsMQTT.TopicCleanupThresholdRatio),
mqtt.WithMaxTopicSubscriptionsPerClient(ParamsMQTT.Subscriptions.MaxTopicSubscriptionsPerClient),
mqtt.WithTopicCleanupThresholdCount(ParamsMQTT.Subscriptions.TopicsCleanupThresholdCount),
mqtt.WithTopicCleanupThresholdRatio(ParamsMQTT.Subscriptions.TopicsCleanupThresholdRatio),
mqtt.WithWebsocketEnabled(ParamsMQTT.Websocket.Enabled),
mqtt.WithWebsocketBindAddress(ParamsMQTT.Websocket.BindAddress),
mqtt.WithTCPEnabled(ParamsMQTT.TCP.Enabled),
Expand Down
15 changes: 10 additions & 5 deletions core/mqtt/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ package mqtt
import "github.com/iotaledger/hive.go/app"

type ParametersMQTT struct {
BufferSize int `default:"0" usage:"the size of the client buffers in bytes"`
BufferBlockSize int `default:"0" usage:"the size per client buffer R/W block in bytes"`
TopicCleanupThresholdCount int `default:"10000" usage:"the number of deleted topics that trigger a garbage collection of the subscription manager"`
TopicCleanupThresholdRatio float32 `default:"1.0" usage:"the ratio of subscribed topics to deleted topics that trigger a garbage collection of the subscription manager"`
Websocket struct {
BufferSize int `default:"0" usage:"the size of the client buffers in bytes"`
BufferBlockSize int `default:"0" usage:"the size per client buffer R/W block in bytes"`

Subscriptions struct {
MaxTopicSubscriptionsPerClient int `default:"1000" usage:"the maximum number of topic subscriptions per client before the client gets dropped (DOS protection)"`
TopicsCleanupThresholdCount int `default:"10000" usage:"the number of deleted topics that trigger a garbage collection of the subscription manager"`
TopicsCleanupThresholdRatio float32 `default:"1.0" usage:"the ratio of subscribed topics to deleted topics that trigger a garbage collection of the subscription manager"`
}

Websocket struct {
Enabled bool `default:"true" usage:"whether to enable the websocket connection of the MQTT broker"`
BindAddress string `default:"localhost:1888" usage:"the websocket bind address on which the MQTT broker listens on"`
}
Expand Down
4 changes: 4 additions & 0 deletions core/mqtt/publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ func (s *Server) PublishOnTopic(topic string, payload interface{}) {
}

func (s *Server) PublishMilestoneOnTopic(topic string, ms *nodebridge.Milestone) {
if ms == nil || ms.Milestone == nil {
return
}

s.PublishOnTopicIfSubscribed(topic, &milestoneInfoPayload{
Index: ms.Milestone.Index,
Time: ms.Milestone.Timestamp,
Expand Down
16 changes: 16 additions & 0 deletions pkg/mqtt/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,27 @@ func NewBroker(onClientConnect OnClientConnectFunc, onClientDisconnect OnClientD
}
}

// this function is used to drop malicious clients
dropClient := func(clientID string, reason error) {
client, exists := broker.Clients.Get(clientID)
if !exists {
return
}

// stop the client connection
client.Stop(reason)

// delete the client from the broker
broker.Clients.Delete(clientID)
}

s := NewSubscriptionManager(
onClientConnect,
onClientDisconnect,
onTopicSubscribe,
onTopicUnsubscribe,
dropClient,
brokerOpts.MaxTopicSubscriptionsPerClient,
brokerOpts.TopicCleanupThresholdCount,
brokerOpts.TopicCleanupThresholdRatio,
)
Expand Down
16 changes: 13 additions & 3 deletions pkg/mqtt/broker_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ type BrokerOptions struct {
BufferSize int
// BufferBlockSize is the size per client buffer R/W block in bytes.
BufferBlockSize int
// TopicCleanupThresholdCount the number of deleted topics that trigger a garbage collection of the SubscriptionManager.

// MaxTopicSubscriptionsPerClient defines the maximum number of topic subscriptions per client before the client gets dropped (DOS protection).
MaxTopicSubscriptionsPerClient int
// TopicCleanupThresholdCount defines the number of deleted topics that trigger a garbage collection of the SubscriptionManager.
TopicCleanupThresholdCount int
// TopicCleanupThresholdRatio the ratio of subscribed topics to deleted topics that trigger a garbage collection of the SubscriptionManager.
// TopicCleanupThresholdRatio defines the ratio of subscribed topics to deleted topics that trigger a garbage collection of the SubscriptionManager.
TopicCleanupThresholdRatio float32

// WebsocketEnabled defines whether to enable the websocket connection of the MQTT broker.
Expand Down Expand Up @@ -83,7 +86,14 @@ func WithBufferBlockSize(bufferBlockSize int) BrokerOption {
}
}

// WithTopicCleanupThreshold sets the number of deleted topics that trigger a garbage collection of the SubscriptionManager.
// WithMaxTopicSubscriptionsPerClient sets the maximum number of topic subscriptions per client before the client gets dropped (DOS protection).
func WithMaxTopicSubscriptionsPerClient(maxTopicSubscriptionsPerClient int) BrokerOption {
return func(options *BrokerOptions) {
options.MaxTopicSubscriptionsPerClient = maxTopicSubscriptionsPerClient
}
}

// WithTopicCleanupThresholdCount sets the number of deleted topics that trigger a garbage collection of the SubscriptionManager.
func WithTopicCleanupThresholdCount(topicCleanupThresholdCount int) BrokerOption {
return func(options *BrokerOptions) {
options.TopicCleanupThresholdCount = topicCleanupThresholdCount
Expand Down
39 changes: 31 additions & 8 deletions pkg/mqtt/subscription_manager.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package mqtt

import (
"errors"
"sync"
)

type OnClientConnectFunc func(clientID string)
type OnClientDisconnectFunc func(clientID string)
type OnTopicSubscribeFunc func(clientID string, topic string)
type OnTopicUnsubscribeFunc func(clientID string, topic string)
type DropClientFunc func(clientID string, reason error)

var (
ErrMaxTopicSubscriptionsPerClientReached = errors.New("maximum amount of topic subscriptions per client reached")
)

// SubscriptionManager keeps track of subscribed topics of clients of the
// mqtt broker by subscribing to broker events.
Expand All @@ -19,20 +25,24 @@ type SubscriptionManager struct {
subscribers *ShrinkingMap[string, *ShrinkingMap[string, int]]
subscribersLock sync.RWMutex

cleanupThresholdCount int
cleanupThresholdRatio float32
maxTopicSubscriptionsPerClient int
cleanupThresholdCount int
cleanupThresholdRatio float32

onClientConnect OnClientConnectFunc
onClientDisconnect OnClientDisconnectFunc
onTopicSubscribe OnTopicSubscribeFunc
onTopicUnsubscribe OnTopicUnsubscribeFunc
dropClient DropClientFunc
}

func NewSubscriptionManager(
onClientConnect OnClientConnectFunc,
onClientDisconnect OnClientDisconnectFunc,
onTopicSubscribe OnTopicSubscribeFunc,
onTopicUnsubscribe OnTopicUnsubscribeFunc,
dropClient DropClientFunc,
maxTopicSubscriptionsPerClient int,
cleanupThresholdCount int,
cleanupThresholdRatio float32) *SubscriptionManager {

Expand All @@ -41,12 +51,14 @@ func NewSubscriptionManager(
WithShrinkingThresholdRatio(cleanupThresholdRatio),
WithShrinkingThresholdCount(cleanupThresholdCount),
),
onClientConnect: onClientConnect,
onClientDisconnect: onClientDisconnect,
onTopicSubscribe: onTopicSubscribe,
onTopicUnsubscribe: onTopicUnsubscribe,
cleanupThresholdCount: cleanupThresholdCount,
cleanupThresholdRatio: cleanupThresholdRatio,
onClientConnect: onClientConnect,
onClientDisconnect: onClientDisconnect,
onTopicSubscribe: onTopicSubscribe,
onTopicUnsubscribe: onTopicUnsubscribe,
dropClient: dropClient,
maxTopicSubscriptionsPerClient: maxTopicSubscriptionsPerClient,
cleanupThresholdCount: cleanupThresholdCount,
cleanupThresholdRatio: cleanupThresholdRatio,
}
}

Expand Down Expand Up @@ -95,7 +107,18 @@ func (s *SubscriptionManager) Subscribe(clientID string, topic string) {
if has {
subscribedTopics.Set(topic, count+1)
} else {
// add a new topic
subscribedTopics.Set(topic, 1)

// check if the client has reached the max number of subscriptions
if subscribedTopics.Size() >= s.maxTopicSubscriptionsPerClient {
// cleanup the client
s.cleanupClientWithoutLocking(clientID)
// drop the client
if s.dropClient != nil {
s.dropClient(clientID, ErrMaxTopicSubscriptionsPerClientReached)
}
}
}

if s.onTopicSubscribe != nil {
Expand Down
69 changes: 61 additions & 8 deletions pkg/mqtt/subscription_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ const (

topic_1 = "topic1"
topic_2 = "topic2"
topic_3 = "topic3"
topic_4 = "topic4"
topic_5 = "topic5"
topic_6 = "topic6"
)

func TestSubscriptionManager_ConnectWithNoTopics(t *testing.T) {
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, nil, 1000, 0, 0.0)

require.Equal(t, manager.SubscribersSize(), 0)
require.Equal(t, manager.TopicsSize(), 0)
Expand All @@ -32,7 +36,7 @@ func TestSubscriptionManager_ConnectWithNoTopics(t *testing.T) {
}

func TestSubscriptionManager_ConnectWithSameID(t *testing.T) {
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, nil, 1000, 0, 0.0)

manager.Connect(clientID_1)
manager.Subscribe(clientID_1, topic_1)
Expand All @@ -50,15 +54,15 @@ func TestSubscriptionManager_ConnectWithSameID(t *testing.T) {
}

func TestSubscriptionManager_SubscribeWithoutConnect(t *testing.T) {
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, nil, 1000, 0, 0.0)

manager.Subscribe(clientID_1, topic_1)
require.Equal(t, manager.SubscribersSize(), 0)
require.Equal(t, manager.TopicsSize(), 0)
}

func TestSubscriptionManager_SubscribeWithSameTopic(t *testing.T) {
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, nil, 1000, 0, 0.0)

manager.Connect(clientID_1)
manager.Subscribe(clientID_1, topic_1)
Expand All @@ -71,15 +75,15 @@ func TestSubscriptionManager_SubscribeWithSameTopic(t *testing.T) {
}

func TestSubscriptionManager_UnsubscribeWithoutConnect(t *testing.T) {
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, nil, 1000, 0, 0.0)

manager.Unsubscribe(clientID_1, topic_1)
require.Equal(t, manager.SubscribersSize(), 0)
require.Equal(t, manager.TopicsSize(), 0)
}

func TestSubscriptionManager_UnsubscribeWithSameTopic(t *testing.T) {
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, nil, 1000, 0, 0.0)

manager.Connect(clientID_1)
manager.Subscribe(clientID_1, topic_1)
Expand All @@ -96,7 +100,7 @@ func TestSubscriptionManager_UnsubscribeWithSameTopic(t *testing.T) {
}

func TestSubscriptionManager_Subscribers(t *testing.T) {
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, nil, 1000, 0, 0.0)

manager.Connect(clientID_1)
manager.Connect(clientID_1)
Expand Down Expand Up @@ -134,7 +138,7 @@ func TestSubscriptionManager_ClientCleanup(t *testing.T) {
}
}

manager := mqtt.NewSubscriptionManager(nil, nil, onTopicsSubscribe, onTopicsUnsubscribe, 0, 0.0)
manager := mqtt.NewSubscriptionManager(nil, nil, onTopicsSubscribe, onTopicsUnsubscribe, nil, 1000, 0, 0.0)

manager.Connect(clientID_1)
manager.Subscribe(clientID_1, topic_1)
Expand Down Expand Up @@ -163,3 +167,52 @@ func TestSubscriptionManager_ClientCleanup(t *testing.T) {
require.Equal(t, subscribe_client_1, 4)
require.Equal(t, unsubscribe_client_1, 4)
}

func TestSubscriptionManager_MaxTopicSubscriptionsPerClient(t *testing.T) {

clientDropped := false
dropClient := func(clientID string, reason error) {
clientDropped = true
}

manager := mqtt.NewSubscriptionManager(nil, nil, nil, nil, dropClient, 5, 0, 0.0)

require.Equal(t, manager.SubscribersSize(), 0)
require.Equal(t, manager.TopicsSize(), 0)
require.Equal(t, clientDropped, false)

manager.Connect(clientID_1)
require.Equal(t, manager.SubscribersSize(), 1)
require.Equal(t, manager.TopicsSize(), 0)
require.Equal(t, clientDropped, false)

manager.Subscribe(clientID_1, topic_1)
require.Equal(t, manager.SubscribersSize(), 1)
require.Equal(t, manager.TopicsSize(), 1)
require.Equal(t, clientDropped, false)

manager.Subscribe(clientID_1, topic_2)
require.Equal(t, manager.SubscribersSize(), 1)
require.Equal(t, manager.TopicsSize(), 2)
require.Equal(t, clientDropped, false)

manager.Subscribe(clientID_1, topic_3)
require.Equal(t, manager.SubscribersSize(), 1)
require.Equal(t, manager.TopicsSize(), 3)
require.Equal(t, clientDropped, false)

manager.Subscribe(clientID_1, topic_4)
require.Equal(t, manager.SubscribersSize(), 1)
require.Equal(t, manager.TopicsSize(), 4)
require.Equal(t, clientDropped, false)

manager.Subscribe(clientID_1, topic_5)
require.Equal(t, manager.SubscribersSize(), 0)
require.Equal(t, manager.TopicsSize(), 0)
require.Equal(t, clientDropped, true)

manager.Subscribe(clientID_1, topic_6)
require.Equal(t, manager.SubscribersSize(), 0)
require.Equal(t, manager.TopicsSize(), 0)
require.Equal(t, clientDropped, true)
}

0 comments on commit c202c43

Please sign in to comment.