From 515aa3125973264499aa079a0a26d1a50fccb8ad Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 15 Dec 2023 15:11:49 +0800 Subject: [PATCH] sm2: add comments and refactor --- sm2/sm2.go | 22 ++++++++++++++++------ sm2/sm2_envelopedkey.go | 24 +++++++++++------------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index cf3df1ec..3409ac65 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -620,7 +620,7 @@ func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, erro } func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) { - // get/compute inv(d+1) + // dp1Inv = (d+1)⁻¹ dp1Inv, err := priv.inverseOfPrivateKeyPlus1(c) if err != nil { return nil, err @@ -649,21 +649,27 @@ func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig if err != nil { return nil, err } - r.Add(e, c.N) // r = (Rx + e) mod N + + // r = [Rx + e] + r.Add(e, c.N) + + // checks if r is zero or [r+k] is zero if r.IsZero() == 0 { - t := bigmod.NewNat().Set(k) - t.Add(r, c.N) - if t.IsZero() == 0 { // if (r + k) != N then ok + t := bigmod.NewNat().Set(k).Add(r, c.N) + if t.IsZero() == 0 { break } } } + // s = [r * d] s, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) if err != nil { return nil, err } s.Mul(r, c.N) + // k = [k - s] k.Sub(s, c.N) + // k = [(d+1)⁻¹ * (k - r * d)] k.Mul(dp1Inv, c.N) if k.IsZero() == 0 { break @@ -738,21 +744,25 @@ func verifySM2EC(c *sm2Curve, pub *ecdsa.PublicKey, hash, sig []byte) bool { e := bigmod.NewNat() hashToNat(c, e, hash) + // t = [r + s] t := bigmod.NewNat().Set(r) t.Add(s, c.N) if t.IsZero() == 1 { return false } + // p₁ = [s]G p1, err := c.newPoint().ScalarBaseMult(s.Bytes(c.N)) if err != nil { return false } + // p₂ = [t]Q p2, err := Q.ScalarMult(Q, t.Bytes(c.N)) if err != nil { return false } + // BytesX returns an error for the point at infinity. Rx, err := p1.Add(p1, p2).BytesX() if err != nil { return false @@ -762,8 +772,8 @@ func verifySM2EC(c *sm2Curve, pub *ecdsa.PublicKey, hash, sig []byte) bool { if err != nil { return false } - v.Add(e, c.N) + return v.Equal(r) == 1 } diff --git a/sm2/sm2_envelopedkey.go b/sm2/sm2_envelopedkey.go index 5bd6c5d9..fb02355f 100644 --- a/sm2/sm2_envelopedkey.go +++ b/sm2/sm2_envelopedkey.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "math/big" "github.com/emmansun/gmsm/cipher" "github.com/emmansun/gmsm/sm4" @@ -82,9 +81,9 @@ func MarshalEnvelopedPrivateKey(rand io.Reader, pub *ecdsa.PublicKey, tobeEnvelo func ParseEnvelopedPrivateKey(priv *PrivateKey, enveloped []byte) (*PrivateKey, error) { // unmarshal the asn.1 data var ( - symAlgId pkix.AlgorithmIdentifier - encryptedPrivateKey, pub asn1.BitString - inner, symEncryptedKey, symAlgIdBytes cryptobyte.String + symAlgId pkix.AlgorithmIdentifier + encryptedPrivateKey, pub asn1.BitString + inner, symEncryptedKey, symAlgIdBytes cryptobyte.String ) input := cryptobyte.String(enveloped) if !input.ReadASN1(&inner, cryptobyte_asn1.SEQUENCE) || @@ -106,9 +105,9 @@ func ParseEnvelopedPrivateKey(priv *PrivateKey, enveloped []byte) (*PrivateKey, } // parse public key - x, y := elliptic.Unmarshal(P256(), pub.RightAlign()) - if x == nil || y == nil { - return nil, errors.New("sm2: invald public key in enveloped data") + pubKey, err := NewPublicKey(pub.RightAlign()) + if err != nil { + return nil, err } // decrypt symmetric cipher key @@ -127,12 +126,11 @@ func ParseEnvelopedPrivateKey(priv *PrivateKey, enveloped []byte) (*PrivateKey, plaintext := make([]byte, len(bytes)) mode.CryptBlocks(plaintext, bytes) // Do we need to check length in order to be compatible with some implementations with padding? - sm2Key := new(PrivateKey) - sm2Key.D = new(big.Int).SetBytes(plaintext) - sm2Key.Curve = P256() - sm2Key.X, sm2Key.Y = sm2Key.ScalarBaseMult(plaintext) - - if sm2Key.X.Cmp(x) != 0 || sm2Key.Y.Cmp(y) != 0 { + sm2Key, err := NewPrivateKey(plaintext) + if err != nil { + return nil, err + } + if !sm2Key.PublicKey.Equal(pubKey) { return nil, errors.New("sm2: mismatch key pair in enveloped data") }