diff --git a/pkg/api/message/service.go b/pkg/api/message/service.go index f2751cfe..f94ba9ad 100644 --- a/pkg/api/message/service.go +++ b/pkg/api/message/service.go @@ -180,21 +180,24 @@ func (s *Service) sendEnvelopes( NodeIdToSequenceId: cursor, } } + envsToSend := make([]*envelopesProto.OriginatorEnvelope, 0, len(envs)) for _, env := range envs { if cursor[uint32(env.OriginatorNodeID())] >= env.OriginatorSequenceID() { continue } - // TODO(rich): Either batch send envelopes, or modify stream proto to - // send one envelope at a time. - err := stream.Send(&message_api.SubscribeEnvelopesResponse{ - Envelopes: []*envelopesProto.OriginatorEnvelope{env.Proto()}, - }) - if err != nil { - return status.Errorf(codes.Internal, "error sending envelope: %v", err) - } + envsToSend = append(envsToSend, env.Proto()) cursor[uint32(env.OriginatorNodeID())] = env.OriginatorSequenceID() } + if len(envsToSend) == 0 { + return nil + } + err := stream.Send(&message_api.SubscribeEnvelopesResponse{ + Envelopes: envsToSend, + }) + if err != nil { + return status.Errorf(codes.Internal, "error sending envelopes: %v", err) + } return nil } diff --git a/pkg/api/message/subscribeWorker.go b/pkg/api/message/subscribeWorker.go index 8d818d55..e963cf3a 100644 --- a/pkg/api/message/subscribeWorker.go +++ b/pkg/api/message/subscribeWorker.go @@ -3,7 +3,6 @@ package message import ( "context" "database/sql" - "encoding/hex" "sync" "time" @@ -11,6 +10,7 @@ import ( "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" + "github.com/xmtp/xmtpd/pkg/topic" "go.uber.org/zap" ) @@ -31,6 +31,7 @@ type listener struct { func newListener( ctx context.Context, + logger *zap.Logger, query *message_api.EnvelopesQuery, ch chan<- []*envelopes.OriginatorEnvelope, ) *listener { @@ -49,9 +50,14 @@ func newListener( return l } - for _, topic := range topics { - topicStr := hex.EncodeToString(topic) - l.topics[topicStr] = struct{}{} + for _, t := range topics { + validatedTopic, err := topic.ParseTopic(t) + if err != nil { + logger.Warn("Skipping invalid topic", zap.Binary("topicBytes", t)) + continue + } + logger.Debug("Adding topic listener", zap.String("topic", validatedTopic.String())) + l.topics[validatedTopic.String()] = struct{}{} } for _, originator := range originators { @@ -122,6 +128,12 @@ func (lm *listenersMap[K]) getListeners(key K) *listenerSet { return nil } +func (lm *listenersMap[K]) rangeKeys(fn func(key K, listeners *listenerSet) bool) { + lm.data.Range(func(key, value any) bool { + return fn(key.(K), value.(*listenerSet)) + }) +} + // A worker that listens for new envelopes in the DB and sends them to subscribers // Assumes that there are many listeners - non-blocking updates are sent on buffered channels // and may be dropped if full @@ -142,6 +154,7 @@ func startSubscribeWorker( store *sql.DB, ) (*subscribeWorker, error) { log = log.Named("subscribeWorker") + log.Info("Starting subscribe worker") q := queries.New(store) pollableQuery := func(ctx context.Context, lastSeen db.VectorClock, numRows int32) ([]queries.GatewayEnvelope, db.VectorClock, error) { envs, err := q. @@ -198,32 +211,57 @@ func (s *subscribeWorker) start() { case <-s.ctx.Done(): return case new_batch := <-s.dbSubscription: + s.log.Debug("Received new batch", zap.Int("numEnvelopes", len(new_batch))) + envs := make([]*envelopes.OriginatorEnvelope, 0, len(new_batch)) for _, row := range new_batch { - s.dispatch(&row) + env, err := envelopes.NewOriginatorEnvelopeFromBytes(row.OriginatorEnvelope) + if err != nil { + s.log.Error("Failed to unmarshal envelope", zap.Error(err)) + continue + } + envs = append(envs, env) } + s.dispatchToOriginators(envs) + s.dispatchToTopics(envs) + s.dispatchToGlobals(envs) } } } -func (s *subscribeWorker) dispatch(row *queries.GatewayEnvelope) { - env, err := envelopes.NewOriginatorEnvelopeFromBytes(row.OriginatorEnvelope) - if err != nil { - s.log.Error("Failed to unmarshal envelope", zap.Error(err)) - return +func (s *subscribeWorker) dispatchToOriginators(envs []*envelopes.OriginatorEnvelope) { + // We use nested loops here because the number of originators is expected to be small + // Possible future optimization: Set up set up multiple DB subscriptions instead of one, + // and have the DB group by originator, topic, and global. + s.originatorListeners.rangeKeys(func(originator uint32, listeners *listenerSet) bool { + filteredEnvs := make([]*envelopes.OriginatorEnvelope, 0, len(envs)) + for _, env := range envs { + if env.OriginatorNodeID() == originator { + filteredEnvs = append(filteredEnvs, env) + } + } + s.dispatchToListeners(listeners, filteredEnvs) + return true + }) +} + +func (s *subscribeWorker) dispatchToTopics(envs []*envelopes.OriginatorEnvelope) { + // We iterate envelopes one-by-one, because we expect the number of envelopers + // per-topic to be small in each tick + for _, env := range envs { + listeners := s.topicListeners.getListeners(env.TargetTopic().String()) + s.dispatchToListeners(listeners, []*envelopes.OriginatorEnvelope{env}) } +} - originatorListeners := s.originatorListeners.getListeners(uint32(row.OriginatorNodeID)) - topicListeners := s.topicListeners.getListeners(hex.EncodeToString(row.Topic)) - s.dispatchToListeners(originatorListeners, env) - s.dispatchToListeners(topicListeners, env) - s.dispatchToListeners(&s.globalListeners, env) +func (s *subscribeWorker) dispatchToGlobals(envs []*envelopes.OriginatorEnvelope) { + s.dispatchToListeners(&s.globalListeners, envs) } func (s *subscribeWorker) dispatchToListeners( listeners *listenerSet, - env *envelopes.OriginatorEnvelope, + envs []*envelopes.OriginatorEnvelope, ) { - if listeners == nil { + if listeners == nil || len(envs) == 0 { return } listeners.Range(func(key, _ any) bool { @@ -238,7 +276,7 @@ func (s *subscribeWorker) dispatchToListeners( s.closeListener(l) default: select { - case l.ch <- []*envelopes.OriginatorEnvelope{env}: + case l.ch <- envs: default: s.log.Info("Channel full, removing listener", zap.Any("listener", l.ch)) s.closeListener(l) @@ -269,7 +307,7 @@ func (s *subscribeWorker) listen( query *message_api.EnvelopesQuery, ) <-chan []*envelopes.OriginatorEnvelope { ch := make(chan []*envelopes.OriginatorEnvelope, subscriptionBufferSize) - l := newListener(ctx, query, ch) + l := newListener(ctx, s.log, query, ch) if l.isGlobal { s.globalListeners.Store(l, struct{}{}) diff --git a/pkg/api/message/subscribe_test.go b/pkg/api/message/subscribe_test.go index e79c4cc9..f9dd3ce2 100644 --- a/pkg/api/message/subscribe_test.go +++ b/pkg/api/message/subscribe_test.go @@ -15,8 +15,14 @@ import ( "github.com/xmtp/xmtpd/pkg/testutils" testUtilsApi "github.com/xmtp/xmtpd/pkg/testutils/api" envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes" + "github.com/xmtp/xmtpd/pkg/topic" ) +var ( + topicA = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicA")).Bytes() + topicB = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicB")).Bytes() + topicC = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicC")).Bytes() +) var allRows []queries.InsertGatewayEnvelopeParams func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) { @@ -25,47 +31,47 @@ func setupTest(t *testing.T) (message_api.ReplicationApiClient, *sql.DB, func()) { OriginatorNodeID: 1, OriginatorSequenceID: 1, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 1), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), ), }, { OriginatorNodeID: 2, OriginatorSequenceID: 1, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 1), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), ), }, // Later rows { OriginatorNodeID: 1, OriginatorSequenceID: 2, - Topic: []byte("topicB"), + Topic: topicB, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 2), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), ), }, { OriginatorNodeID: 2, OriginatorSequenceID: 2, - Topic: []byte("topicB"), + Topic: topicB, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 2), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), ), }, { OriginatorNodeID: 1, OriginatorSequenceID: 3, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 3), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), ), }, } @@ -140,7 +146,7 @@ func TestSubscribeEnvelopesByTopic(t *testing.T) { ctx, &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")}, + Topics: []db.Topic{topicA, topicC}, LastSeen: nil, }, }, @@ -192,7 +198,7 @@ func TestSimultaneousSubscriptions(t *testing.T) { ctx, &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicB")}, + Topics: []db.Topic{topicB}, LastSeen: nil, }, }, @@ -227,7 +233,7 @@ func TestSubscribeEnvelopesFromCursor(t *testing.T) { ctx, &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")}, + Topics: []db.Topic{topicA, topicC}, LastSeen: &envelopes.VectorClock{NodeIdToSequenceId: map[uint32]uint64{1: 1}}, }, }, @@ -249,7 +255,7 @@ func TestSubscribeEnvelopesFromEmptyCursor(t *testing.T) { ctx, &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA"), []byte("topicC")}, + Topics: []db.Topic{topicA, topicC}, LastSeen: &envelopes.VectorClock{NodeIdToSequenceId: map[uint32]uint64{}}, }, }, @@ -268,7 +274,7 @@ func TestSubscribeEnvelopesInvalidRequest(t *testing.T) { context.Background(), &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA")}, + Topics: []db.Topic{topicA}, OriginatorNodeIds: []uint32{1}, LastSeen: nil, }, diff --git a/pkg/api/query_test.go b/pkg/api/query_test.go index c1499772..b63d755e 100644 --- a/pkg/api/query_test.go +++ b/pkg/api/query_test.go @@ -13,6 +13,13 @@ import ( "github.com/xmtp/xmtpd/pkg/testutils" apiTestUtils "github.com/xmtp/xmtpd/pkg/testutils/api" envelopeTestUtils "github.com/xmtp/xmtpd/pkg/testutils/envelopes" + "github.com/xmtp/xmtpd/pkg/topic" +) + +var ( + topicA = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicA")).Bytes() + topicB = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicB")).Bytes() + topicC = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicC")).Bytes() ) func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopeParams { @@ -20,46 +27,46 @@ func setupQueryTest(t *testing.T, db *sql.DB) []queries.InsertGatewayEnvelopePar { OriginatorNodeID: 1, OriginatorSequenceID: 1, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 1), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 1, topicA), ), }, { OriginatorNodeID: 2, OriginatorSequenceID: 1, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 1), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 1, topicA), ), }, { OriginatorNodeID: 1, OriginatorSequenceID: 2, - Topic: []byte("topicB"), + Topic: topicB, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 2), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 2, topicB), ), }, { OriginatorNodeID: 2, OriginatorSequenceID: 2, - Topic: []byte("topicB"), + Topic: topicB, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 2, 2), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 2, 2, topicB), ), }, { OriginatorNodeID: 1, OriginatorSequenceID: 3, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: testutils.Marshal( t, - envelopeTestUtils.CreateOriginatorEnvelope(t, 1, 3), + envelopeTestUtils.CreateOriginatorEnvelopeWithTopic(t, 1, 3, topicA), ), }, } @@ -127,7 +134,7 @@ func TestQueryEnvelopesByTopic(t *testing.T) { context.Background(), &message_api.QueryEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA")}, + Topics: []db.Topic{topicA}, LastSeen: nil, }, Limit: 0, @@ -164,7 +171,7 @@ func TestQueryTopicFromLastSeen(t *testing.T) { context.Background(), &message_api.QueryEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA")}, + Topics: []db.Topic{topicA}, LastSeen: &envelopes.VectorClock{ NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1}, }, @@ -185,7 +192,7 @@ func TestQueryMultipleTopicsFromLastSeen(t *testing.T) { context.Background(), &message_api.QueryEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA"), db.Topic("topicB")}, + Topics: []db.Topic{topicA, topicB}, LastSeen: &envelopes.VectorClock{ NodeIdToSequenceId: map[uint32]uint64{1: 2, 2: 1}, }, @@ -227,7 +234,7 @@ func TestQueryEnvelopesWithEmptyResult(t *testing.T) { context.Background(), &message_api.QueryEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicC")}, + Topics: []db.Topic{topicC}, }, Limit: 0, }, @@ -245,7 +252,7 @@ func TestInvalidQuery(t *testing.T) { context.Background(), &message_api.QueryEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - Topics: []db.Topic{db.Topic("topicA")}, + Topics: []db.Topic{topicA}, OriginatorNodeIds: []uint32{1}, }, Limit: 0, diff --git a/pkg/db/subscription_test.go b/pkg/db/subscription_test.go index 03c8ff6f..f53e6d4b 100644 --- a/pkg/db/subscription_test.go +++ b/pkg/db/subscription_test.go @@ -11,9 +11,12 @@ import ( "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" "github.com/xmtp/xmtpd/pkg/testutils" + "github.com/xmtp/xmtpd/pkg/topic" "go.uber.org/zap" ) +var topicA = topic.NewTopic(topic.TOPIC_KIND_GROUP_MESSAGES_V1, []byte("topicA")).Bytes() + func setup(t *testing.T) (*sql.DB, *zap.Logger, func()) { ctx := context.Background() store, _, storeCleanup := testutils.NewDB(t, ctx) @@ -28,13 +31,13 @@ func insertInitialRows(t *testing.T, store *sql.DB) { { OriginatorNodeID: 1, OriginatorSequenceID: 1, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: []byte("envelope1"), }, { OriginatorNodeID: 2, OriginatorSequenceID: 1, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: []byte("envelope2"), }, }) @@ -62,19 +65,19 @@ func insertAdditionalRows(t *testing.T, store *sql.DB, notifyChan ...chan bool) { OriginatorNodeID: 1, OriginatorSequenceID: 2, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: []byte("envelope3"), }, { OriginatorNodeID: 2, OriginatorSequenceID: 2, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: []byte("envelope4"), }, { OriginatorNodeID: 1, OriginatorSequenceID: 3, - Topic: []byte("topicA"), + Topic: topicA, OriginatorEnvelope: []byte("envelope5"), }, }, notifyChan...) diff --git a/pkg/envelopes/originator.go b/pkg/envelopes/originator.go index 76286bcd..ad2335c0 100644 --- a/pkg/envelopes/originator.go +++ b/pkg/envelopes/originator.go @@ -4,6 +4,7 @@ import ( "errors" envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" + "github.com/xmtp/xmtpd/pkg/topic" "github.com/xmtp/xmtpd/pkg/utils" "google.golang.org/protobuf/proto" ) @@ -56,3 +57,7 @@ func (o *OriginatorEnvelope) OriginatorNodeID() uint32 { func (o *OriginatorEnvelope) OriginatorSequenceID() uint64 { return o.UnsignedOriginatorEnvelope.OriginatorSequenceID() } + +func (o *OriginatorEnvelope) TargetTopic() topic.Topic { + return o.UnsignedOriginatorEnvelope.TargetTopic() +} diff --git a/pkg/envelopes/payer.go b/pkg/envelopes/payer.go index 22ec7183..13cf638f 100644 --- a/pkg/envelopes/payer.go +++ b/pkg/envelopes/payer.go @@ -6,6 +6,7 @@ import ( "github.com/ethereum/go-ethereum/common" ethcrypto "github.com/ethereum/go-ethereum/crypto" envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" + "github.com/xmtp/xmtpd/pkg/topic" "github.com/xmtp/xmtpd/pkg/utils" "google.golang.org/protobuf/proto" ) @@ -55,3 +56,7 @@ func (p *PayerEnvelope) RecoverSigner() (*common.Address, error) { return &address, nil } + +func (p *PayerEnvelope) TargetTopic() topic.Topic { + return p.ClientEnvelope.TargetTopic() +} diff --git a/pkg/envelopes/unsignedOriginator.go b/pkg/envelopes/unsignedOriginator.go index a98517e7..bd4a3d6a 100644 --- a/pkg/envelopes/unsignedOriginator.go +++ b/pkg/envelopes/unsignedOriginator.go @@ -4,6 +4,7 @@ import ( "errors" envelopesProto "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" + "github.com/xmtp/xmtpd/pkg/topic" "github.com/xmtp/xmtpd/pkg/utils" ) @@ -55,3 +56,7 @@ func NewUnsignedOriginatorEnvelopeFromBytes(bytes []byte) (*UnsignedOriginatorEn func (u *UnsignedOriginatorEnvelope) Proto() *envelopesProto.UnsignedOriginatorEnvelope { return u.proto } + +func (u *UnsignedOriginatorEnvelope) TargetTopic() topic.Topic { + return u.PayerEnvelope.TargetTopic() +} diff --git a/pkg/testutils/envelopes/envelopes.go b/pkg/testutils/envelopes/envelopes.go index b1a41699..8a3eca25 100644 --- a/pkg/testutils/envelopes/envelopes.go +++ b/pkg/testutils/envelopes/envelopes.go @@ -119,3 +119,20 @@ func CreateOriginatorEnvelope( Proof: nil, } } + +func CreateOriginatorEnvelopeWithTopic( + t *testing.T, + originatorNodeID uint32, + originatorSequenceID uint64, + topic []byte, +) *envelopes.OriginatorEnvelope { + payerEnv := CreatePayerEnvelope(t, CreateClientEnvelope( + &envelopes.AuthenticatedData{ + TargetTopic: topic, + TargetOriginator: originatorNodeID, + LastSeen: nil, + }, + )) + + return CreateOriginatorEnvelope(t, originatorNodeID, originatorSequenceID, payerEnv) +}