Skip to content

Commit

Permalink
kdf: refactoring, create one interface
Browse files Browse the repository at this point in the history
  • Loading branch information
emmansun authored May 17, 2024
1 parent 7fb729f commit 9ef3fdc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
12 changes: 10 additions & 2 deletions kdf/kdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@ import (
"hash"
)

// KdfInterface is the interface implemented by some specific Hash implementations.
type KdfInterface interface {
Kdf(z []byte, keyLen int) []byte
}

// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
// ANSI-X9.63-KDF
func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
baseMD := newHash()
// If the hash implements KdfInterface, use the optimized Kdf method.
if kdfImpl, ok := baseMD.(KdfInterface); ok {
return kdfImpl.Kdf(z, keyLen)
}
limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size())
if limit >= uint64(1<<32)-1 {
panic("kdf: key length too long")
Expand All @@ -19,8 +28,7 @@ func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
var ct uint32 = 1
var k []byte

marshaler, ok := baseMD.(encoding.BinaryMarshaler)
if limit == 1 || len(z) < baseMD.BlockSize() || !ok {
if marshaler, ok := baseMD.(encoding.BinaryMarshaler); limit == 1 || len(z) < baseMD.BlockSize() || !ok {
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
baseMD.Write(z)
Expand Down
8 changes: 6 additions & 2 deletions sm3/sm3.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,14 @@ func Sum(data []byte) [Size]byte {
}

// Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3.
func Kdf(z []byte, keyLen int) []byte {
func (baseMD *digest) Kdf(z []byte, keyLen int) []byte {
limit := uint64(keyLen+Size-1) / uint64(Size)
if limit >= uint64(1<<32)-1 {
panic("sm3: key length too long")
}
var countBytes [4]byte
var ct uint32 = 1
k := make([]byte, keyLen)
baseMD := new(digest)
baseMD.Reset()
baseMD.Write(z)
for i := 0; i < int(limit); i++ {
Expand All @@ -234,3 +233,8 @@ func Kdf(z []byte, keyLen int) []byte {
}
return k
}

func Kdf(z []byte, keyLen int) []byte {
baseMD := new(digest)
return baseMD.Kdf(z, keyLen)
}

0 comments on commit 9ef3fdc

Please sign in to comment.