-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathshards.go
117 lines (99 loc) · 2.85 KB
/
shards.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package kinesis
import (
"crypto/md5"
"fmt"
"math"
"math/big"
"sort"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kinesis"
)
// Shard holds the information for a kinesis shard
type Shard struct {
shardID string
startingHashKey *big.Int
endingHashKey *big.Int
}
type shards []Shard
func (s shards) Len() int {
return len(s)
}
func (s shards) Less(i, j int) bool {
return s[i].startingHashKey.Cmp(s[j].startingHashKey) < 0
}
func (s shards) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
// ShardInfo provides a way to find the index for which shard to put a span in
type ShardInfo struct {
shiftLen uint // number of bits to shift to get to an int
shards shards // use for names and backup if not a power of 2
power bool // to know if its a power of 2
}
func (s *ShardInfo) getIndex(traceID string) (int, error) {
if len(s.shards) == 1 {
return 0, nil
}
key := partitionKeyToHashKey(traceID)
if s.power {
rshift := (&big.Int{}).Rsh(key, s.shiftLen)
return int(rshift.Int64()), nil
}
// honestly this should be a tree, but it would have to be a custom one so probably not worth the effort since nearly everyone is a power of 2
for i := 0; i < len(s.shards); i++ {
sh := s.shards[i]
if key.Cmp(sh.endingHashKey) <= 0 {
return i, nil
}
}
return -1, fmt.Errorf("no shard found for parition key %s", traceID)
}
type kinin interface {
ListShards(input *kinesis.ListShardsInput) (*kinesis.ListShardsOutput, error)
}
func getShardInfo(k kinin, streamName string) (*ShardInfo, error) {
listShardsInput := &kinesis.ListShardsInput{
StreamName: aws.String(streamName),
MaxResults: aws.Int64(100),
}
ret := &ShardInfo{}
for {
resp, err := k.ListShards(listShardsInput)
if err != nil {
return nil, fmt.Errorf("listShards error: %v", err)
}
for _, s := range resp.Shards {
// shard is closed so skip it
if s.SequenceNumberRange.EndingSequenceNumber != nil {
continue
}
sh := Shard{
shardID: *s.ShardId,
startingHashKey: toBigInt(*s.HashKeyRange.StartingHashKey),
endingHashKey: toBigInt(*s.HashKeyRange.EndingHashKey),
}
ret.shards = append(ret.shards, sh)
}
if resp.NextToken == nil {
sort.Sort(ret.shards)
ret.power = math.Ceil(math.Log2(float64(len(ret.shards)))) == math.Floor(math.Log2(float64(len(ret.shards))))
ret.shiftLen = uint(128 - math.Log2(float64(len(ret.shards))))
return ret, nil
}
listShardsInput = &kinesis.ListShardsInput{
NextToken: resp.NextToken,
}
}
}
func toBigInt(key string) *big.Int {
num := big.NewInt(0)
num.SetString(key, 10)
return num
}
func (s *Shard) belongsToShardKey(key *big.Int) (bool, error) {
return key.Cmp(s.startingHashKey) >= 0 && key.Cmp(s.endingHashKey) <= 0, nil
}
func partitionKeyToHashKey(partitionKey string) *big.Int {
b := md5.Sum([]byte(partitionKey))
return big.NewInt(0).SetBytes(b[:])
}