Skip to content

Commit

Permalink
sm4: reduce allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
emmansun authored Mar 27, 2024
1 parent 178241a commit e4909be
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 54 deletions.
7 changes: 2 additions & 5 deletions sm4/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
)

// Encrypt one block from src into dst, using the expanded key xk.
func encryptBlockGo(xk []uint32, dst, src []byte) {
func encryptBlockGo(xk *[rounds]uint32, dst, src []byte) {
_ = src[15] // early bounds check
_ = xk[31] // bounds check elimination hint

var b0, b1, b2, b3 uint32
b0 = binary.BigEndian.Uint32(src[0:4])
Expand Down Expand Up @@ -68,10 +67,8 @@ func encryptBlockGo(xk []uint32, dst, src []byte) {
}

// Key expansion algorithm.
func expandKeyGo(key []byte, enc, dec []uint32) {
func expandKeyGo(key []byte, enc, dec *[rounds]uint32) {
// Encryption key setup.
enc = enc[:rounds]
dec = dec[:rounds]
key = key[:KeySize]
var b0, b1, b2, b3 uint32
b0 = binary.BigEndian.Uint32(key[:4]) ^ fk[0]
Expand Down
14 changes: 7 additions & 7 deletions sm4/cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ const rounds = 32

// A cipher is an instance of SM4 encryption using a particular key.
type sm4Cipher struct {
enc []uint32
dec []uint32
enc [rounds]uint32
dec [rounds]uint32
}

// NewCipher creates and returns a new cipher.Block.
Expand All @@ -35,9 +35,9 @@ func NewCipher(key []byte) (cipher.Block, error) {
// newCipher creates and returns a new cipher.Block
// implemented in pure Go.
func newCipherGeneric(key []byte) (cipher.Block, error) {
c := sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}
expandKeyGo(key, c.enc, c.dec)
return &c, nil
c := &sm4Cipher{}
expandKeyGo(key, &c.enc, &c.dec)
return c, nil
}

func (c *sm4Cipher) BlockSize() int { return BlockSize }
Expand All @@ -52,7 +52,7 @@ func (c *sm4Cipher) Encrypt(dst, src []byte) {
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("sm4: invalid buffer overlap")
}
encryptBlockGo(c.enc, dst, src)
encryptBlockGo(&c.enc, dst, src)
}

func (c *sm4Cipher) Decrypt(dst, src []byte) {
Expand All @@ -65,5 +65,5 @@ func (c *sm4Cipher) Decrypt(dst, src []byte) {
if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) {
panic("sm4: invalid buffer overlap")
}
encryptBlockGo(c.dec, dst, src)
encryptBlockGo(&c.dec, dst, src)
}
12 changes: 6 additions & 6 deletions sm4/cipher_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ func newCipher(key []byte) (cipher.Block, error) {
if useAVX2 {
blocks = 8
}
c := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize}
c := &sm4CipherGCM{sm4CipherAsm{sm4Cipher{}, blocks, blocks * BlockSize}}
expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_AES)
if supportsGFMUL {
return &sm4CipherGCM{c}, nil
return c, nil
}
return c, nil
return &c.sm4CipherAsm, nil
}

func (c *sm4CipherAsm) Concurrency() int { return c.batchBlocks }
Expand All @@ -74,7 +74,7 @@ func (c *sm4CipherAsm) Encrypt(dst, src []byte) {
if useAESNI4SingleBlock {
encryptBlockAsm(&c.enc[0], &dst[0], &src[0], INST_AES)
} else {
encryptBlockGo(c.enc, dst, src)
encryptBlockGo(&c.enc, dst, src)
}
}

Expand All @@ -91,7 +91,7 @@ func (c *sm4CipherAsm) Decrypt(dst, src []byte) {
if useAESNI4SingleBlock {
encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_AES)
} else {
encryptBlockGo(c.dec, dst, src)
encryptBlockGo(&c.dec, dst, src)
}
}

Expand Down Expand Up @@ -129,6 +129,6 @@ func expandKey(key []byte, enc, dec []uint32) {
} else if supportsAES {
expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], INST_AES)
} else {
expandKeyGo(key, enc, dec)
expandKeyGo(key, (*[rounds]uint32)(enc), (*[rounds]uint32)(dec))
}
}
14 changes: 7 additions & 7 deletions sm4/cipher_asm_fuzzy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
func TestExpandKey(t *testing.T) {
key := make([]byte, 16)

encRes1 := make([]uint32, 32)
decRes1 := make([]uint32, 32)
var encRes1 [rounds]uint32
var decRes1 [rounds]uint32
encRes2 := make([]uint32, 32)
decRes2 := make([]uint32, 32)
var timeout *time.Timer
Expand All @@ -32,13 +32,13 @@ func TestExpandKey(t *testing.T) {
default:
}
io.ReadFull(rand.Reader, key)
expandKeyGo(key, encRes1, decRes1)
expandKeyGo(key, &encRes1, &decRes1)
expandKey(key, encRes2, decRes2)
if !reflect.DeepEqual(encRes1, encRes2) {
t.Errorf("expected=%x, result=%x\n", encRes1, encRes2)
if !reflect.DeepEqual(encRes1[:], encRes2) {
t.Errorf("expected=%x, result=%x\n", encRes1[:], encRes2)
}
if !reflect.DeepEqual(decRes1, decRes2) {
t.Errorf("expected=%x, result=%x\n", encRes1, encRes2)
if !reflect.DeepEqual(decRes1[:], decRes2) {
t.Errorf("expected=%x, result=%x\n", decRes1[:], decRes2)
}
}
}
2 changes: 1 addition & 1 deletion sm4/cipher_asm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestWithoutGFMUL(t *testing.T) {
if useAVX2 {
blocks = 8
}
c1 := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize}
c1 := &sm4CipherAsm{sm4Cipher{}, blocks, blocks * BlockSize}
expandKeyAsm(&key[0], &ck[0], &c1.enc[0], &c1.dec[0], INST_AES)
c = c1
}
Expand Down
6 changes: 0 additions & 6 deletions sm4/cipher_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,3 @@ import "crypto/cipher"
func newCipher(key []byte) (cipher.Block, error) {
return newCipherGeneric(key)
}

// expandKey is used by BenchmarkExpand and should
// call an assembly implementation if one is available.
func expandKey(key []byte, enc, dec []uint32) {
expandKeyGo(key, enc, dec)
}
6 changes: 3 additions & 3 deletions sm4/cipher_ni.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ type sm4CipherNI struct {
}

func newCipherNI(key []byte) (cipher.Block, error) {
c := &sm4CipherNI{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}}
c := &sm4CipherNIGCM{sm4CipherNI{sm4Cipher{}}}
expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_SM4)
if supportsGFMUL {
return &sm4CipherNIGCM{c}, nil
return c, nil
}
return c, nil
return &c.sm4CipherNI, nil
}

func (c *sm4CipherNI) Encrypt(dst, src []byte) {
Expand Down
9 changes: 0 additions & 9 deletions sm4/cipher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,3 @@ func BenchmarkDecrypt(b *testing.B) {
c.Decrypt(out, tt.out)
}
}

func BenchmarkExpand(b *testing.B) {
tt := encryptTests[0]
c := &sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}
b.ResetTimer()
for i := 0; i < b.N; i++ {
expandKey(tt.key, c.enc, c.dec)
}
}
10 changes: 5 additions & 5 deletions sm4/sm4_gcm_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
// will use the optimised implementation in this file when possible. Instances
// of this type only exist when hasGCMAsm and hasAES returns true.
type sm4CipherGCM struct {
*sm4CipherAsm
sm4CipherAsm
}

// Assert that sm4CipherGCM implements the gcmAble interface.
Expand Down Expand Up @@ -43,10 +43,10 @@ type gcmAsm struct {
// called by crypto/cipher.NewGCM via the gcmAble interface.
func (c *sm4CipherGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) {
g := &gcmAsm{}
g.cipher = c.sm4CipherAsm
g.cipher = &c.sm4CipherAsm
g.nonceSize = nonceSize
g.tagSize = tagSize
gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_AES)
gcmSm4Init(&g.bytesProductTable, g.cipher.enc[:], INST_AES)
return g, nil
}

Expand Down Expand Up @@ -91,7 +91,7 @@ func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte {
}

if len(plaintext) > 0 {
gcmSm4Enc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc)
gcmSm4Enc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data)))
copy(out[len(plaintext):], tagOut[:])
Expand Down Expand Up @@ -144,7 +144,7 @@ func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
panic("cipher: invalid buffer overlap")
}
if len(ciphertext) > 0 {
gcmSm4Dec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc)
gcmSm4Dec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data)))

Expand Down
10 changes: 5 additions & 5 deletions sm4/sm4ni_gcm_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func gcmSm4niDec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk
// will use the optimised implementation in this file when possible. Instances
// of this type only exist when hasGCMAsm and hasSM4 returns true.
type sm4CipherNIGCM struct {
*sm4CipherNI
sm4CipherNI
}

// Assert that sm4CipherNIGCM implements the gcmAble interface.
Expand All @@ -44,10 +44,10 @@ func (g *gcmNI) Overhead() int {
// called by crypto/cipher.NewGCM via the gcmAble interface.
func (c *sm4CipherNIGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) {
g := &gcmNI{}
g.cipher = c.sm4CipherNI
g.cipher = &c.sm4CipherNI
g.nonceSize = nonceSize
g.tagSize = tagSize
gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_SM4)
gcmSm4Init(&g.bytesProductTable, g.cipher.enc[:], INST_SM4)
return g, nil
}

Expand Down Expand Up @@ -84,7 +84,7 @@ func (g *gcmNI) Seal(dst, nonce, plaintext, data []byte) []byte {
}

if len(plaintext) > 0 {
gcmSm4niEnc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc)
gcmSm4niEnc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data)))
copy(out[len(plaintext):], tagOut[:])
Expand Down Expand Up @@ -137,7 +137,7 @@ func (g *gcmNI) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
panic("cipher: invalid buffer overlap")
}
if len(ciphertext) > 0 {
gcmSm4niDec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc)
gcmSm4niDec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc[:])
}
gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data)))

Expand Down

0 comments on commit e4909be

Please sign in to comment.