diff --git a/cmd/starknet-id/store.go b/cmd/starknet-id/store.go index c81a439..205c106 100644 --- a/cmd/starknet-id/store.go +++ b/cmd/starknet-id/store.go @@ -90,11 +90,11 @@ func (s Store) Save(ctx context.Context, blockCtx *BlockContext) error { return nil } -func (s Store) saveAddresses(ctx context.Context, tx sdk.Transaction, blockCtx *BlockContext) error { +func (s Store) saveAddresses(ctx context.Context, tx postgres.Transaction, blockCtx *BlockContext) error { if blockCtx.addresses.Len() == 0 { return nil } - addresses := make([]any, 0) + addresses := make([]*storage.Address, 0) if err := blockCtx.addresses.Range(func(k string, v *storage.Address) (bool, error) { addresses = append(addresses, v) return false, nil @@ -102,7 +102,7 @@ func (s Store) saveAddresses(ctx context.Context, tx sdk.Transaction, blockCtx * return err } - if err := tx.BulkSave(ctx, addresses); err != nil { + if err := tx.SaveAddress(ctx, addresses...); err != nil { return errors.Wrap(err, "saving addresses") } return nil diff --git a/internal/storage/postgres/storage_test.go b/internal/storage/postgres/storage_test.go index 0ae2a18..319a1d3 100644 --- a/internal/storage/postgres/storage_test.go +++ b/internal/storage/postgres/storage_test.go @@ -141,6 +141,69 @@ func (s *StorageTestSuite) TestTxSaveState() { s.Require().EqualValues(101, response.LastHeight) } +func (s *StorageTestSuite) TestTxSaveAddress() { + ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer ctxCancel() + + tx, err := BeginTransaction(ctx, s.storage.Transactable) + s.Require().NoError(err) + defer tx.Close(ctx) + + classId := uint64(1) + err = tx.SaveAddress(ctx, &storage.Address{ + Id: 1, + Hash: []byte{1}, + Height: 100, + ClassId: &classId, + }) + s.Require().NoError(err) + + err = tx.Flush(ctx) + s.Require().NoError(err) + + address, err := s.storage.Addresses.GetByID(ctx, 1) + s.Require().NoError(err) + + s.Require().EqualValues(1, address.Id) + s.Require().EqualValues(100, address.Height) + s.Require().NotNil(address.ClassId) + s.Require().EqualValues(1, *address.ClassId) + s.Require().Equal([]byte{1}, address.Hash) +} + +func (s *StorageTestSuite) TestTxSaveAddressUpdate() { + ctx, ctxCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer ctxCancel() + + tx, err := BeginTransaction(ctx, s.storage.Transactable) + s.Require().NoError(err) + defer tx.Close(ctx) + + b, err := hex.DecodeString("020cfa74ee3564b4cd5435cdace0f9c4d43b939620e4a0bb5076105df0a626c6") + s.Require().NoError(err) + + classIdNew := uint64(2) + err = tx.SaveAddress(ctx, &storage.Address{ + Id: 2, + Hash: b, + Height: 101, + ClassId: &classIdNew, + }) + s.Require().NoError(err) + + err = tx.Flush(ctx) + s.Require().NoError(err) + + address, err := s.storage.Addresses.GetByID(ctx, 2) + s.Require().NoError(err) + + s.Require().EqualValues(2, address.Id) + s.Require().EqualValues(0, address.Height) + s.Require().NotNil(address.ClassId) + s.Require().EqualValues(2, *address.ClassId) + s.Require().Equal(b, address.Hash) +} + func TestSuiteStorage_Run(t *testing.T) { suite.Run(t, new(StorageTestSuite)) } diff --git a/internal/storage/postgres/transaction.go b/internal/storage/postgres/transaction.go index d72057b..5b19d4f 100644 --- a/internal/storage/postgres/transaction.go +++ b/internal/storage/postgres/transaction.go @@ -18,7 +18,7 @@ func BeginTransaction(ctx context.Context, tx storage.Transactable) (Transaction return Transaction{t}, err } -// SaveAddress - +// SaveState - func (t Transaction) SaveState(ctx context.Context, state *models.State) error { _, err := t.Tx().NewInsert().Model(state). On("CONFLICT (name) DO UPDATE"). @@ -27,3 +27,14 @@ func (t Transaction) SaveState(ctx context.Context, state *models.State) error { Exec(ctx) return err } + +func (t Transaction) SaveAddress(ctx context.Context, addresses ...*models.Address) error { + if len(addresses) == 0 { + return nil + } + _, err := t.Tx().NewInsert().Model(&addresses). + On("CONFLICT (id) DO UPDATE"). + Set("class_id = excluded.class_id"). + Exec(ctx) + return err +}