Skip to content

Commit

Permalink
Update keyPath to ceil(maxLvl/8),add chck len(key)
Browse files Browse the repository at this point in the history
Update keyPath to be ceil(maxLevels/8), add check of key len on Add,
AddBatch, Update, GenProof, Get.
  • Loading branch information
arnaucube committed Sep 30, 2021
1 parent 9eb7c8e commit 97a223b
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 32 deletions.
49 changes: 29 additions & 20 deletions tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,13 +311,15 @@ func (t *Tree) AddWithTx(wTx db.WriteTx, k, v []byte) error {
}

func (t *Tree) add(wTx db.WriteTx, root []byte, fromLvl int, k, v []byte) ([]byte, error) {
keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return nil, fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
if len(k) < len(keyPath) {
return nil, fmt.Errorf("len(k) < ceil(maxLevels/8), where"+
" len(k): %d, maxLevels: %d, ceil(maxLevels/8): %d", len(k),
t.maxLevels, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
}
copy(keyPath[:], k)

path := getPath(t.maxLevels, keyPath)

// go down to the leaf
var siblings [][]byte
_, _, siblings, err := t.down(wTx, k, root, siblings, path, fromLvl, false)
Expand Down Expand Up @@ -595,10 +597,12 @@ func (t *Tree) UpdateWithTx(wTx db.WriteTx, k, v []byte) error {

var err error

keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
if len(k) < len(keyPath) {
return fmt.Errorf("len(k) < ceil(maxLevels/8), where"+
" len(k): %d, maxLevels: %d, ceil(maxLevels/8): %d", len(k),
t.maxLevels, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
}
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)

Expand Down Expand Up @@ -655,18 +659,21 @@ func (t *Tree) GenProof(k []byte) ([]byte, []byte, []byte, bool, error) {
// GenProofWithTx does the same than the GenProof method, but allowing to pass
// the db.ReadTx that is used.
func (t *Tree) GenProofWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, []byte, bool, error) {
keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return nil, nil, nil, false, fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
if len(k) < len(keyPath) {
return nil, nil, nil, false,
fmt.Errorf("len(k) < ceil(maxLevels/8), where"+
" len(k): %d, maxLevels: %d, ceil(maxLevels/8): %d", len(k),
t.maxLevels, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
}
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)

root, err := t.RootWithTx(rTx)
if err != nil {
return nil, nil, nil, false, err
}

path := getPath(t.maxLevels, keyPath)
// go down to the leaf
var siblings [][]byte
_, value, siblings, err := t.down(rTx, k, root, siblings, path, 0, true)
Expand Down Expand Up @@ -793,18 +800,20 @@ func (t *Tree) Get(k []byte) ([]byte, []byte, error) {
// ErrKeyNotFound, and in the leafK & leafV parameters will be placed the data
// found in the tree in the leaf that was on the path going to the input key.
func (t *Tree) GetWithTx(rTx db.ReadTx, k []byte) ([]byte, []byte, error) {
keyPath := make([]byte, t.hashFunction.Len())
// if len(k) > t.hashFunction.Len() { // WIP
// return nil, nil, fmt.Errorf("len(k) > hashFunction.Len()")
// }
keyPath := make([]byte, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
if len(k) < len(keyPath) {
return nil, nil, fmt.Errorf("len(k) < ceil(maxLevels/8), where"+
" len(k): %d, maxLevels: %d, ceil(maxLevels/8): %d", len(k),
t.maxLevels, int(math.Ceil(float64(t.maxLevels)/float64(8)))) //nolint:gomnd
}
copy(keyPath[:], k)
path := getPath(t.maxLevels, keyPath)

root, err := t.RootWithTx(rTx)
if err != nil {
return nil, nil, err
}

path := getPath(t.maxLevels, keyPath)
// go down to the leaf
var siblings [][]byte
_, value, _, err := t.down(rTx, k, root, siblings, path, 0, true)
Expand All @@ -827,7 +836,7 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
return false, err
}

keyPath := make([]byte, hashFunc.Len())
keyPath := make([]byte, len(siblings))
copy(keyPath[:], k)

key, _, err := newLeafValue(hashFunc, k, v)
Expand Down
41 changes: 35 additions & 6 deletions tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,9 @@ func TestSetRoot(t *testing.T) {
checkRootBIString(c, tree, expectedRoot)

// check that the tree can be updated
err = tree.Add([]byte("test"), []byte("test"))
err = tree.Add([]byte("testtesttestt"), []byte("test"))
c.Assert(err, qt.IsNil)
err = tree.Update([]byte("test"), []byte("test"))
err = tree.Update([]byte("testtesttestt"), []byte("test"))
c.Assert(err, qt.IsNil)

// check that the k-v '1000' does not exist in the new tree
Expand Down Expand Up @@ -601,22 +601,22 @@ func TestSnapshot(t *testing.T) {
// check that the snapshotTree can not be updated
_, err = snapshotTree.AddBatch(keys, values)
c.Assert(err, qt.Equals, ErrSnapshotNotEditable)
err = snapshotTree.Add([]byte("test"), []byte("test"))
err = snapshotTree.Add([]byte("testtesttestt"), []byte("test"))
c.Assert(err, qt.Equals, ErrSnapshotNotEditable)
err = snapshotTree.Update([]byte("test"), []byte("test"))
err = snapshotTree.Update([]byte("testtesttestt"), []byte("test"))
c.Assert(err, qt.Equals, ErrSnapshotNotEditable)
err = snapshotTree.ImportDump(nil)
c.Assert(err, qt.Equals, ErrSnapshotNotEditable)

// update the original tree by adding a new key-value, and check that
// snapshotTree still has the old root, but the original tree has a new
// root
err = tree.Add([]byte("test"), []byte("test"))
err = tree.Add([]byte("testtesttestt"), []byte("test"))
c.Assert(err, qt.IsNil)
checkRootBIString(c, snapshotTree,
"13742386369878513332697380582061714160370929283209286127733983161245560237407")
checkRootBIString(c, tree,
"1025190963769001718196479367844646783678188389989148142691917685159698888868")
"11012181874962502280319215908999587695675615549060172228896013220013270140826")
}

func TestGetFromSnapshotExpectArboErrKeyNotFound(t *testing.T) {
Expand All @@ -640,6 +640,35 @@ func TestGetFromSnapshotExpectArboErrKeyNotFound(t *testing.T) {
c.Assert(err, qt.Equals, ErrKeyNotFound) // and not equal to db.ErrKeyNotFound
}

func TestKeyLen(t *testing.T) {
c := qt.New(t)
database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
c.Assert(err, qt.IsNil)
// maxLevels is 100, minimum key size = ceil(maxLevels/8) = 13
maxLevels := 100
tree, err := NewTree(database, maxLevels, HashFunctionPoseidon)
c.Assert(err, qt.IsNil)

bLen := 4
incorrectK := BigIntToBytes(bLen, big.NewInt(1))
v := BigIntToBytes(bLen, big.NewInt(1))

expectedErrMsg := "len(k) < ceil(maxLevels/8), where len(k): 4," +
" maxLevels: 100, ceil(maxLevels/8): 13"

err = tree.Add(incorrectK, v)
c.Assert(err.Error(), qt.Equals, expectedErrMsg)

err = tree.Update(incorrectK, v)
c.Assert(err.Error(), qt.Equals, expectedErrMsg)

_, _, _, _, err = tree.GenProof(incorrectK)
c.Assert(err.Error(), qt.Equals, expectedErrMsg)

_, _, err = tree.Get(incorrectK)
c.Assert(err.Error(), qt.Equals, expectedErrMsg)
}

func BenchmarkAdd(b *testing.B) {
bLen := 32 // for both Poseidon & Sha256
// prepare inputs
Expand Down
22 changes: 16 additions & 6 deletions vt.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,33 @@ type kv struct {
v []byte
}

func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, error) {
func (p *params) keysValuesToKvs(ks, vs [][]byte) ([]kv, []int, error) {
if len(ks) != len(vs) {
return nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
return nil, nil, fmt.Errorf("len(keys)!=len(values) (%d!=%d)",
len(ks), len(vs))
}
kvs := make([]kv, len(ks))
var invalids []int
for i := 0; i < len(ks); i++ {
keyPath := make([]byte, p.hashFunction.Len())
keyPath := make([]byte, int(math.Ceil(float64(p.maxLevels)/float64(8)))) //nolint:gomnd
if len(ks[i]) < len(keyPath) {
// TODO in a future iteration, invalids will contain
// the reason of the error of why each index is invalid.
// invalid reason = fmt.Errorf("len(k) < ceil(maxLevels/8), where"+
// " len(k): %d, maxLevels: %d, ceil(maxLevels/8): %d", len(ks[i]),
// p.maxLevels, int(math.Ceil(float64(p.maxLevels)/float64(8))))
invalids = append(invalids, i)
continue
}

copy(keyPath[:], ks[i])
kvs[i].pos = i
kvs[i].keyPath = keyPath
kvs[i].k = ks[i]
kvs[i].v = vs[i]
}

return kvs, nil
return kvs, invalids, nil
}

// vt stands for virtual tree. It's a tree that does not have any computed hash
Expand Down Expand Up @@ -94,7 +105,7 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) {

l := int(math.Log2(float64(nCPU)))

kvs, err := t.params.keysValuesToKvs(ks, vs)
kvs, invalids, err := t.params.keysValuesToKvs(ks, vs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -186,7 +197,6 @@ func (t *vt) addBatch(ks, vs [][]byte) ([]int, error) {
}
wg.Wait()

var invalids []int
for i := 0; i < len(invalidsInBucket); i++ {
invalids = append(invalids, invalidsInBucket[i]...)
}
Expand Down

0 comments on commit 97a223b

Please sign in to comment.