From e4909bed2d28ae8ee3fbc2f22cbca6cc00175cda Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 27 Mar 2024 08:38:25 +0800 Subject: [PATCH] sm4: reduce allocations --- sm4/block.go | 7 ++----- sm4/cipher.go | 14 +++++++------- sm4/cipher_asm.go | 12 ++++++------ sm4/cipher_asm_fuzzy_test.go | 14 +++++++------- sm4/cipher_asm_test.go | 2 +- sm4/cipher_generic.go | 6 ------ sm4/cipher_ni.go | 6 +++--- sm4/cipher_test.go | 9 --------- sm4/sm4_gcm_asm.go | 10 +++++----- sm4/sm4ni_gcm_asm.go | 10 +++++----- 10 files changed, 36 insertions(+), 54 deletions(-) diff --git a/sm4/block.go b/sm4/block.go index 46b8db20..18554be3 100644 --- a/sm4/block.go +++ b/sm4/block.go @@ -7,9 +7,8 @@ import ( ) // Encrypt one block from src into dst, using the expanded key xk. -func encryptBlockGo(xk []uint32, dst, src []byte) { +func encryptBlockGo(xk *[rounds]uint32, dst, src []byte) { _ = src[15] // early bounds check - _ = xk[31] // bounds check elimination hint var b0, b1, b2, b3 uint32 b0 = binary.BigEndian.Uint32(src[0:4]) @@ -68,10 +67,8 @@ func encryptBlockGo(xk []uint32, dst, src []byte) { } // Key expansion algorithm. -func expandKeyGo(key []byte, enc, dec []uint32) { +func expandKeyGo(key []byte, enc, dec *[rounds]uint32) { // Encryption key setup. - enc = enc[:rounds] - dec = dec[:rounds] key = key[:KeySize] var b0, b1, b2, b3 uint32 b0 = binary.BigEndian.Uint32(key[:4]) ^ fk[0] diff --git a/sm4/cipher.go b/sm4/cipher.go index e3225a74..1b0ecee1 100644 --- a/sm4/cipher.go +++ b/sm4/cipher.go @@ -15,8 +15,8 @@ const rounds = 32 // A cipher is an instance of SM4 encryption using a particular key. type sm4Cipher struct { - enc []uint32 - dec []uint32 + enc [rounds]uint32 + dec [rounds]uint32 } // NewCipher creates and returns a new cipher.Block. @@ -35,9 +35,9 @@ func NewCipher(key []byte) (cipher.Block, error) { // newCipher creates and returns a new cipher.Block // implemented in pure Go. func newCipherGeneric(key []byte) (cipher.Block, error) { - c := sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)} - expandKeyGo(key, c.enc, c.dec) - return &c, nil + c := &sm4Cipher{} + expandKeyGo(key, &c.enc, &c.dec) + return c, nil } func (c *sm4Cipher) BlockSize() int { return BlockSize } @@ -52,7 +52,7 @@ func (c *sm4Cipher) Encrypt(dst, src []byte) { if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } - encryptBlockGo(c.enc, dst, src) + encryptBlockGo(&c.enc, dst, src) } func (c *sm4Cipher) Decrypt(dst, src []byte) { @@ -65,5 +65,5 @@ func (c *sm4Cipher) Decrypt(dst, src []byte) { if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } - encryptBlockGo(c.dec, dst, src) + encryptBlockGo(&c.dec, dst, src) } diff --git a/sm4/cipher_asm.go b/sm4/cipher_asm.go index effe4a22..ca874445 100644 --- a/sm4/cipher_asm.go +++ b/sm4/cipher_asm.go @@ -51,12 +51,12 @@ func newCipher(key []byte) (cipher.Block, error) { if useAVX2 { blocks = 8 } - c := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize} + c := &sm4CipherGCM{sm4CipherAsm{sm4Cipher{}, blocks, blocks * BlockSize}} expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_AES) if supportsGFMUL { - return &sm4CipherGCM{c}, nil + return c, nil } - return c, nil + return &c.sm4CipherAsm, nil } func (c *sm4CipherAsm) Concurrency() int { return c.batchBlocks } @@ -74,7 +74,7 @@ func (c *sm4CipherAsm) Encrypt(dst, src []byte) { if useAESNI4SingleBlock { encryptBlockAsm(&c.enc[0], &dst[0], &src[0], INST_AES) } else { - encryptBlockGo(c.enc, dst, src) + encryptBlockGo(&c.enc, dst, src) } } @@ -91,7 +91,7 @@ func (c *sm4CipherAsm) Decrypt(dst, src []byte) { if useAESNI4SingleBlock { encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_AES) } else { - encryptBlockGo(c.dec, dst, src) + encryptBlockGo(&c.dec, dst, src) } } @@ -129,6 +129,6 @@ func expandKey(key []byte, enc, dec []uint32) { } else if supportsAES { expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], INST_AES) } else { - expandKeyGo(key, enc, dec) + expandKeyGo(key, (*[rounds]uint32)(enc), (*[rounds]uint32)(dec)) } } diff --git a/sm4/cipher_asm_fuzzy_test.go b/sm4/cipher_asm_fuzzy_test.go index 11c88811..5836e02e 100644 --- a/sm4/cipher_asm_fuzzy_test.go +++ b/sm4/cipher_asm_fuzzy_test.go @@ -13,8 +13,8 @@ import ( func TestExpandKey(t *testing.T) { key := make([]byte, 16) - encRes1 := make([]uint32, 32) - decRes1 := make([]uint32, 32) + var encRes1 [rounds]uint32 + var decRes1 [rounds]uint32 encRes2 := make([]uint32, 32) decRes2 := make([]uint32, 32) var timeout *time.Timer @@ -32,13 +32,13 @@ func TestExpandKey(t *testing.T) { default: } io.ReadFull(rand.Reader, key) - expandKeyGo(key, encRes1, decRes1) + expandKeyGo(key, &encRes1, &decRes1) expandKey(key, encRes2, decRes2) - if !reflect.DeepEqual(encRes1, encRes2) { - t.Errorf("expected=%x, result=%x\n", encRes1, encRes2) + if !reflect.DeepEqual(encRes1[:], encRes2) { + t.Errorf("expected=%x, result=%x\n", encRes1[:], encRes2) } - if !reflect.DeepEqual(decRes1, decRes2) { - t.Errorf("expected=%x, result=%x\n", encRes1, encRes2) + if !reflect.DeepEqual(decRes1[:], decRes2) { + t.Errorf("expected=%x, result=%x\n", decRes1[:], decRes2) } } } diff --git a/sm4/cipher_asm_test.go b/sm4/cipher_asm_test.go index e75db8e1..af420fe8 100644 --- a/sm4/cipher_asm_test.go +++ b/sm4/cipher_asm_test.go @@ -25,7 +25,7 @@ func TestWithoutGFMUL(t *testing.T) { if useAVX2 { blocks = 8 } - c1 := &sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize} + c1 := &sm4CipherAsm{sm4Cipher{}, blocks, blocks * BlockSize} expandKeyAsm(&key[0], &ck[0], &c1.enc[0], &c1.dec[0], INST_AES) c = c1 } diff --git a/sm4/cipher_generic.go b/sm4/cipher_generic.go index 215280b4..d51a9ea7 100644 --- a/sm4/cipher_generic.go +++ b/sm4/cipher_generic.go @@ -12,9 +12,3 @@ import "crypto/cipher" func newCipher(key []byte) (cipher.Block, error) { return newCipherGeneric(key) } - -// expandKey is used by BenchmarkExpand and should -// call an assembly implementation if one is available. -func expandKey(key []byte, enc, dec []uint32) { - expandKeyGo(key, enc, dec) -} diff --git a/sm4/cipher_ni.go b/sm4/cipher_ni.go index efe458ba..3e2ee0af 100644 --- a/sm4/cipher_ni.go +++ b/sm4/cipher_ni.go @@ -13,12 +13,12 @@ type sm4CipherNI struct { } func newCipherNI(key []byte) (cipher.Block, error) { - c := &sm4CipherNI{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}} + c := &sm4CipherNIGCM{sm4CipherNI{sm4Cipher{}}} expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], INST_SM4) if supportsGFMUL { - return &sm4CipherNIGCM{c}, nil + return c, nil } - return c, nil + return &c.sm4CipherNI, nil } func (c *sm4CipherNI) Encrypt(dst, src []byte) { diff --git a/sm4/cipher_test.go b/sm4/cipher_test.go index c693c550..e01e8cd0 100644 --- a/sm4/cipher_test.go +++ b/sm4/cipher_test.go @@ -114,12 +114,3 @@ func BenchmarkDecrypt(b *testing.B) { c.Decrypt(out, tt.out) } } - -func BenchmarkExpand(b *testing.B) { - tt := encryptTests[0] - c := &sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)} - b.ResetTimer() - for i := 0; i < b.N; i++ { - expandKey(tt.key, c.enc, c.dec) - } -} diff --git a/sm4/sm4_gcm_asm.go b/sm4/sm4_gcm_asm.go index def90dd2..236dce4d 100644 --- a/sm4/sm4_gcm_asm.go +++ b/sm4/sm4_gcm_asm.go @@ -13,7 +13,7 @@ import ( // will use the optimised implementation in this file when possible. Instances // of this type only exist when hasGCMAsm and hasAES returns true. type sm4CipherGCM struct { - *sm4CipherAsm + sm4CipherAsm } // Assert that sm4CipherGCM implements the gcmAble interface. @@ -43,10 +43,10 @@ type gcmAsm struct { // called by crypto/cipher.NewGCM via the gcmAble interface. func (c *sm4CipherGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { g := &gcmAsm{} - g.cipher = c.sm4CipherAsm + g.cipher = &c.sm4CipherAsm g.nonceSize = nonceSize g.tagSize = tagSize - gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_AES) + gcmSm4Init(&g.bytesProductTable, g.cipher.enc[:], INST_AES) return g, nil } @@ -91,7 +91,7 @@ func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte { } if len(plaintext) > 0 { - gcmSm4Enc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc) + gcmSm4Enc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc[:]) } gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data))) copy(out[len(plaintext):], tagOut[:]) @@ -144,7 +144,7 @@ func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { panic("cipher: invalid buffer overlap") } if len(ciphertext) > 0 { - gcmSm4Dec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc) + gcmSm4Dec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc[:]) } gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data))) diff --git a/sm4/sm4ni_gcm_asm.go b/sm4/sm4ni_gcm_asm.go index 11ce21ee..0c50fe93 100644 --- a/sm4/sm4ni_gcm_asm.go +++ b/sm4/sm4ni_gcm_asm.go @@ -19,7 +19,7 @@ func gcmSm4niDec(productTable *[256]byte, dst, src []byte, ctr, T *[16]byte, rk // will use the optimised implementation in this file when possible. Instances // of this type only exist when hasGCMAsm and hasSM4 returns true. type sm4CipherNIGCM struct { - *sm4CipherNI + sm4CipherNI } // Assert that sm4CipherNIGCM implements the gcmAble interface. @@ -44,10 +44,10 @@ func (g *gcmNI) Overhead() int { // called by crypto/cipher.NewGCM via the gcmAble interface. func (c *sm4CipherNIGCM) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) { g := &gcmNI{} - g.cipher = c.sm4CipherNI + g.cipher = &c.sm4CipherNI g.nonceSize = nonceSize g.tagSize = tagSize - gcmSm4Init(&g.bytesProductTable, g.cipher.enc, INST_SM4) + gcmSm4Init(&g.bytesProductTable, g.cipher.enc[:], INST_SM4) return g, nil } @@ -84,7 +84,7 @@ func (g *gcmNI) Seal(dst, nonce, plaintext, data []byte) []byte { } if len(plaintext) > 0 { - gcmSm4niEnc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc) + gcmSm4niEnc(&g.bytesProductTable, out, plaintext, &counter, &tagOut, g.cipher.enc[:]) } gcmSm4Finish(&g.bytesProductTable, &tagMask, &tagOut, uint64(len(plaintext)), uint64(len(data))) copy(out[len(plaintext):], tagOut[:]) @@ -137,7 +137,7 @@ func (g *gcmNI) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { panic("cipher: invalid buffer overlap") } if len(ciphertext) > 0 { - gcmSm4niDec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc) + gcmSm4niDec(&g.bytesProductTable, out, ciphertext, &counter, &expectedTag, g.cipher.enc[:]) } gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data)))