Skip to content

Commit

Permalink
feat: Decode witness to SMT (#1363)
Browse files Browse the repository at this point in the history
* feat: Decode witness to SMT

* chore: warning fixes and simplifications

* test: use require in witness unit tests

* feat: simplifications in SMT state reader

* fix: address comment

* test: use requires

* Allocate array in getValueInBytes
  • Loading branch information
Stefan-Ethernal authored Nov 4, 2024
1 parent d1079c0 commit df6709f
Show file tree
Hide file tree
Showing 6 changed files with 628 additions and 71 deletions.
14 changes: 14 additions & 0 deletions smt/pkg/db/mdbx.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db

import (
"context"
"encoding/hex"
"math/big"

"fmt"
Expand Down Expand Up @@ -304,6 +305,19 @@ func (m *EriRoDb) GetCode(codeHash []byte) ([]byte, error) {
return data, nil
}

func (m *EriDb) AddCode(code []byte) error {
codeHash := utils.HashContractBytecode(hex.EncodeToString(code))

codeHashBytes, err := hex.DecodeString(strings.TrimPrefix(codeHash, "0x"))
if err != nil {
return err
}

codeHashBytes = utils.ResizeHashTo32BytesByPrefixingWithZeroes(codeHashBytes)

return m.tx.Put(kv.Code, codeHashBytes, code)
}

func (m *EriRoDb) PrintDb() {
err := m.kvTxRo.ForEach(TableSmt, []byte{}, func(k, v []byte) error {
println(string(k), string(v))
Expand Down
69 changes: 41 additions & 28 deletions smt/pkg/smt/entity_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,55 @@ import (
"github.com/ledgerwatch/erigon/smt/pkg/utils"
)

// SetAccountState sets the balance and nonce of an account
func (s *SMT) SetAccountState(ethAddr string, balance, nonce *big.Int) (*big.Int, error) {
_, err := s.SetAccountBalance(ethAddr, balance)
if err != nil {
return nil, err
}

auxOut, err := s.SetAccountNonce(ethAddr, nonce)
if err != nil {
return nil, err
}

return auxOut, nil
}

// SetAccountBalance sets the balance of an account
func (s *SMT) SetAccountBalance(ethAddr string, balance *big.Int) (*big.Int, error) {
keyBalance := utils.KeyEthAddrBalance(ethAddr)
keyNonce := utils.KeyEthAddrNonce(ethAddr)

if _, err := s.InsertKA(keyBalance, balance); err != nil {
response, err := s.InsertKA(keyBalance, balance)
if err != nil {
return nil, err
}

ks := utils.EncodeKeySource(utils.KEY_BALANCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
if err := s.Db.InsertKeySource(keyBalance, ks); err != nil {
err = s.Db.InsertKeySource(keyBalance, ks)
if err != nil {
return nil, err
}

auxRes, err := s.InsertKA(keyNonce, nonce)
return response.NewRootScalar.ToBigInt(), err
}

// SetAccountNonce sets the nonce of an account
func (s *SMT) SetAccountNonce(ethAddr string, nonce *big.Int) (*big.Int, error) {
keyNonce := utils.KeyEthAddrNonce(ethAddr)

response, err := s.InsertKA(keyNonce, nonce)
if err != nil {
return nil, err
}

ks = utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
if err := s.Db.InsertKeySource(keyNonce, ks); err != nil {
ks := utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
err = s.Db.InsertKeySource(keyNonce, ks)
if err != nil {
return nil, err
}

return auxRes.NewRootScalar.ToBigInt(), nil
return response.NewRootScalar.ToBigInt(), nil
}

func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) error {
Expand Down Expand Up @@ -80,13 +105,7 @@ func (s *SMT) SetContractBytecode(ethAddr string, bytecode string) error {

ks = utils.EncodeKeySource(utils.SC_LENGTH, utils.ConvertHexToAddress(ethAddr), common.Hash{})

err = s.Db.InsertKeySource(keyContractLength, ks)

if err != nil {
return err
}

return err
return s.Db.InsertKeySource(keyContractLength, ks)
}

func (s *SMT) SetContractStorage(ethAddr string, storage map[string]string, progressChan chan uint64) (*big.Int, error) {
Expand Down Expand Up @@ -203,7 +222,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
for addr, acc := range accChanges {
select {
case <-ctx.Done():
return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix))
return nil, nil, fmt.Errorf("[%s] Context done", logPrefix)
default:
}
ethAddr := addr.String()
Expand Down Expand Up @@ -250,7 +269,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
for addr, code := range codeChanges {
select {
case <-ctx.Done():
return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix))
return nil, nil, fmt.Errorf("[%s] Context done", logPrefix)
default:
}

Expand Down Expand Up @@ -295,7 +314,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
for addr, storage := range storageChanges {
select {
case <-ctx.Done():
return nil, nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix))
return nil, nil, fmt.Errorf("[%s] Context done", logPrefix)
default:
}
ethAddr := addr.String()
Expand All @@ -304,7 +323,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l

for k, v := range storage {
keyStoragePosition := utils.KeyContractStorage(ethAddrBigIngArray, k)
valueBigInt := convertStrintToBigInt(v)
valueBigInt := convertStringToBigInt(v)
keysBatchStorage = append(keysBatchStorage, &keyStoragePosition)
if valuesBatchStorage, isDelete, err = appendToValuesBatchStorageBigInt(valuesBatchStorage, valueBigInt); err != nil {
return nil, nil, err
Expand Down Expand Up @@ -341,7 +360,7 @@ func (s *SMT) DeleteKeySource(nodeKey *utils.NodeKey) error {
}

func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) {
val := convertStrintToBigInt(v)
val := convertStringToBigInt(v)

x := utils.ScalarToArrayBig(val)
value, err := utils.NodeValue8FromBigIntArray(x)
Expand All @@ -354,10 +373,10 @@ func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) {
return value, h, nil
}

func convertStrintToBigInt(v string) *big.Int {
func convertStringToBigInt(v string) *big.Int {
base := 10
if strings.HasPrefix(v, "0x") {
v = v[2:]
v = strings.TrimPrefix(v, "0x")
base = 16
}

Expand All @@ -374,14 +393,8 @@ func appendToValuesBatchStorageBigInt(valuesBatchStorage []*utils.NodeValue8, va
}

func convertBytecodeToBigInt(bytecode string) (*big.Int, int, error) {
var parsedBytecode string
bi := utils.HashContractBytecodeBigInt(bytecode)

if strings.HasPrefix(bytecode, "0x") {
parsedBytecode = bytecode[2:]
} else {
parsedBytecode = bytecode
}
parsedBytecode := strings.TrimPrefix(bytecode, "0x")

if len(parsedBytecode)%2 != 0 {
parsedBytecode = "0" + parsedBytecode
Expand Down
94 changes: 92 additions & 2 deletions smt/pkg/smt/smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type DB interface {
InsertKeySource(key utils.NodeKey, value []byte) error
DeleteKeySource(key utils.NodeKey) error
InsertHashKey(key utils.NodeKey, value utils.NodeKey) error
AddCode(code []byte) error
DeleteHashKey(key utils.NodeKey) error
Delete(string) error
DeleteByNodeKey(key utils.NodeKey) error
Expand Down Expand Up @@ -297,7 +298,9 @@ func (s *SMT) insert(k utils.NodeKey, v utils.NodeValue8, newValH [4]uint64, old
if err != nil {
return nil, err
}
s.Db.InsertHashKey(newLeafHash, k)
if err := s.Db.InsertHashKey(newLeafHash, k); err != nil {
return nil, err
}
if level >= 0 {
for j := 0; j < 4; j++ {
siblings[level][keys[level]*4+j] = new(big.Int).SetUint64(newLeafHash[j])
Expand Down Expand Up @@ -649,7 +652,7 @@ func (s *SMT) updateDepth(newDepth int) {

newDepthAsByte := byte(newDepth & 0xFF)
if oldDepth < newDepthAsByte {
s.Db.SetDepth(newDepthAsByte)
_ = s.Db.SetDepth(newDepthAsByte)
}
}

Expand Down Expand Up @@ -728,3 +731,90 @@ func (s *RoSMT) traverseAndMark(ctx context.Context, node *big.Int, visited Visi
return true, nil
})
}

// InsertHashNode inserts a hash node into the SMT. The SMT should not contain any other leaf nodes with the same path prefix. Otherwise, the new root hash will be incorrect.
// TODO: Support insertion of hash nodes even if there are leaf nodes with the same path prefix in SMT.
func (s *SMT) InsertHashNode(path []int, hash *big.Int) (*big.Int, error) {
s.clearUpMutex.Lock()
defer s.clearUpMutex.Unlock()

or, err := s.getLastRoot()
if err != nil {
return nil, err
}

h := utils.ScalarToArray(hash)

var nodeHash [4]uint64
copy(nodeHash[:], h[:4])

lastRoot, err := s.insertHashNode(path, nodeHash, or)
if err != nil {
return nil, err
}

if err = s.setLastRoot(lastRoot); err != nil {
return nil, err
}

return lastRoot.ToBigInt(), nil
}

func (s *SMT) insertHashNode(path []int, hash [4]uint64, root utils.NodeKey) (utils.NodeKey, error) {
if len(path) == 0 {
newValHBig := utils.ArrayToScalar(hash[:])
v := utils.ScalarToNodeValue8(newValHBig)

err := s.hashSave(v.ToUintArray(), utils.LeafCapacity, hash)
if err != nil {
return utils.NodeKey{}, err
}

return hash, nil
}

rootVal := utils.NodeValue12{}

if !root.IsZero() {
v, err := s.Db.Get(root)
if err != nil {
return utils.NodeKey{}, err
}

rootVal = v
}

childIndex := path[0]

childOldRoot := rootVal[childIndex*4 : childIndex*4+4]

childNewRoot, err := s.insertHashNode(path[1:], hash, utils.NodeKeyFromBigIntArray(childOldRoot))

if err != nil {
return utils.NodeKey{}, err
}

var newIn [8]uint64

emptyRootVal := utils.NodeValue12{}

if childIndex == 0 {
var sibling [4]uint64
if rootVal == emptyRootVal {
sibling = [4]uint64{0, 0, 0, 0}
} else {
sibling = *rootVal.Get4to8()
}
newIn = utils.ConcatArrays4(childNewRoot, sibling)
} else {
var sibling [4]uint64
if rootVal == emptyRootVal {
sibling = [4]uint64{0, 0, 0, 0}
} else {
sibling = *rootVal.Get0to4()
}
newIn = utils.ConcatArrays4(sibling, childNewRoot)
}

return s.hashcalcAndSave(newIn, utils.BranchCapacity)
}
Loading

0 comments on commit df6709f

Please sign in to comment.