Skip to content

Commit

Permalink
Send envelopes in batches (#253)
Browse files Browse the repository at this point in the history
When streaming envelopes, send them in batches rather than one-by-one.
We use a different batching strategy based on whether the subscription
is originator/topic/global.

#255

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced batch processing for envelope sending, improving
performance.
- Enhanced listener and subscription mechanisms with better error
handling and logging.
- Added new methods to various envelope structs for improved topic
management.

- **Bug Fixes**
- Improved handling of invalid topics and ensured proper cursor
management.

- **Tests**
- Updated test cases to utilize new topic variables for consistency and
readability.
- Added a new utility function for creating originator envelopes with
specified topics.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
richardhuaaa authored Oct 24, 2024
1 parent 05f9191 commit 51cbe4f
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 62 deletions.
19 changes: 11 additions & 8 deletions pkg/api/message/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,24 @@ func (s *Service) sendEnvelopes(
NodeIdToSequenceId: cursor,
}
}
envsToSend := make([]*envelopesProto.OriginatorEnvelope, 0, len(envs))
for _, env := range envs {
if cursor[uint32(env.OriginatorNodeID())] >= env.OriginatorSequenceID() {
continue
}

// TODO(rich): Either batch send envelopes, or modify stream proto to
// send one envelope at a time.
err := stream.Send(&message_api.SubscribeEnvelopesResponse{
Envelopes: []*envelopesProto.OriginatorEnvelope{env.Proto()},
})
if err != nil {
return status.Errorf(codes.Internal, "error sending envelope: %v", err)
}
envsToSend = append(envsToSend, env.Proto())
cursor[uint32(env.OriginatorNodeID())] = env.OriginatorSequenceID()
}
if len(envsToSend) == 0 {
return nil
}
err := stream.Send(&message_api.SubscribeEnvelopesResponse{
Envelopes: envsToSend,
})
if err != nil {
return status.Errorf(codes.Internal, "error sending envelopes: %v", err)
}
return nil
}

Expand Down
76 changes: 57 additions & 19 deletions pkg/api/message/subscribeWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package message
import (
"context"
"database/sql"
"encoding/hex"
"sync"
"time"

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/topic"
"go.uber.org/zap"
)

Expand All @@ -31,6 +31,7 @@ type listener struct {

func newListener(
ctx context.Context,
logger *zap.Logger,
query *message_api.EnvelopesQuery,
ch chan<- []*envelopes.OriginatorEnvelope,
) *listener {
Expand All @@ -49,9 +50,14 @@ func newListener(
return l
}

for _, topic := range topics {
topicStr := hex.EncodeToString(topic)
l.topics[topicStr] = struct{}{}
for _, t := range topics {
validatedTopic, err := topic.ParseTopic(t)
if err != nil {
logger.Warn("Skipping invalid topic", zap.Binary("topicBytes", t))
continue
}
logger.Debug("Adding topic listener", zap.String("topic", validatedTopic.String()))
l.topics[validatedTopic.String()] = struct{}{}
}

for _, originator := range originators {
Expand Down Expand Up @@ -122,6 +128,12 @@ func (lm *listenersMap[K]) getListeners(key K) *listenerSet {
return nil
}

func (lm *listenersMap[K]) rangeKeys(fn func(key K, listeners *listenerSet) bool) {
lm.data.Range(func(key, value any) bool {
return fn(key.(K), value.(*listenerSet))
})
}

// A worker that listens for new envelopes in the DB and sends them to subscribers
// Assumes that there are many listeners - non-blocking updates are sent on buffered channels
// and may be dropped if full
Expand All @@ -142,6 +154,7 @@ func startSubscribeWorker(
store *sql.DB,
) (*subscribeWorker, error) {
log = log.Named("subscribeWorker")
log.Info("Starting subscribe worker")
q := queries.New(store)
pollableQuery := func(ctx context.Context, lastSeen db.VectorClock, numRows int32) ([]queries.GatewayEnvelope, db.VectorClock, error) {
envs, err := q.
Expand Down Expand Up @@ -198,32 +211,57 @@ func (s *subscribeWorker) start() {
case <-s.ctx.Done():
return
case new_batch := <-s.dbSubscription:
s.log.Debug("Received new batch", zap.Int("numEnvelopes", len(new_batch)))
envs := make([]*envelopes.OriginatorEnvelope, 0, len(new_batch))
for _, row := range new_batch {
s.dispatch(&row)
env, err := envelopes.NewOriginatorEnvelopeFromBytes(row.OriginatorEnvelope)
if err != nil {
s.log.Error("Failed to unmarshal envelope", zap.Error(err))
continue
}
envs = append(envs, env)
}
s.dispatchToOriginators(envs)
s.dispatchToTopics(envs)
s.dispatchToGlobals(envs)
}
}
}

func (s *subscribeWorker) dispatch(row *queries.GatewayEnvelope) {
env, err := envelopes.NewOriginatorEnvelopeFromBytes(row.OriginatorEnvelope)
if err != nil {
s.log.Error("Failed to unmarshal envelope", zap.Error(err))
return
func (s *subscribeWorker) dispatchToOriginators(envs []*envelopes.OriginatorEnvelope) {
// We use nested loops here because the number of originators is expected to be small
// Possible future optimization: Set up set up multiple DB subscriptions instead of one,
// and have the DB group by originator, topic, and global.
s.originatorListeners.rangeKeys(func(originator uint32, listeners *listenerSet) bool {
filteredEnvs := make([]*envelopes.OriginatorEnvelope, 0, len(envs))
for _, env := range envs {
if env.OriginatorNodeID() == originator {
filteredEnvs = append(filteredEnvs, env)
}
}
s.dispatchToListeners(listeners, filteredEnvs)
return true
})
}

func (s *subscribeWorker) dispatchToTopics(envs []*envelopes.OriginatorEnvelope) {
// We iterate envelopes one-by-one, because we expect the number of envelopers
// per-topic to be small in each tick
for _, env := range envs {
listeners := s.topicListeners.getListeners(env.TargetTopic().String())
s.dispatchToListeners(listeners, []*envelopes.OriginatorEnvelope{env})
}
}

originatorListeners := s.originatorListeners.getListeners(uint32(row.OriginatorNodeID))
topicListeners := s.topicListeners.getListeners(hex.EncodeToString(row.Topic))
s.dispatchToListeners(originatorListeners, env)
s.dispatchToListeners(topicListeners, env)
s.dispatchToListeners(&s.globalListeners, env)
func (s *subscribeWorker) dispatchToGlobals(envs []*envelopes.OriginatorEnvelope) {
s.dispatchToListeners(&s.globalListeners, envs)
}

func (s *subscribeWorker) dispatchToListeners(
listeners *listenerSet,
env *envelopes.OriginatorEnvelope,
envs []*envelopes.OriginatorEnvelope,
) {
if listeners == nil {
if listeners == nil || len(envs) == 0 {
return
}
listeners.Range(func(key, _ any) bool {
Expand All @@ -238,7 +276,7 @@ func (s *subscribeWorker) dispatchToListeners(
s.closeListener(l)
default:
select {
case l.ch <- []*envelopes.OriginatorEnvelope{env}:
case l.ch <- envs:
default:
s.log.Info("Channel full, removing listener", zap.Any("listener", l.ch))
s.closeListener(l)
Expand Down Expand Up @@ -269,7 +307,7 @@ func (s *subscribeWorker) listen(
query *message_api.EnvelopesQuery,
) <-chan []*envelopes.OriginatorEnvelope {
ch := make(chan []*envelopes.OriginatorEnvelope, subscriptionBufferSize)
l := newListener(ctx, query, ch)
l := newListener(ctx, s.log, query, ch)

if l.isGlobal {
s.globalListeners.Store(l, struct{}{})
Expand Down
36 changes: 21 additions & 15 deletions pkg/api/message/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ import (
"github.com/xmtp/xmtpd/pkg/testutils"
testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
"github.com/xmtp/xmtpd/pkg/topic"
)

var (
topicA = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicA")).Bytes()
topicB = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicB")).Bytes()
topicC = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicC")).Bytes()
)
var allRows []queries.InsertGatewayEnvelopeParams

func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) {
Expand All @@ -25,47 +31,47 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
{
OriginatorNodeID: 1,
OriginatorSequenceID: 1,
Topic: []byte("topicA"),
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 1),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA),
),
},
{
OriginatorNodeID: 2,
OriginatorSequenceID: 1,
Topic: []byte("topicA"),
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 1),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA),
),
},
// Later rows
{
OriginatorNodeID: 1,
OriginatorSequenceID: 2,
Topic: []byte("topicB"),
Topic: topicB,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 2),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB),
),
},
{
OriginatorNodeID: 2,
OriginatorSequenceID: 2,
Topic: []byte("topicB"),
Topic: topicB,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 2),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB),
),
},
{
OriginatorNodeID: 1,
OriginatorSequenceID: 3,
Topic: []byte("topicA"),
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 3),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA),
),
},
}
Expand Down Expand Up @@ -140,7 +146,7 @@ func TestSubscribeEnvelopesByTopic(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")},
Topics: []db.Topic{topicA, topicC},
LastSeen: nil,
},
},
Expand Down Expand Up @@ -192,7 +198,7 @@ func TestSimultaneousSubscriptions(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicB")},
Topics: []db.Topic{topicB},
LastSeen: nil,
},
},
Expand Down Expand Up @@ -227,7 +233,7 @@ func TestSubscribeEnvelopesFromCursor(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")},
Topics: []db.Topic{topicA, topicC},
LastSeen: &envelopes.VectorClock{NodeIdToSequenceId: map[uint32]uint64{1: 1}},
},
},
Expand All @@ -249,7 +255,7 @@ func TestSubscribeEnvelopesFromEmptyCursor(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")},
Topics: []db.Topic{topicA, topicC},
LastSeen: &envelopes.VectorClock{NodeIdToSequenceId: map[uint32]uint64{}},
},
},
Expand All @@ -268,7 +274,7 @@ func TestSubscribeEnvelopesInvalidRequest(t *testing.T) {
context.Background(),
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA")},
Topics: []db.Topic{topicA},
OriginatorNodeIds: []uint32{1},
LastSeen: nil,
},
Expand Down
Loading

0 comments on commit 51cbe4f

Please sign in to comment.