diff --git a/bitmap.go b/bitmap.go new file mode 100644 index 0000000..a404293 --- /dev/null +++ b/bitmap.go @@ -0,0 +1,115 @@ +package hamt + +import ( + "fmt" + "math" + "math/bits" +) + +// Bitmap is a managed bitmap, primarily for the purpose of tracking the +// presence or absence of elements in an associated array. It can set and unset +// individual bits and perform limited popcount for a given index to calculate +// the position in the associated compacted array. +type Bitmap struct { + Bytes []byte +} + +// NewBitmap creates a new bitmap for a given bitWidth. The bitmap will hold +// 2^bitWidth bytes. +func NewBitmap(bitWidth int) *Bitmap { + bc := (1 << uint(bitWidth)) / 8 + if bc == 0 { + panic("bitWidth too small") + } + + return NewBitmapFrom(make([]byte, bc)) +} + +// NewBitmapFrom creates a new Bitmap from an existing byte array. It is +// assumed that bytes is the correct length for the bitWidth of this Bitmap. +func NewBitmapFrom(bytes []byte) *Bitmap { + if len(bytes) == 0 { + panic("can't form Bitmap from zero bytes") + } + bm := Bitmap{Bytes: bytes} + return &bm +} + +// BitWidth calculates the bitWidth of this Bitmap by performing a +// log2(bits). The bitWidth is the minimum number of bits required to +// form indexes that address all of this Bitmap. e.g. a bitWidth of 5 can form +// indexes of 0 to 31, i.e. 4 bytes. +func (bm *Bitmap) BitWidth() int { + return int(math.Log2(float64(len(bm.Bytes) * 8))) +} + +func (bm *Bitmap) bindex(in int) int { + // Return `in` to flip the byte addressing order to LE. For BE we address + // from the last byte backward. + bi := len(bm.Bytes) - 1 - in + if bi > len(bm.Bytes) || bi < 0 { + panic(fmt.Sprintf("invalid index for this Bitmap (index: %v, bytes: %v)", in, len(bm.Bytes))) + } + return bi +} + +// IsSet indicates whether the bit at the provided position is set or not. +func (bm *Bitmap) IsSet(position int) bool { + byt := bm.bindex(position / 8) + offset := position % 8 + return (bm.Bytes[byt]>>offset)&1 == 1 +} + +// Set sets or unsets the bit at the given position according. If set is true, +// the bit will be set. If set is false, the bit will be unset. +func (bm *Bitmap) Set(position int, set bool) { + has := bm.IsSet(position) + byt := bm.bindex(position / 8) + offset := position % 8 + + if set && !has { + bm.Bytes[byt] |= 1 << offset + } else if !set && has { + bm.Bytes[byt] ^= 1 << offset + } +} + +// Index performs a limited popcount up to the given position. This calculates +// the number of set bits up to the index of the bitmap. Useful for calculating +// the position of an element in an associated compacted array. +func (bm *Bitmap) Index(position int) int { + t := 0 + eb := position / 8 + byt := 0 + for ; byt < eb; byt++ { + // quick popcount for the full bytes + t += bits.OnesCount(uint(bm.Bytes[bm.bindex(byt)])) + } + eb = eb * 8 + if position > eb { + for i := byt * 8; i < position; i++ { + // manual per-bit check for the remainder <8 bits + if bm.IsSet(i) { + t++ + } + } + } + return t +} + +// Copy creates a clone of the Bitmap, creating a new byte array with the same +// contents as the original. +func (bm *Bitmap) Copy() *Bitmap { + ba := make([]byte, len(bm.Bytes)) + copy(ba, bm.Bytes) + return NewBitmapFrom(ba) +} + +// BitsSetCount counts how many bits are set in the bitmap. +func (bm *Bitmap) BitsSetCount() int { + count := 0 + for _, b := range bm.Bytes { + count += bits.OnesCount(uint(b)) + } + return count +} diff --git a/bitmap_test.go b/bitmap_test.go new file mode 100644 index 0000000..30918ae --- /dev/null +++ b/bitmap_test.go @@ -0,0 +1,250 @@ +package hamt + +import ( + "bytes" + "testing" +) + +// many cases taken from https://github.com/rvagg/iamap/blob/fad95295b013c8b4f0faac6dd5d9be175f6e606c/test/bit-utils-test.js +// but rev() is used to reverse the data in most instances + +// reverse for BE format +func rev(in []byte) []byte { + out := make([]byte, len(in)) + for i := 0; i < len(in); i++ { + out[len(in)-1-i] = in[i] + } + return out +} + +// 8-char binary string to byte, no binary literals in old Go +func bb(s string) byte { + var r byte + for i, c := range s { + if c == '1' { + r |= 1 << uint(7-i) + } + } + return r +} + +func TestBitmapHas(t *testing.T) { + type tcase struct { + bytes []byte + pos int + set bool + } + cases := []tcase{ + {b(0x0), 0, false}, + {b(0x1), 0, true}, + {b(bb("00101010")), 2, false}, + {b(bb("00101010")), 3, true}, + {b(bb("00101010")), 4, false}, + {b(bb("00101010")), 5, true}, + {b(bb("00100000")), 5, true}, + {[]byte{0x0, bb("00100000")}, 8 + 5, true}, + {[]byte{0x0, 0x0, bb("00100000")}, 8*2 + 5, true}, + {[]byte{0x0, 0x0, 0x0, bb("00100000")}, 8*3 + 5, true}, + {[]byte{0x0, 0x0, 0x0, 0x0, bb("00100000")}, 8*4 + 5, true}, + {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, bb("00100000")}, 8*5 + 5, true}, + {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, bb("00100000")}, 8*4 + 5, false}, + {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, bb("00100000")}, 8*3 + 5, false}, + {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, bb("00100000")}, 8*2 + 5, false}, + {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, bb("00100000")}, 8 + 5, false}, + {[]byte{0x0, 0x0, 0x0, 0x0, 0x0, bb("00100000")}, 5, false}, + } + + for _, c := range cases { + bm := NewBitmapFrom(rev(c.bytes)) + if bm.IsSet(c.pos) != c.set { + t.Fatalf("bitmap %v IsSet(%v) should be %v", c.bytes, c.pos, c.set) + } + } +} + +func TestBitmapBitWidth(t *testing.T) { + for i := 3; i <= 16; i++ { + if NewBitmap(i).BitWidth() != i { + t.Fatal("incorrect bitWidth calculation") + } + if NewBitmapFrom(make([]byte, (1< 0 { buf := make([]byte, extra) if _, err := io.ReadFull(br, buf); err != nil { return err } - t.Bitfield = big.NewInt(0).SetBytes(buf) - } else { - t.Bitfield = big.NewInt(0) + t.Bitfield = NewBitmapFrom(buf) } // t.Pointers ([]*hamt.Pointer) (slice) diff --git a/hamt.go b/hamt.go index 6b0ba61..4dd06e0 100644 --- a/hamt.go +++ b/hamt.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "math/big" cid "github.com/ipfs/go-cid" cbor "github.com/ipfs/go-ipld-cbor" @@ -66,7 +65,7 @@ var ErrMalformedHamt = fmt.Errorf("HAMT node was malformed") // pointers [Pointer] // } representation tuple type Node struct { - Bitfield *big.Int + Bitfield *Bitmap Pointers []*Pointer bitWidth int @@ -181,17 +180,23 @@ func UseHashFunction(hash func([]byte) []byte) Option { // This function creates a new HAMT that you can use directly and is also // used internally to create child nodes. func NewNode(cs cbor.IpldStore, options ...Option) *Node { + return newNode(cs, defaultBitWidth, defaultHashFunction, options...) +} + +// internal version of NewNode that will properly set up defaults and correct +// size bitfield for the supplied parameters +func newNode(cs cbor.IpldStore, bitWidth int, hashFunction func([]byte) []byte, options ...Option) *Node { nd := &Node{ - Bitfield: big.NewInt(0), Pointers: make([]*Pointer, 0), store: cs, - hash: defaultHashFunction, - bitWidth: defaultBitWidth, + hash: hashFunction, + bitWidth: bitWidth, } // apply functional options to node before using for _, option := range options { option(nd) } + nd.Bitfield = NewBitmap(nd.bitWidth) return nd } @@ -263,14 +268,14 @@ func (n *Node) getValue(ctx context.Context, hv *hashBits, k string, cb func(*KV // if the element expected at this node isn't here then we can be sure it // doesn't exist in the HAMT. - if n.Bitfield.Bit(idx) == 0 { + if !n.Bitfield.IsSet(idx) { return ErrNotFound } // otherwise, the value is either local or in a child // perform a popcount of bits up to the `idx` to find `cindex` - cindex := byte(n.indexForBitPos(idx)) + cindex := byte(n.Bitfield.Index(idx)) c := n.getPointer(cindex) if c.isShard() { @@ -327,7 +332,7 @@ func LoadNode(ctx context.Context, cs cbor.IpldStore, c cid.Cid, options ...Opti return loadNode(ctx, cs, c, true, defaultBitWidth, defaultHashFunction, options...) } -// internal version of loadNode that is aware of whether this is a root node or +// internal version of LoadNode that is aware of whether this is a root node or // not for the purpose of additional validation on non-root nodes. func loadNode( ctx context.Context, @@ -359,8 +364,13 @@ func loadNode( return nil, ErrMalformedHamt } + // bitmap needs to be exactly the right number of bytes + if out.Bitfield.BitWidth() != out.bitWidth { + return nil, ErrMalformedHamt + } + // the bifield is lying or the elements array is - if out.bitsSetCount() != len(out.Pointers) { + if out.Bitfield.BitsSetCount() != len(out.Pointers) { return nil, ErrMalformedHamt } @@ -573,14 +583,14 @@ func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k []byte, v *cbg.D // if the element expected at this node isn't here then we can be sure it // doesn't exist in the HAMT already and can insert it at the appropriate // position. - if n.Bitfield.Bit(idx) != 1 { + if !n.Bitfield.IsSet(idx) { return n.insertKV(idx, k, v) } // otherwise, the value is either local or in a child // perform a popcount of bits up to the `idx` to find `cindex` - cindex := byte(n.indexForBitPos(idx)) + cindex := byte(n.Bitfield.Index(idx)) child := n.getPointer(cindex) if child.isShard() { @@ -648,9 +658,7 @@ func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k []byte, v *cbg.D // not cause an overflow - i.e. we just need to take each element, hash it // and consume the correct number of bytes off the digest and figure out // where it should be in the new node. - sub := NewNode(n.store) - sub.bitWidth = n.bitWidth - sub.hash = n.hash + sub := newNode(n.store, n.bitWidth, n.hash) hvcopy := &hashBits{b: hv.b, consumed: hv.consumed} if err := sub.modifyValue(ctx, hvcopy, k, v); err != nil { return err @@ -692,8 +700,8 @@ func (n *Node) insertKV(idx int, k []byte, v *cbg.Deferred) error { return ErrNotFound } - i := n.indexForBitPos(idx) - n.Bitfield.SetBit(n.Bitfield, idx, 1) + i := n.Bitfield.Index(idx) + n.Bitfield.Set(idx, true) p := &Pointer{KVs: []*KV{{Key: k, Value: v}}} @@ -715,7 +723,7 @@ func (n *Node) setPointer(i byte, p *Pointer) error { func (n *Node) rmPointer(i byte, idx int) error { copy(n.Pointers[i:], n.Pointers[i+1:]) n.Pointers = n.Pointers[:len(n.Pointers)-1] - n.Bitfield.SetBit(n.Bitfield, idx, 0) + n.Bitfield.Set(idx, false) return nil } @@ -740,10 +748,8 @@ func (n *Node) getPointer(i byte) *Pointer { // as cached nodes. func (n *Node) Copy() *Node { // TODO(rvagg): clarify what situations this method is actually useful for. - nn := NewNode(n.store) - nn.bitWidth = n.bitWidth - nn.hash = n.hash - nn.Bitfield.Set(n.Bitfield) + nn := newNode(n.store, n.bitWidth, n.hash) + nn.Bitfield = n.Bitfield.Copy() nn.Pointers = make([]*Pointer, len(n.Pointers)) for i, p := range n.Pointers { diff --git a/hamt_test.go b/hamt_test.go index 6ffbe1a..8adebb1 100644 --- a/hamt_test.go +++ b/hamt_test.go @@ -6,6 +6,8 @@ import ( "crypto/sha256" "encoding/hex" "fmt" + "math/big" + "math/bits" "math/rand" "strings" "testing" @@ -859,6 +861,9 @@ func TestMalformedHamt(t *testing.T) { blocks := newMockBlocks() cs := cbor.NewCborStore(blocks) bcid, err := cid.Decode("bafy2bzaceab7vkg5c3zti7ebqensb3onksjkc4wwktkiledkezgvnbvzs4cti") + if err != nil { + t.Fatal(err) + } bccid, err := cid.Decode("bafy2bzaceab7vkg5c3zti7ebqensb3onksjkc4wwktkiledkezgvnbvzs4cqa") if err != nil { t.Fatal(err) @@ -874,15 +879,15 @@ func TestMalformedHamt(t *testing.T) { store := func(blob []byte) { blocks.data[bcid] = block.NewBlock(blob) } - load := func() *Node { - n, err := LoadNode(ctx, cs, bcid, UseTreeBitWidth(8), UseHashFunction(identityHash)) + load := func(bitWidth int) *Node { + n, err := LoadNode(ctx, cs, bcid, UseTreeBitWidth(bitWidth), UseHashFunction(identityHash)) if err != nil { t.Fatal(err) } return n } - find := func(key []byte, expected []byte) *[]byte { - vg, err := load().FindRaw(ctx, string(key)) + find := func(nd *Node, key []byte, expected []byte) *[]byte { + vg, err := nd.FindRaw(ctx, string(key)) if err != nil { t.Fatal(err) } @@ -916,38 +921,46 @@ func TestMalformedHamt(t *testing.T) { bcat(b(0x80+1), // array(1) bucketCbor(kv{0x00, 0xff})))) // 0x00=0xff // should find a bytes(1) "\xff" - find(b(0x00), bcat(b(0x40+1), b(0xff))) + find(load(3), b(0x00), bcat(b(0x40+1), b(0xff))) // print the raw cbor: fmt.Printf("%v\n", hex.EncodeToString(blocks.data[bcid].RawData())) - // 10 entry node, assumed bitwidth of >3 + // 10 entry node, assumed bitwidth of 4 store( bcat(b(0x80+2), // array(2) bcat(b(0x40+2), []byte{0x03, 0xff}), // bytes(1) "\x3ff" (bitmap with lower 10 bits set) bcat(b(0x80+10), // array(10) - bucketCbor(kv{0x00, 0xf0}), // 0x00=0xf0 - bucketCbor(kv{0x01, 0xf1}), // 0x01=0xf1 - bucketCbor(kv{0x02, 0xf2}), // 0x02=0xf2 - bucketCbor(kv{0x03, 0xf3}), // 0x03=0xf3 - bucketCbor(kv{0x04, 0xf4}), // 0x04=0xf4 - bucketCbor(kv{0x05, 0xf5}), // 0x05=0xf5 - bucketCbor(kv{0x06, 0xf6}), // 0x06=0xf6 - bucketCbor(kv{0x07, 0xf7}), // 0x07=0xf7 - bucketCbor(kv{0x08, 0xf8}), // 0x08=0xf8 - bucketCbor(kv{0x09, 0xf9})))) // 0x09=0xf9 + // shift these indexes up by 4 because with a bitWidth of 4 we are + // chomping the top bits first + bucketCbor(kv{0x00 << 4, 0xf0}), // 0x00=0xf0 + bucketCbor(kv{0x01 << 4, 0xf1}), // 0x01=0xf1 + bucketCbor(kv{0x02 << 4, 0xf2}), // 0x02=0xf2 + bucketCbor(kv{0x03 << 4, 0xf3}), // 0x03=0xf3 + bucketCbor(kv{0x04 << 4, 0xf4}), // 0x04=0xf4 + bucketCbor(kv{0x05 << 4, 0xf5}), // 0x05=0xf5 + bucketCbor(kv{0x06 << 4, 0xf6}), // 0x06=0xf6 + bucketCbor(kv{0x07 << 4, 0xf7}), // 0x07=0xf7 + bucketCbor(kv{0x08 << 4, 0xf8}), // 0x08=0xf8 + bucketCbor(kv{0x09 << 4, 0xf9})))) // 0x09=0xf9 // sanity check for i := 0; i < 10; i++ { v := bcat(b(0x40+1), b(0xf0+byte(i))) - if vg := find(b(0x00+byte(i)), v); vg != nil { + if vg := find(load(4), b(byte(i<<4)), v); vg != nil { t.Fatalf("expected a value of %v, got %v", hex.EncodeToString(v), hex.EncodeToString(*vg)) } } - // load as bitWidth=3, which can only handle a max of 8 elements + // load as bitWidth=3 which needs a 1-byte index n, err := LoadNode(ctx, cs, bcid, UseTreeBitWidth(3), UseHashFunction(identityHash)) if err != ErrMalformedHamt || n != nil { t.Fatal("Should have returned ErrMalformedHamt for too-small bitWidth") } + // load as bitWidth=5 which needs a 4-byte index + n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(5), UseHashFunction(identityHash)) + if err != ErrMalformedHamt || n != nil { + t.Fatal("Should have returned ErrMalformedHamt for too-large bitWidth") + } + // test that the bitfield set count matches array size // this node says it has 3 elements in the bitfield, but there are 4 buckets store( @@ -972,7 +985,7 @@ func TestMalformedHamt(t *testing.T) { bcat(b(0xd8), b(0x2a), // tag(42) b(0x58), b(0x27), // bytes(39) cidBytes)))) // cid - load() + load(3) // node pointing to a non-dag-cbor node store( @@ -982,7 +995,7 @@ func TestMalformedHamt(t *testing.T) { bcat(b(0xd8), b(0x2a), // tag(42) b(0x58), b(0x27), // bytes(39) badCidBytes)))) // cid - n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(8), UseHashFunction(identityHash)) + n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(3), UseHashFunction(identityHash)) if err != ErrMalformedHamt || n != nil { t.Fatal("Should have returned ErrMalformedHamt for bad child link codec") } @@ -993,7 +1006,7 @@ func TestMalformedHamt(t *testing.T) { bcat(b(0x40+1), b(0x01)), // bytes(1) "\x01" (bitmap) bcat(b(0x80+1), // array(1) bucketCbor()))) // empty bucket - n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(8), UseHashFunction(identityHash)) + n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(3), UseHashFunction(identityHash)) if err != ErrMalformedHamt || n != nil { t.Fatal("Should have returned ErrMalformedHamt for zero element bucket") } @@ -1009,7 +1022,7 @@ func TestMalformedHamt(t *testing.T) { kv{0x02, 0xff}, kv{0x03, 0xff})))) // bucket with 4 entires - n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(8), UseHashFunction(identityHash)) + n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(3), UseHashFunction(identityHash)) if err != ErrMalformedHamt || n != nil { t.Fatal("Should have returned ErrMalformedHamt for four element bucket") } @@ -1024,7 +1037,7 @@ func TestMalformedHamt(t *testing.T) { kv{0x01, 0xff}, kv{0x00, 0xff})))) // bucket with 2, misordered entries - n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(8), UseHashFunction(identityHash)) + n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(3), UseHashFunction(identityHash)) if err != ErrMalformedHamt || n != nil { t.Fatal("Should have returned ErrMalformedHamt for mis-ordered bucket") } @@ -1039,7 +1052,7 @@ func TestMalformedHamt(t *testing.T) { kv{0x01, 0xf0}, kv{0x01, 0xff})))) // bucket with 3 element, 2 dupes with different values - n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(8), UseHashFunction(identityHash)) + n, err = LoadNode(ctx, cs, bcid, UseTreeBitWidth(3), UseHashFunction(identityHash)) if err != ErrMalformedHamt || n != nil { t.Fatal("Should have returned ErrMalformedHamt for mis-ordered bucket") } @@ -1049,7 +1062,7 @@ func TestMalformedHamt(t *testing.T) { bcat(b(0x80+2), // array(2) bcat(b(0x40+1), b(0x00)), // bytes(1) "\x00" (bitmap) bcat(b(0x80+0)))) // array(0) - load() + load(3) // make a child empty block and point to it in a root blocks.data[bccid] = block.NewBlock( @@ -1065,7 +1078,7 @@ func TestMalformedHamt(t *testing.T) { b(0x58), b(0x27), // bytes(39) ccidBytes)))) // cid - vg, err := load().FindRaw(ctx, string([]byte{0x00, 0x01})) + vg, err := load(3).FindRaw(ctx, string([]byte{0x00, 0x01})) // without validation of the child block, this would return an ErrNotFound if err != ErrMalformedHamt || vg != nil { t.Fatal("Should have returned ErrMalformedHamt for its empty child node") @@ -1077,7 +1090,7 @@ func TestMalformedHamt(t *testing.T) { bcat(b(0x40+1), b(0x01)), // bytes(1) "\x01" (bitmap) bcat(b(0x80+1), // array(1) bucketCbor(kv{0x00, 0x01})))) - vg, err = load().FindRaw(ctx, string([]byte{0x00, 0x01})) + vg, err = load(3).FindRaw(ctx, string([]byte{0x00, 0x01})) // without validation of the child block, this would return an ErrNotFound if err != ErrMalformedHamt || vg != nil { t.Fatal("Should have returned ErrMalformedHamt for its too-small child node") @@ -1085,19 +1098,19 @@ func TestMalformedHamt(t *testing.T) { blocks.data[bccid] = block.NewBlock( bcat(b(0x80+2), // array(2) - bcat(b(0x40+1), b(0x01)), // bytes(1) "\x01" (bitmap) + bcat(b(0x40+1), b(0x07)), // bytes(1) "\x07" (bitmap) bcat(b(0x80+3), // array(1) bucketCbor(kv{0x00, 0x01}), bucketCbor(kv{0x01, 0x01}), bucketCbor(kv{0x02, 0x01})))) - vg, err = load().FindRaw(ctx, string([]byte{0x00, 0x01})) + vg, err = load(3).FindRaw(ctx, string([]byte{0x00, 0x01})) // without validation of the child block, this would return an ErrNotFound if err != ErrMalformedHamt || vg != nil { t.Fatal("Should have returned ErrMalformedHamt for its too-small child node") } - // same as the above case, too few direct entries, but this one has a link in - // it to a child so we can't perform this check, so this should work + // same as the above case, too few direct entries, but this one has a link + // in it to a child so we can't perform this check, so this should work blocks.data[bccid] = block.NewBlock( bcat(b(0x80+2), // array(2) bcat(b(0x40+1), b(0x03)), // bytes(1) "\x03" (bitmap) @@ -1110,9 +1123,65 @@ func TestMalformedHamt(t *testing.T) { b(0x58), b(0x27), // bytes(39) ccidBytes)))) // cid - vg, err = load().FindRaw(ctx, string([]byte{0x00, 0x01})) + vg, err = load(3).FindRaw(ctx, string([]byte{0x00, 0x01})) // without validation of the child block, this would return an ErrNotFound if err != nil && bytes.Compare(vg, []byte{0x40 + 2, 0x00, 0x01}) != 0 { t.Fatal("Should have returned found entry") } } + +func TestIndexForBitRandom(t *testing.T) { + t.Parallel() + r := rand.New(rand.NewSource(int64(42))) + + count := 100000 + slot := make([]byte, 32) + for i := 0; i < count; i++ { + _, err := r.Read(slot) + if err != nil { + t.Fatal("couldn't create random bitfield") + } + bi := big.NewInt(0).SetBytes(slot) + bm := NewBitmapFrom(slot) + for k := 0; k < 256; k++ { + if indexForBitPosOriginal(k, bi) != bm.Index(k) { + t.Fatalf("indexForBit doesn't match with new") + } + } + } +} + +func TestIndexForBitLinear(t *testing.T) { + t.Parallel() + var i int64 + for i = 0; i < 1<<16-1; i++ { + bi := big.NewInt(i) + // turn the bigint into a proper 8 byte slice + bib := bi.Bytes() + if len(bib) < 8 { + bib = append(make([]byte, 8-len(bib)), bib...) + } + bm := NewBitmapFrom(bib) + for k := 0; k < 16; k++ { + if indexForBitPosOriginal(k, bi) != bm.Index(k) { + t.Fatalf("indexForBit doesn't match with new") + } + } + } +} + +// Original implementation of indexForBit, before #39. +func indexForBitPosOriginal(bp int, bitfield *big.Int) int { + mask := new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(bp)), nil), big.NewInt(1)) + mask.And(mask, bitfield) + + return popCount(mask) +} + +func popCount(i *big.Int) int { + var n int + for _, v := range i.Bits() { + n += bits.OnesCount64(uint64(v)) + } + return n +} diff --git a/uhamt.go b/uhamt.go deleted file mode 100644 index 29e6740..0000000 --- a/uhamt.go +++ /dev/null @@ -1,41 +0,0 @@ -package hamt - -import ( - "math/big" - "math/bits" -) - -// indexForBitPos returns the index within the collapsed array corresponding to -// the given bit in the bitset. The collapsed array contains only one entry -// per bit set in the bitfield, and this function is used to map the indices. -// This is similar to a popcount() operation but is limited to a certain index. -// e.g. a Bitfield of `10010110000` shows that we have a 4 elements in the -// associated array. Indexes `[1]` and `[2]` are not present, but index `[3]` -// is at the second position of our Pointers array. -func (n *Node) indexForBitPos(bp int) int { - return indexForBitPos(bp, n.Bitfield) -} - -func indexForBitPos(bp int, bitfield *big.Int) int { - var x uint - var count, i int - w := bitfield.Bits() - for x = uint(bp); x > bits.UintSize && i < len(w); x -= bits.UintSize { - count += bits.OnesCount(uint(w[i])) - i++ - } - if i == len(w) { - return count - } - return count + bits.OnesCount(uint(w[i])&((1<