Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send envelopes in batches #253

Merged
merged 4 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pkg/api/message/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,24 @@ func (s *Service) sendEnvelopes(
NodeIdToSequenceId: cursor,
}
}
envsToSend := make([]*envelopesProto.OriginatorEnvelope, 0, len(envs))
for _, env := range envs {
if cursor[uint32(env.OriginatorNodeID())] >= env.OriginatorSequenceID() {
continue
}

// TODO(rich): Either batch send envelopes, or modify stream proto to
// send one envelope at a time.
err := stream.Send(&message_api.SubscribeEnvelopesResponse{
Envelopes: []*envelopesProto.OriginatorEnvelope{env.Proto()},
})
if err != nil {
return status.Errorf(codes.Internal, "error sending envelope: %v", err)
}
envsToSend = append(envsToSend, env.Proto())
cursor[uint32(env.OriginatorNodeID())] = env.OriginatorSequenceID()
}
if len(envsToSend) == 0 {
return nil
}
err := stream.Send(&message_api.SubscribeEnvelopesResponse{
richardhuaaa marked this conversation as resolved.
Show resolved Hide resolved
Envelopes: envsToSend,
})
if err != nil {
return status.Errorf(codes.Internal, "error sending envelopes: %v", err)
}
return nil
}

Expand Down
76 changes: 57 additions & 19 deletions pkg/api/message/subscribeWorker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package message
import (
"context"
"database/sql"
"encoding/hex"
"sync"
"time"

"github.com/xmtp/xmtpd/pkg/db"
"github.com/xmtp/xmtpd/pkg/db/queries"
"github.com/xmtp/xmtpd/pkg/envelopes"
"github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api"
"github.com/xmtp/xmtpd/pkg/topic"
"go.uber.org/zap"
)

Expand All @@ -31,6 +31,7 @@ type listener struct {

func newListener(
ctx context.Context,
logger *zap.Logger,
query *message_api.EnvelopesQuery,
ch chan<- []*envelopes.OriginatorEnvelope,
) *listener {
Expand All @@ -49,9 +50,14 @@ func newListener(
return l
}

for _, topic := range topics {
topicStr := hex.EncodeToString(topic)
l.topics[topicStr] = struct{}{}
for _, t := range topics {
validatedTopic, err := topic.ParseTopic(t)
if err != nil {
logger.Warn("Skipping invalid topic", zap.Binary("topicBytes", t))
continue
}
logger.Debug("Adding topic listener", zap.String("topic", validatedTopic.String()))
l.topics[validatedTopic.String()] = struct{}{}
}

for _, originator := range originators {
Expand Down Expand Up @@ -122,6 +128,12 @@ func (lm *listenersMap[K]) getListeners(key K) *listenerSet {
return nil
}

func (lm *listenersMap[K]) rangeKeys(fn func(key K, listeners *listenerSet) bool) {
lm.data.Range(func(key, value any) bool {
return fn(key.(K), value.(*listenerSet))
})
}

// A worker that listens for new envelopes in the DB and sends them to subscribers
// Assumes that there are many listeners - non-blocking updates are sent on buffered channels
// and may be dropped if full
Expand All @@ -142,6 +154,7 @@ func startSubscribeWorker(
store *sql.DB,
) (*subscribeWorker, error) {
log = log.Named("subscribeWorker")
log.Info("Starting subscribe worker")
q := queries.New(store)
pollableQuery := func(ctx context.Context, lastSeen db.VectorClock, numRows int32) ([]queries.GatewayEnvelope, db.VectorClock, error) {
envs, err := q.
Expand Down Expand Up @@ -198,32 +211,57 @@ func (s *subscribeWorker) start() {
case <-s.ctx.Done():
return
case new_batch := <-s.dbSubscription:
s.log.Debug("Received new batch", zap.Int("numEnvelopes", len(new_batch)))
envs := make([]*envelopes.OriginatorEnvelope, 0, len(new_batch))
for _, row := range new_batch {
s.dispatch(&row)
env, err := envelopes.NewOriginatorEnvelopeFromBytes(row.OriginatorEnvelope)
if err != nil {
s.log.Error("Failed to unmarshal envelope", zap.Error(err))
continue
}
envs = append(envs, env)
}
s.dispatchToOriginators(envs)
s.dispatchToTopics(envs)
s.dispatchToGlobals(envs)
}
}
}

func (s *subscribeWorker) dispatch(row *queries.GatewayEnvelope) {
env, err := envelopes.NewOriginatorEnvelopeFromBytes(row.OriginatorEnvelope)
if err != nil {
s.log.Error("Failed to unmarshal envelope", zap.Error(err))
return
func (s *subscribeWorker) dispatchToOriginators(envs []*envelopes.OriginatorEnvelope) {
// We use nested loops here because the number of originators is expected to be small
// Possible future optimization: Set up set up multiple DB subscriptions instead of one,
// and have the DB group by originator, topic, and global.
s.originatorListeners.rangeKeys(func(originator uint32, listeners *listenerSet) bool {
filteredEnvs := make([]*envelopes.OriginatorEnvelope, 0, len(envs))
for _, env := range envs {
if env.OriginatorNodeID() == originator {
filteredEnvs = append(filteredEnvs, env)
}
}
s.dispatchToListeners(listeners, filteredEnvs)
return true
})
}
richardhuaaa marked this conversation as resolved.
Show resolved Hide resolved

func (s *subscribeWorker) dispatchToTopics(envs []*envelopes.OriginatorEnvelope) {
// We iterate envelopes one-by-one, because we expect the number of envelopers
// per-topic to be small in each tick
for _, env := range envs {
listeners := s.topicListeners.getListeners(env.TargetTopic().String())
s.dispatchToListeners(listeners, []*envelopes.OriginatorEnvelope{env})
}
}
richardhuaaa marked this conversation as resolved.
Show resolved Hide resolved

originatorListeners := s.originatorListeners.getListeners(uint32(row.OriginatorNodeID))
topicListeners := s.topicListeners.getListeners(hex.EncodeToString(row.Topic))
s.dispatchToListeners(originatorListeners, env)
s.dispatchToListeners(topicListeners, env)
s.dispatchToListeners(&s.globalListeners, env)
func (s *subscribeWorker) dispatchToGlobals(envs []*envelopes.OriginatorEnvelope) {
s.dispatchToListeners(&s.globalListeners, envs)
}

func (s *subscribeWorker) dispatchToListeners(
listeners *listenerSet,
env *envelopes.OriginatorEnvelope,
envs []*envelopes.OriginatorEnvelope,
) {
if listeners == nil {
if listeners == nil || len(envs) == 0 {
return
}
listeners.Range(func(key, _ any) bool {
Expand All @@ -238,7 +276,7 @@ func (s *subscribeWorker) dispatchToListeners(
s.closeListener(l)
default:
select {
case l.ch <- []*envelopes.OriginatorEnvelope{env}:
case l.ch <- envs:
default:
s.log.Info("Channel full, removing listener", zap.Any("listener", l.ch))
s.closeListener(l)
Expand Down Expand Up @@ -269,7 +307,7 @@ func (s *subscribeWorker) listen(
query *message_api.EnvelopesQuery,
) <-chan []*envelopes.OriginatorEnvelope {
ch := make(chan []*envelopes.OriginatorEnvelope, subscriptionBufferSize)
l := newListener(ctx, query, ch)
l := newListener(ctx, s.log, query, ch)

if l.isGlobal {
s.globalListeners.Store(l, struct{}{})
Expand Down
36 changes: 21 additions & 15 deletions pkg/api/message/subscribe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ import (
"github.com/xmtp/xmtpd/pkg/testutils"
testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api"
envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes"
"github.com/xmtp/xmtpd/pkg/topic"
)

var (
topicA = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicA")).Bytes()
topicB = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicB")).Bytes()
topicC = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicC")).Bytes()
)
var allRows []queries.InsertGatewayEnvelopeParams

func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) {
Expand All @@ -25,47 +31,47 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func())
{
OriginatorNodeID: 1,
OriginatorSequenceID: 1,
Topic: []byte("topicA"),
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 1),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA),
),
},
{
OriginatorNodeID: 2,
OriginatorSequenceID: 1,
Topic: []byte("topicA"),
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 1),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA),
),
},
// Later rows
{
OriginatorNodeID: 1,
OriginatorSequenceID: 2,
Topic: []byte("topicB"),
Topic: topicB,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 2),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB),
),
},
{
OriginatorNodeID: 2,
OriginatorSequenceID: 2,
Topic: []byte("topicB"),
Topic: topicB,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 2),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB),
),
},
{
OriginatorNodeID: 1,
OriginatorSequenceID: 3,
Topic: []byte("topicA"),
Topic: topicA,
OriginatorEnvelope: testutils.Marshal(
t,
envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 3),
envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA),
),
},
}
Expand Down Expand Up @@ -140,7 +146,7 @@ func TestSubscribeEnvelopesByTopic(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")},
Topics: []db.Topic{topicA, topicC},
LastSeen: nil,
},
},
Expand Down Expand Up @@ -192,7 +198,7 @@ func TestSimultaneousSubscriptions(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicB")},
Topics: []db.Topic{topicB},
LastSeen: nil,
},
},
Expand Down Expand Up @@ -227,7 +233,7 @@ func TestSubscribeEnvelopesFromCursor(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")},
Topics: []db.Topic{topicA, topicC},
LastSeen: &envelopes.VectorClock{NodeIdToSequenceId: map[uint32]uint64{1: 1}},
},
},
Expand All @@ -249,7 +255,7 @@ func TestSubscribeEnvelopesFromEmptyCursor(t *testing.T) {
ctx,
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")},
Topics: []db.Topic{topicA, topicC},
LastSeen: &envelopes.VectorClock{NodeIdToSequenceId: map[uint32]uint64{}},
},
},
Expand All @@ -268,7 +274,7 @@ func TestSubscribeEnvelopesInvalidRequest(t *testing.T) {
context.Background(),
&message_api.SubscribeEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Topics: []db.Topic{db.Topic("topicA")},
Topics: []db.Topic{topicA},
OriginatorNodeIds: []uint32{1},
LastSeen: nil,
},
Expand Down
Loading
Loading