diff --git a/pkg/api/publishWorker.go b/pkg/api/publishWorker.go new file mode 100644 index 00000000..aac28612 --- /dev/null +++ b/pkg/api/publishWorker.go @@ -0,0 +1,89 @@ +package api + +import ( + "context" + "database/sql" + "time" + + "github.com/xmtp/xmtpd/pkg/db" + "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/registrant" + "google.golang.org/protobuf/proto" +) + +type PublishWorker struct { + listener <-chan []queries.StagedOriginatorEnvelope + registrant *registrant.Registrant + store *sql.DB + subscription db.DBSubscription[queries.StagedOriginatorEnvelope] +} + +func StartPublishWorker( + ctx context.Context, + reg *registrant.Registrant, + store *sql.DB, + notifier <-chan bool, +) (*PublishWorker, error) { + query := func(lastSeenID int64, numRows int32) ([]queries.StagedOriginatorEnvelope, int64, error) { + results, err := queries.New(store).SelectStagedOriginatorEnvelopes( + ctx, + queries.SelectStagedOriginatorEnvelopesParams{ + LastSeenID: lastSeenID, + NumRows: numRows, + }, + ) + if err != nil { + return nil, 0, err + } + if len(results) > 0 { + lastSeenID = results[len(results)-1].ID + } + return results, lastSeenID, nil + } + subscription := db.NewDBSubscription( + query, + 0, // lastSeenID + db.PollingOptions{Interval: 5 * time.Second, Notifier: notifier, NumRows: 100}, + ) + listener, err := subscription.Start() + if err != nil { + return nil, err + } + + worker := &PublishWorker{ + subscription: *subscription, + listener: listener, + registrant: reg, + store: store, + } + go worker.start() + + return worker, nil +} + +func (p *PublishWorker) start() { + for new_batch := range p.listener { + for _, stagedEnv := range new_batch { + originatedEnv, sid, err := p.registrant.SignStagedEnvelope(stagedEnv) + if err != nil { + panic("TODO(rich)") + } + originatedBytes, err := proto.Marshal(originatedEnv) + if err != nil { + panic("TODO(rich)") + } + q := queries.New(p.store) + // TODO(rich) Verify context + // q.InsertGatewayEnvelope(context.Background(), queries.InsertGatewayEnvelopeParams{ + // OriginatorSid: sid, + // Topic: originatedBytes, + // OriginatorEnvelope: proto.Marshal(originatedEnv), + // }) + + // Start transaction + // Sign envelope + // Insert into all envelopes + // Delete envelope from staged_originator_envelopes + } + } +} diff --git a/pkg/api/service.go b/pkg/api/service.go index 32367558..3065cc8a 100644 --- a/pkg/api/service.go +++ b/pkg/api/service.go @@ -17,19 +17,33 @@ import ( type Service struct { message_api.UnimplementedReplicationApiServer - ctx context.Context - log *zap.Logger - registrant *registrant.Registrant - queries *queries.Queries + ctx context.Context + log *zap.Logger + notifyStagedPublish chan<- bool + registrant *registrant.Registrant + store *sql.DB + worker *PublishWorker } func NewReplicationApiService( ctx context.Context, log *zap.Logger, registrant *registrant.Registrant, - writerDB *sql.DB, + store *sql.DB, ) (*Service, error) { - return &Service{ctx: ctx, log: log, registrant: registrant, queries: queries.New(writerDB)}, nil + notifier := make(chan bool, 1) + // worker, err := StartPublishWorker(ctx, store, notifier) + // if err != nil { + // return nil, err + // } + return &Service{ + ctx: ctx, + log: log, + notifyStagedPublish: notifier, + registrant: registrant, + store: store, + // worker: worker, + }, nil } func (s *Service) Close() { @@ -71,12 +85,12 @@ func (s *Service) PublishEnvelope( return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err) } - stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes) + stagedEnv, err := queries.New(s.store).InsertStagedOriginatorEnvelope(ctx, payerBytes) if err != nil { return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err) } - originatorEnv, err := s.registrant.SignStagedEnvelope(stagedEnv) + originatorEnv, _, err := s.registrant.SignStagedEnvelope(stagedEnv) if err != nil { return nil, status.Errorf(codes.Internal, "could not sign envelope: %v", err) } diff --git a/pkg/registrant/registrant.go b/pkg/registrant/registrant.go index e642a7e1..7471124f 100644 --- a/pkg/registrant/registrant.go +++ b/pkg/registrant/registrant.go @@ -62,14 +62,14 @@ func (r *Registrant) signKeccak256(data []byte) ([]byte, error) { func (r *Registrant) SignStagedEnvelope( stagedEnv queries.StagedOriginatorEnvelope, -) (*message_api.OriginatorEnvelope, error) { +) (*message_api.OriginatorEnvelope, uint64, error) { payerEnv := &message_api.PayerEnvelope{} if err := proto.Unmarshal(stagedEnv.PayerEnvelope, payerEnv); err != nil { - return nil, fmt.Errorf("Could not unmarshal payer envelope: %v", err) + return nil, 0, fmt.Errorf("Could not unmarshal payer envelope: %v", err) } sid, err := r.sid(stagedEnv.ID) if err != nil { - return nil, err + return nil, 0, err } unsignedEnv := message_api.UnsignedOriginatorEnvelope{ OriginatorSid: sid, @@ -78,12 +78,12 @@ func (r *Registrant) SignStagedEnvelope( } unsignedBytes, err := proto.Marshal(&unsignedEnv) if err != nil { - return nil, err + return nil, 0, err } sig, err := r.signKeccak256(unsignedBytes) if err != nil { - return nil, err + return nil, 0, err } signedEnv := message_api.OriginatorEnvelope{ @@ -95,7 +95,7 @@ func (r *Registrant) SignStagedEnvelope( }, } - return &signedEnv, nil + return &signedEnv, sid, nil } func getRegistryRecord( diff --git a/pkg/registrant/registrant_test.go b/pkg/registrant/registrant_test.go index 858767c2..17164417 100644 --- a/pkg/registrant/registrant_test.go +++ b/pkg/registrant/registrant_test.go @@ -188,7 +188,7 @@ func TestSignStagedEnvelopeInvalidEnvelope(t *testing.T) { _, r, cleanup := setupWithRegistrant(t) defer cleanup() - _, err := r.SignStagedEnvelope( + _, _, err := r.SignStagedEnvelope( queries.StagedOriginatorEnvelope{ ID: 1, OriginatorTime: time.Now(), @@ -205,7 +205,7 @@ func TestSignStagedEnvelopeSIDExhaustion(t *testing.T) { payerBytes, err := proto.Marshal(&message_api.PayerEnvelope{}) require.NoError(t, err) - _, err = r.SignStagedEnvelope( + _, _, err = r.SignStagedEnvelope( queries.StagedOriginatorEnvelope{ ID: 0b0000000000000001000000000000000000000000000000000000000000000000, OriginatorTime: time.Now(), @@ -224,7 +224,7 @@ func TestSignStagedEnvelopeSuccess(t *testing.T) { ) require.NoError(t, err) - env, err := r.SignStagedEnvelope( + env, _, err := r.SignStagedEnvelope( queries.StagedOriginatorEnvelope{ ID: 50, OriginatorTime: time.Now(),