From 149f443a86e100ec7801446038561864cd2dca96 Mon Sep 17 00:00:00 2001 From: Ezequiel Raynaudo Date: Wed, 26 Jun 2024 14:50:31 -0300 Subject: [PATCH] feat: add basic common code used in sdk and cometbft repos (#2) * Add files from cometBFT repo * Rename hash package * wip * Clean up curves (for now) * Add bls curve and cache util * Add github action * Fix action * Fix tests * Fix ci * Simpler ci * Fix lint * Add lint jobs * Add lint ci --- .github/workflows/build.yml | 36 +++ .github/workflows/lint.yml | 21 ++ .gitignore | 1 + Makefile | 35 +++ README.md | 9 + armor/armor.go | 52 ++++ armor/armor_test.go | 21 ++ curves/bls12381/alias.go | 9 + curves/bls12381/doc.go | 6 + curves/bls12381/helper_test.go | 15 + curves/bls12381/init.go | 27 ++ curves/bls12381/interface.go | 21 ++ curves/bls12381/pubkey.go | 76 +++++ curves/bls12381/pubkey_test.go | 97 ++++++ curves/bls12381/secret_key.go | 73 +++++ curves/bls12381/secret_key_test.go | 113 +++++++ curves/bls12381/signature.go | 77 +++++ curves/bls12381/signature_test.go | 171 +++++++++++ doc.go | 3 + go.mod | 18 ++ go.sum | 20 ++ hash/sha256/bench_test.go | 52 ++++ hash/sha256/hash.go | 78 +++++ hash/sha256/hash_test.go | 46 +++ internal/cache/cache.go | 144 +++++++++ internal/cache/list.go | 123 ++++++++ internal/libs/bytes/bytes.go | 65 ++++ internal/libs/bytes/bytes_test.go | 75 +++++ internal/libs/bytes/byteslice.go | 10 + internal/libs/json/decoder.go | 277 ++++++++++++++++++ internal/libs/json/decoder_test.go | 150 ++++++++++ internal/libs/json/doc.go | 98 +++++++ internal/libs/json/encoder.go | 257 ++++++++++++++++ internal/libs/json/encoder_test.go | 120 ++++++++ internal/libs/json/helpers_test.go | 92 ++++++ internal/libs/json/structs.go | 86 ++++++ internal/libs/json/types.go | 107 +++++++ internal/sync/deadlock.go | 18 ++ internal/sync/sync.go | 16 + random/random.go | 35 +++ random/random_test.go | 22 ++ symmetric/types.go | 7 + symmetric/xchacha20poly1305/vector_test.go | 122 ++++++++ symmetric/xchacha20poly1305/xchachapoly.go | 264 +++++++++++++++++ .../xchacha20poly1305/xchachapoly_test.go | 113 +++++++ symmetric/xsalsa20symmetric/symmetric.go | 60 ++++ symmetric/xsalsa20symmetric/symmetric_test.go | 36 +++ types/address.go | 20 ++ types/keys.go | 18 ++ 49 files changed, 3412 insertions(+) create mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/lint.yml create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 README.md create mode 100644 armor/armor.go create mode 100644 armor/armor_test.go create mode 100644 curves/bls12381/alias.go create mode 100644 curves/bls12381/doc.go create mode 100644 curves/bls12381/helper_test.go create mode 100644 curves/bls12381/init.go create mode 100644 curves/bls12381/interface.go create mode 100644 curves/bls12381/pubkey.go create mode 100644 curves/bls12381/pubkey_test.go create mode 100644 curves/bls12381/secret_key.go create mode 100644 curves/bls12381/secret_key_test.go create mode 100644 curves/bls12381/signature.go create mode 100644 curves/bls12381/signature_test.go create mode 100644 doc.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 hash/sha256/bench_test.go create mode 100644 hash/sha256/hash.go create mode 100644 hash/sha256/hash_test.go create mode 100644 internal/cache/cache.go create mode 100644 internal/cache/list.go create mode 100644 internal/libs/bytes/bytes.go create mode 100644 internal/libs/bytes/bytes_test.go create mode 100644 internal/libs/bytes/byteslice.go create mode 100644 internal/libs/json/decoder.go create mode 100644 internal/libs/json/decoder_test.go create mode 100644 internal/libs/json/doc.go create mode 100644 internal/libs/json/encoder.go create mode 100644 internal/libs/json/encoder_test.go create mode 100644 internal/libs/json/helpers_test.go create mode 100644 internal/libs/json/structs.go create mode 100644 internal/libs/json/types.go create mode 100644 internal/sync/deadlock.go create mode 100644 internal/sync/sync.go create mode 100644 random/random.go create mode 100644 random/random_test.go create mode 100644 symmetric/types.go create mode 100644 symmetric/xchacha20poly1305/vector_test.go create mode 100644 symmetric/xchacha20poly1305/xchachapoly.go create mode 100644 symmetric/xchacha20poly1305/xchachapoly_test.go create mode 100644 symmetric/xsalsa20symmetric/symmetric.go create mode 100644 symmetric/xsalsa20symmetric/symmetric_test.go create mode 100644 types/address.go create mode 100644 types/keys.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..aefa828 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,36 @@ +name: Build and Test +on: + pull_request: + merge_group: + push: + branches: + - main + - release/** +permissions: + contents: read + +concurrency: + group: ci-${{ github.ref }}-build + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "1.22.2" + - name: Build + run: make build + + test: + runs-on: ubuntu-latest + needs: build + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "1.22.2" + - name: Test + run: make test diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..19c31aa --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,21 @@ +name: Lint +on: + pull_request: + push: + branches: + - main + - release/** + +permissions: + contents: read + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "1.22.2" + - name: Run Linter + run: make lint \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..723ef36 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..14e117d --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +# Variables +PKG := ./... +GOFILES := $(shell find . -name '*.go' | grep -v _test.go) +TESTFILES := $(shell find . -name '*_test.go') +GOLANGCI_VERSION := v1.59.0 + +all: build + +build: + @echo "Building..." + @go build $(PKG) + +# Run tests +test: + @echo "Running tests..." + @go test -v $(PKG) + +# Install golangci-lint +lint-install: + @echo "--> Installing golangci-lint $(GOLANGCI_VERSION)" + @go install github.com/golangci/golangci-lint/cmd/golangci-lint@$(GOLANGCI_VERSION) + +# Run golangci-lint +lint: + @echo "--> Running linter" + $(MAKE) lint-install + @golangci-lint run --timeout=15m + +# Run golangci-lint and fix +lint-fix: + @echo "--> Running linter with fix" + $(MAKE) lint-install + @golangci-lint run --fix + +.PHONY: build test lint-install lint lint-fix diff --git a/README.md b/README.md new file mode 100644 index 0000000..7a0a557 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# crypto + +cosmos-crypto is the cryptographic package adapted for the interchain stack + +## Importing it + +To get the interfaces, +`import "github.com/cosmos/crypto/types"` + diff --git a/armor/armor.go b/armor/armor.go new file mode 100644 index 0000000..af61b20 --- /dev/null +++ b/armor/armor.go @@ -0,0 +1,52 @@ +package armor + +import ( + "bytes" + "fmt" + "io" + + "golang.org/x/crypto/openpgp/armor" //nolint: staticcheck +) + +// ErrEncode represents an error from calling [EncodeArmor]. +type ErrEncode struct { + Err error +} + +func (e ErrEncode) Error() string { + return fmt.Sprintf("armor: could not encode ASCII armor: %v", e.Err) +} + +func (e ErrEncode) Unwrap() error { + return e.Err +} + +func EncodeArmor(blockType string, headers map[string]string, data []byte) (string, error) { + buf := new(bytes.Buffer) + w, err := armor.Encode(buf, blockType, headers) + if err != nil { + return "", ErrEncode{Err: err} + } + _, err = w.Write(data) + if err != nil { + return "", ErrEncode{Err: err} + } + err = w.Close() + if err != nil { + return "", ErrEncode{Err: err} + } + return buf.String(), nil +} + +func DecodeArmor(armorStr string) (blockType string, headers map[string]string, data []byte, err error) { + buf := bytes.NewBufferString(armorStr) + block, err := armor.Decode(buf) + if err != nil { + return "", nil, nil, err + } + data, err = io.ReadAll(block.Body) + if err != nil { + return "", nil, nil, err + } + return block.Type, block.Header, data, nil +} diff --git a/armor/armor_test.go b/armor/armor_test.go new file mode 100644 index 0000000..051d285 --- /dev/null +++ b/armor/armor_test.go @@ -0,0 +1,21 @@ +package armor + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestArmor(t *testing.T) { + blockType := "MINT TEST" + data := []byte("somedata") + armorStr, err := EncodeArmor(blockType, nil, data) + require.NoError(t, err, "%+v", err) + + // Decode armorStr and test for equivalence. + blockType2, _, data2, err := DecodeArmor(armorStr) + require.NoError(t, err, "%+v", err) + assert.Equal(t, blockType, blockType2) + assert.Equal(t, data, data2) +} diff --git a/curves/bls12381/alias.go b/curves/bls12381/alias.go new file mode 100644 index 0000000..6682f30 --- /dev/null +++ b/curves/bls12381/alias.go @@ -0,0 +1,9 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst + +import blst "github.com/supranational/blst/bindings/go" + +// Internal types for blst. +type blstPublicKey = blst.P1Affine +type blstSignature = blst.P2Affine diff --git a/curves/bls12381/doc.go b/curves/bls12381/doc.go new file mode 100644 index 0000000..b5f8850 --- /dev/null +++ b/curves/bls12381/doc.go @@ -0,0 +1,6 @@ +// Package blst implements a go-wrapper around a library implementing the +// BLS12-381 curve and signature scheme. This package exposes a public API for +// verifying and aggregating BLS signatures used by Ethereum. +// +// This implementation uses the library written by Supranational, blst. +package blst diff --git a/curves/bls12381/helper_test.go b/curves/bls12381/helper_test.go new file mode 100644 index 0000000..20cf48c --- /dev/null +++ b/curves/bls12381/helper_test.go @@ -0,0 +1,15 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst + +// Note: These functions are for tests to access private globals, such as pubkeyCache. + +// DisableCaches sets the cache sizes to 0. +func DisableCaches() { + pubkeyCache.Resize(0) +} + +// EnableCaches sets the cache sizes to the default values. +func EnableCaches() { + pubkeyCache.Resize(maxKeys) +} diff --git a/curves/bls12381/init.go b/curves/bls12381/init.go new file mode 100644 index 0000000..3112fa0 --- /dev/null +++ b/curves/bls12381/init.go @@ -0,0 +1,27 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst + +import ( + "fmt" + "runtime" + + blst "github.com/supranational/blst/bindings/go" + + "github.com/cosmos/crypto/internal/cache" +) + +func init() { + // Reserve 1 core for general application work + maxProcs := runtime.GOMAXPROCS(0) - 1 + if maxProcs <= 0 { + maxProcs = 1 + } + blst.SetMaxProcs(maxProcs) + onEvict := func(_ [48]byte, _ PubKey) {} + keysCache, err := cache.NewLRU(maxKeys, onEvict) + if err != nil { + panic(fmt.Sprintf("Could not initiate public keys cache: %v", err)) + } + pubkeyCache = keysCache +} diff --git a/curves/bls12381/interface.go b/curves/bls12381/interface.go new file mode 100644 index 0000000..1e8265c --- /dev/null +++ b/curves/bls12381/interface.go @@ -0,0 +1,21 @@ +package blst + +type PubKey interface { + Marshal() []byte + Copy() PubKey + Equals(p2 PubKey) bool +} + +// SignatureI represents a BLS signature. +type SignatureI interface { + Verify(pubKey PubKey, msg []byte) bool + Marshal() []byte + Copy() SignatureI +} + +// SecretKey represents a BLS secret or private key. +type SecretKey interface { + PublicKey() PubKey + Sign(msg []byte) SignatureI + Marshal() []byte +} diff --git a/curves/bls12381/pubkey.go b/curves/bls12381/pubkey.go new file mode 100644 index 0000000..5f8590d --- /dev/null +++ b/curves/bls12381/pubkey.go @@ -0,0 +1,76 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst + +import ( + "errors" + "fmt" + + "github.com/cosmos/crypto/internal/cache" +) + +const ( + SignatureLength = 96 + PubkeyLength = 48 // PubkeyLength defines the byte length of a BLSSignature. +) + +var maxKeys = 2_000_000 +var pubkeyCache *cache.LRU[[48]byte, PubKey] + +// PublicKey used in the BLS signature scheme. +type PublicKey struct { + p *blstPublicKey +} + +// Marshal a public key into a LittleEndian byte slice. +func (p *PublicKey) Marshal() []byte { + return p.p.Compress() +} + +// Copy the public key to a new pointer reference. +func (p *PublicKey) Copy() PubKey { + np := *p.p + return &PublicKey{p: &np} +} + +// Equals checks if the provided public key is equal to +// the current one. +func (p *PublicKey) Equals(p2 PubKey) bool { + return p.p.Equals(p2.(*PublicKey).p) +} + +// PublicKeyFromBytes creates a BLS public key from a BigEndian byte slice. +func PublicKeyFromBytes(pubKey []byte) (PubKey, error) { + return publicKeyFromBytes(pubKey, true) +} + +func publicKeyFromBytes(pubKey []byte, cacheCopy bool) (PubKey, error) { + if len(pubKey) != PubkeyLength { //TODO: make this a parameter + return nil, fmt.Errorf("public key must be %d bytes", PubkeyLength) + } + + newKey := (*[PubkeyLength]byte)(pubKey) + if cv, ok := pubkeyCache.Get(*newKey); ok { + if cacheCopy { + return cv.Copy(), nil + } + return cv, nil + } + + // Subgroup check NOT done when decompressing pubkey. + p := new(blstPublicKey).Uncompress(pubKey) + if p == nil { + return nil, errors.New("could not unmarshal bytes into public key") + } + // Subgroup and infinity check + if !p.KeyValidate() { + // NOTE: the error is not quite accurate since it includes group check + return nil, errors.New("publickey is infinite") + } + + pubKeyObj := &PublicKey{p: p} + copiedKey := pubKeyObj.Copy() + cacheKey := *newKey + pubkeyCache.Add(cacheKey, copiedKey) + return pubKeyObj, nil +} diff --git a/curves/bls12381/pubkey_test.go b/curves/bls12381/pubkey_test.go new file mode 100644 index 0000000..fd656ea --- /dev/null +++ b/curves/bls12381/pubkey_test.go @@ -0,0 +1,97 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst_test + +import ( + "bytes" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + blst "github.com/cosmos/crypto/curves/bls12381" +) + +func TestPublicKeyFromBytes(t *testing.T) { + tests := []struct { + name string + input []byte + err error + }{ + { + name: "Nil", + err: errors.New("public key must be 48 bytes"), + }, + { + name: "Empty", + input: []byte{}, + err: errors.New("public key must be 48 bytes"), + }, + { + name: "Short", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("public key must be 48 bytes"), + }, + { + name: "Long", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("public key must be 48 bytes"), + }, + { + name: "Bad", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("could not unmarshal bytes into public key"), + }, + { + name: "Good", + input: []byte{0xa9, 0x9a, 0x76, 0xed, 0x77, 0x96, 0xf7, 0xbe, 0x22, 0xd5, 0xb7, 0xe8, 0x5d, 0xee, 0xb7, 0xc5, 0x67, 0x7e, 0x88, 0xe5, 0x11, 0xe0, 0xb3, 0x37, 0x61, 0x8f, 0x8c, 0x4e, 0xb6, 0x13, 0x49, 0xb4, 0xbf, 0x2d, 0x15, 0x3f, 0x64, 0x9f, 0x7b, 0x53, 0x35, 0x9f, 0xe8, 0xb9, 0x4a, 0x38, 0xe4, 0x4c}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res, err := blst.PublicKeyFromBytes(test.input) + if test.err != nil { + assert.NotEqual(t, nil, err, "No error returned") + assert.ErrorContains(t, test.err, err.Error(), "Unexpected error returned") + } else { + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare(res.Marshal(), test.input)) + } + }) + } +} + +func TestPublicKey_Copy(t *testing.T) { + priv, err := blst.RandKey() + require.NoError(t, err) + pubkeyA := priv.PublicKey() + pubkeyBytes := pubkeyA.Marshal() + + require.Equal(t, pubkeyA.Marshal(), pubkeyBytes, "Pubkey was mutated after copy") +} + +func BenchmarkPublicKeyFromBytes(b *testing.B) { + priv, err := blst.RandKey() + require.NoError(b, err) + pubkey := priv.PublicKey() + pubkeyBytes := pubkey.Marshal() + + b.Run("cache on", func(b *testing.B) { + blst.EnableCaches() + for i := 0; i < b.N; i++ { + _, err := blst.PublicKeyFromBytes(pubkeyBytes) + require.NoError(b, err) + } + }) + + b.Run("cache off", func(b *testing.B) { + // blst.DisableCaches() + for i := 0; i < b.N; i++ { + _, err := blst.PublicKeyFromBytes(pubkeyBytes) + require.NoError(b, err) + } + }) + +} diff --git a/curves/bls12381/secret_key.go b/curves/bls12381/secret_key.go new file mode 100644 index 0000000..deaf0a4 --- /dev/null +++ b/curves/bls12381/secret_key.go @@ -0,0 +1,73 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst + +import ( + "crypto/subtle" + "errors" + "fmt" + + blst "github.com/supranational/blst/bindings/go" +) + +// bls12SecretKey used in the BLS signature scheme. +type bls12SecretKey struct { + p *blst.SecretKey +} + +// RandKey creates a new private key using a random method provided as an io.Reader. +func RandKey() (SecretKey, error) { + // Generate 32 bytes of randomness + var ikm [32]byte + _, err := rand.NewGenerator().Read(ikm[:]) + if err != nil { + return nil, err + } + // Defensive check, that we have not generated a secret key, + secKey := &bls12SecretKey{blst.KeyGen(ikm[:])} + if IsZero(secKey.Marshal()) { + return nil, errors.New("received secret key is zero") + } + return secKey, nil +} + +// SecretKeyFromBytes creates a BLS private key from a BigEndian byte slice. +func SecretKeyFromBytes(privKey []byte) (SecretKey, error) { + if len(privKey) != 32 { + return nil, fmt.Errorf("secret key must be %d bytes", 32) + } + if IsZero(privKey) { + return nil, errors.New("received secret key is zero") + } + secKey := new(blst.SecretKey).Deserialize(privKey) + if secKey == nil { + return nil, errors.New("could not unmarshal bytes into secret key") + } + wrappedKey := &bls12SecretKey{p: secKey} + return wrappedKey, nil +} + +// IsZero checks if the secret key is a zero key. +func IsZero(sKey []byte) bool { + b := byte(0) + for _, s := range sKey { + b |= s + } + return subtle.ConstantTimeByteEq(b, 0) == 1 +} + +func (s *bls12SecretKey) Sign(msg []byte) SignatureI { + signature := new(blstSignature).Sign(s.p, msg, dst) + return &Signature{s: signature} +} + +// Marshal a secret key into a LittleEndian byte slice. +func (s *bls12SecretKey) Marshal() []byte { + keyBytes := s.p.Serialize() + return keyBytes +} + +// PublicKey obtains the public key corresponding to the BLS secret key. +func (s *bls12SecretKey) PublicKey() PubKey { + return &PublicKey{p: new(blstPublicKey).From(s.p)} +} diff --git a/curves/bls12381/secret_key_test.go b/curves/bls12381/secret_key_test.go new file mode 100644 index 0000000..bfea916 --- /dev/null +++ b/curves/bls12381/secret_key_test.go @@ -0,0 +1,113 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst_test + +import ( + "bytes" + "crypto/rand" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + blst "github.com/cosmos/crypto/curves/bls12381" +) + +func TestMarshalUnmarshal(t *testing.T) { + priv, err := blst.RandKey() + require.NoError(t, err) + b := priv.Marshal() + b32 := ToBytes32(b) + pk, err := blst.SecretKeyFromBytes(b32[:]) + require.NoError(t, err) + pk2, err := blst.SecretKeyFromBytes(b32[:]) + require.NoError(t, err) + assert.Equal(t, pk.Marshal(), pk2.Marshal(), "Keys not equal") +} + +func TestSecretKeyFromBytes(t *testing.T) { + tests := []struct { + name string + input []byte + err error + }{ + { + name: "Nil", + err: errors.New("secret key must be 32 bytes"), + }, + { + name: "Empty", + input: []byte{}, + err: errors.New("secret key must be 32 bytes"), + }, + { + name: "Short", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("secret key must be 32 bytes"), + }, + { + name: "Long", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("secret key must be 32 bytes"), + }, + { + name: "Bad", + input: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + err: errors.New("could not unmarshal bytes into secret key"), + }, + { + name: "Good", + input: []byte{0x25, 0x29, 0x5f, 0x0d, 0x1d, 0x59, 0x2a, 0x90, 0xb3, 0x33, 0xe2, 0x6e, 0x85, 0x14, 0x97, 0x08, 0x20, 0x8e, 0x9f, 0x8e, 0x8b, 0xc1, 0x8f, 0x6c, 0x77, 0xbd, 0x62, 0xf8, 0xad, 0x7a, 0x68, 0x66}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res, err := blst.SecretKeyFromBytes(test.input) + if test.err != nil { + assert.NotEqual(t, nil, err, "No error returned") + assert.ErrorContains(t, test.err, err.Error(), "Unexpected error returned") + } else { + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare(res.Marshal(), test.input)) + } + }) + } +} + +func TestSerialize(t *testing.T) { + rk, err := blst.RandKey() + require.NoError(t, err) + b := rk.Marshal() + + _, err = blst.SecretKeyFromBytes(b) + assert.NoError(t, err) +} + +func TestZeroKey(t *testing.T) { + // Is Zero + var zKey [32]byte + assert.Equal(t, true, blst.IsZero(zKey[:])) + + // Is Not Zero + _, err := rand.Read(zKey[:]) + assert.NoError(t, err) + assert.Equal(t, false, blst.IsZero(zKey[:])) +} + +// PadTo pads a byte slice to the given size. If the byte slice is larger than the given size, the +// original slice is returned. +func PadTo(b []byte, size int) []byte { + if len(b) >= size { + return b + } + return append(b, make([]byte, size-len(b))...) +} + +// ToBytes32 is a convenience method for converting a byte slice to a fix +// sized 32 byte array. This method will truncate the input if it is larger +// than 32 bytes. +func ToBytes32(x []byte) [32]byte { + return [32]byte(PadTo(x, 32)) +} diff --git a/curves/bls12381/signature.go b/curves/bls12381/signature.go new file mode 100644 index 0000000..a5a8655 --- /dev/null +++ b/curves/bls12381/signature.go @@ -0,0 +1,77 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst + +import ( + "errors" + "fmt" +) + +var dst = []byte("BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_") + +// Signature used in the BLS signature scheme. +type Signature struct { + s *blstSignature +} + +// Marshal a signature into a LittleEndian byte slice. +func (s *Signature) Marshal() []byte { + return s.s.Compress() +} + +// Copy returns a full deep copy of a signature. +func (s *Signature) Copy() SignatureI { + sign := *s.s + return &Signature{s: &sign} +} + +func (s *Signature) Verify(pubKey PubKey, msg []byte) bool { + // Signature and PKs are assumed to have been validated upon decompression! + return s.s.Verify(false, pubKey.(*PublicKey).p, false, msg, dst) +} + +// VerifySignature verifies a single signature using public key and message. +func VerifySignature(sig []byte, msg [32]byte, pubKey PubKey) (bool, error) { + rSig, err := SignatureFromBytes(sig) + if err != nil { + return false, err + } + return rSig.Verify(pubKey, msg[:]), nil +} + +// signatureFromBytesNoValidation creates a BLS signature from a LittleEndian +// byte slice. It does not validate that the signature is in the BLS group +func signatureFromBytesNoValidation(sig []byte) (*blstSignature, error) { + if len(sig) != SignatureLength { + return nil, fmt.Errorf("signature must be %d bytes", SignatureLength) + } + signature := new(blstSignature).Uncompress(sig) + if signature == nil { + return nil, errors.New("could not unmarshal bytes into signature") + } + return signature, nil +} + +// SignatureFromBytesNoValidation creates a BLS signature from a LittleEndian +// byte slice. It does not validate that the signature is in the BLS group +func SignatureFromBytesNoValidation(sig []byte) (SignatureI, error) { + signature, err := signatureFromBytesNoValidation(sig) + if err != nil { + return nil, fmt.Errorf("could not create signature from byte slice: %w", err) + } + return &Signature{s: signature}, nil +} + +// SignatureFromBytes creates a BLS signature from a LittleEndian byte slice. +func SignatureFromBytes(sig []byte) (SignatureI, error) { + signature, err := signatureFromBytesNoValidation(sig) + if err != nil { + return nil, fmt.Errorf("could not create signature from byte slice: %w", err) + } + // Group check signature. Do not check for infinity since an aggregated signature + // could be infinite. + if !signature.SigValidate(false) { + return nil, errors.New("signature not in group") + } + return &Signature{s: signature}, nil +} diff --git a/curves/bls12381/signature_test.go b/curves/bls12381/signature_test.go new file mode 100644 index 0000000..318adc9 --- /dev/null +++ b/curves/bls12381/signature_test.go @@ -0,0 +1,171 @@ +//go:build ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) && bls12381 + +package blst + +import ( + "bytes" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSignVerify(t *testing.T) { + priv, err := RandKey() + require.NoError(t, err) + pub := priv.PublicKey() + msg := []byte("hello") + sig := priv.Sign(msg) + assert.Equal(t, true, sig.Verify(pub, msg), "Signature did not verify") +} + +func TestVerifySingleSignature_InvalidSignature(t *testing.T) { + priv, err := RandKey() + require.NoError(t, err) + pub := priv.PublicKey() + msgA := [32]byte{'h', 'e', 'l', 'l', 'o'} + msgB := [32]byte{'o', 'l', 'l', 'e', 'h'} + sigA := priv.Sign(msgA[:]).Marshal() + valid, err := VerifySignature(sigA, msgB, pub) + assert.NoError(t, err) + assert.Equal(t, false, valid, "Signature did verify") +} + +func TestVerifySingleSignature_ValidSignature(t *testing.T) { + priv, err := RandKey() + require.NoError(t, err) + pub := priv.PublicKey() + msg := [32]byte{'h', 'e', 'l', 'l', 'o'} + sig := priv.Sign(msg[:]).Marshal() + valid, err := VerifySignature(sig, msg, pub) + assert.NoError(t, err) + assert.Equal(t, true, valid, "Signature did not verify") +} + +func TestSignatureFromBytes(t *testing.T) { + tests := []struct { + name string + input []byte + err error + }{ + { + name: "Nil", + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Empty", + input: []byte{}, + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Short", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Long", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Bad", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("could not unmarshal bytes into signature"), + }, + { + input: []byte{0xac, 0xb0, 0x12, 0x4c, 0x75, 0x74, 0xf2, 0x81, 0xa2, 0x93, 0xf4, 0x18, 0x5c, 0xad, 0x3c, 0xb2, 0x26, 0x81, 0xd5, 0x20, 0x91, 0x7c, 0xe4, 0x66, 0x65, 0x24, 0x3e, 0xac, 0xb0, 0x51, 0x00, 0x0d, 0x8b, 0xac, 0xf7, 0x5e, 0x14, 0x51, 0x87, 0x0c, 0xa6, 0xb3, 0xb9, 0xe6, 0xc9, 0xd4, 0x1a, 0x7b, 0x02, 0xea, 0xd2, 0x68, 0x5a, 0x84, 0x18, 0x8a, 0x4f, 0xaf, 0xd3, 0x82, 0x5d, 0xaf, 0x6a, 0x98, 0x96, 0x25, 0xd7, 0x19, 0xcc, 0xd2, 0xd8, 0x3a, 0x40, 0x10, 0x1f, 0x4a, 0x45, 0x3f, 0xca, 0x62, 0x87, 0x8c, 0x89, 0x0e, 0xca, 0x62, 0x23, 0x63, 0xf9, 0xdd, 0xb8, 0xf3, 0x67, 0xa9, 0x1e, 0x84}, + name: "Not in group", + err: errors.New("signature not in group"), + }, + { + name: "Good", + input: []byte{0xab, 0xb0, 0x12, 0x4c, 0x75, 0x74, 0xf2, 0x81, 0xa2, 0x93, 0xf4, 0x18, 0x5c, 0xad, 0x3c, 0xb2, 0x26, 0x81, 0xd5, 0x20, 0x91, 0x7c, 0xe4, 0x66, 0x65, 0x24, 0x3e, 0xac, 0xb0, 0x51, 0x00, 0x0d, 0x8b, 0xac, 0xf7, 0x5e, 0x14, 0x51, 0x87, 0x0c, 0xa6, 0xb3, 0xb9, 0xe6, 0xc9, 0xd4, 0x1a, 0x7b, 0x02, 0xea, 0xd2, 0x68, 0x5a, 0x84, 0x18, 0x8a, 0x4f, 0xaf, 0xd3, 0x82, 0x5d, 0xaf, 0x6a, 0x98, 0x96, 0x25, 0xd7, 0x19, 0xcc, 0xd2, 0xd8, 0x3a, 0x40, 0x10, 0x1f, 0x4a, 0x45, 0x3f, 0xca, 0x62, 0x87, 0x8c, 0x89, 0x0e, 0xca, 0x62, 0x23, 0x63, 0xf9, 0xdd, 0xb8, 0xf3, 0x67, 0xa9, 0x1e, 0x84}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res, err := SignatureFromBytes(test.input) + if test.err != nil { + assert.NotEqual(t, nil, err, "No error returned") + assert.ErrorContains(t, test.err, err.Error(), "Unexpected error returned") + } else { + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare(res.Marshal(), test.input)) + } + }) + } +} + +func TestSignatureFromBytesNoValidation(t *testing.T) { + tests := []struct { + name string + input []byte + err error + }{ + { + name: "Nil", + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Empty", + input: []byte{}, + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Short", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Long", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("signature must be 96 bytes"), + }, + { + name: "Bad", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + err: errors.New("could not unmarshal bytes into signature"), + }, + { + name: "Not in group", + input: []byte{0xac, 0xb0, 0x12, 0x4c, 0x75, 0x74, 0xf2, 0x81, 0xa2, 0x93, 0xf4, 0x18, 0x5c, 0xad, 0x3c, 0xb2, 0x26, 0x81, 0xd5, 0x20, 0x91, 0x7c, 0xe4, 0x66, 0x65, 0x24, 0x3e, 0xac, 0xb0, 0x51, 0x00, 0x0d, 0x8b, 0xac, 0xf7, 0x5e, 0x14, 0x51, 0x87, 0x0c, 0xa6, 0xb3, 0xb9, 0xe6, 0xc9, 0xd4, 0x1a, 0x7b, 0x02, 0xea, 0xd2, 0x68, 0x5a, 0x84, 0x18, 0x8a, 0x4f, 0xaf, 0xd3, 0x82, 0x5d, 0xaf, 0x6a, 0x98, 0x96, 0x25, 0xd7, 0x19, 0xcc, 0xd2, 0xd8, 0x3a, 0x40, 0x10, 0x1f, 0x4a, 0x45, 0x3f, 0xca, 0x62, 0x87, 0x8c, 0x89, 0x0e, 0xca, 0x62, 0x23, 0x63, 0xf9, 0xdd, 0xb8, 0xf3, 0x67, 0xa9, 0x1e, 0x84}, + }, + { + name: "Good", + input: []byte{0xab, 0xb0, 0x12, 0x4c, 0x75, 0x74, 0xf2, 0x81, 0xa2, 0x93, 0xf4, 0x18, 0x5c, 0xad, 0x3c, 0xb2, 0x26, 0x81, 0xd5, 0x20, 0x91, 0x7c, 0xe4, 0x66, 0x65, 0x24, 0x3e, 0xac, 0xb0, 0x51, 0x00, 0x0d, 0x8b, 0xac, 0xf7, 0x5e, 0x14, 0x51, 0x87, 0x0c, 0xa6, 0xb3, 0xb9, 0xe6, 0xc9, 0xd4, 0x1a, 0x7b, 0x02, 0xea, 0xd2, 0x68, 0x5a, 0x84, 0x18, 0x8a, 0x4f, 0xaf, 0xd3, 0x82, 0x5d, 0xaf, 0x6a, 0x98, 0x96, 0x25, 0xd7, 0x19, 0xcc, 0xd2, 0xd8, 0x3a, 0x40, 0x10, 0x1f, 0x4a, 0x45, 0x3f, 0xca, 0x62, 0x87, 0x8c, 0x89, 0x0e, 0xca, 0x62, 0x23, 0x63, 0xf9, 0xdd, 0xb8, 0xf3, 0x67, 0xa9, 0x1e, 0x84}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + res, err := SignatureFromBytesNoValidation(test.input) + if test.err != nil { + assert.NotEqual(t, nil, err, "No error returned") + assert.ErrorContains(t, test.err, err.Error(), "Unexpected error returned") + } else { + assert.NoError(t, err) + assert.Equal(t, 0, bytes.Compare(res.Marshal(), test.input)) + } + }) + } +} + +func TestCopy(t *testing.T) { + priv, err := RandKey() + require.NoError(t, err) + key, ok := priv.(*bls12SecretKey) + require.Equal(t, true, ok) + + signatureA := &Signature{s: new(blstSignature).Sign(key.p, []byte("foo"), dst)} + signatureB, ok := signatureA.Copy().(*Signature) + require.Equal(t, true, ok) + + assert.NotEqual(t, signatureA, signatureB) + assert.NotEqual(t, signatureA.s, signatureB.s) + assert.Equal(t, signatureA, signatureB) + + signatureA.s.Sign(key.p, []byte("bar"), dst) + assert.NotEqual(t, signatureA, signatureB) +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..097fb6a --- /dev/null +++ b/doc.go @@ -0,0 +1,3 @@ +package crypto + +// TODO: Add more docs in here diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..1eb0484 --- /dev/null +++ b/go.mod @@ -0,0 +1,18 @@ +module github.com/cosmos/crypto + +go 1.22.2 + +require ( + github.com/sasha-s/go-deadlock v0.3.1 + github.com/stretchr/testify v1.9.0 + github.com/supranational/blst v0.3.12 + golang.org/x/crypto v0.24.0 +) + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + golang.org/x/sys v0.21.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..cfdcc7c --- /dev/null +++ b/go.sum @@ -0,0 +1,20 @@ +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5 h1:q2e307iGHPdTGp0hoxKjt1H5pDo6utceo3dQVK3I5XQ= +github.com/petermattis/goid v0.0.0-20180202154549-b0b1615b78e5/go.mod h1:jvVRKCrJTQWu0XVbaOlby/2lO20uSCHEMzzplHXte1o= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sasha-s/go-deadlock v0.3.1 h1:sqv7fDNShgjcaxkO0JNcOAlr8B9+cV5Ey/OB71efZx0= +github.com/sasha-s/go-deadlock v0.3.1/go.mod h1:F73l+cr82YSh10GxyRI6qZiCgK64VaZjwesgfQ1/iLM= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/supranational/blst v0.3.12 h1:Vfas2U2CFHhniv2QkUm2OVa1+pGTdqtpqm9NnhUUbZ8= +github.com/supranational/blst v0.3.12/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hash/sha256/bench_test.go b/hash/sha256/bench_test.go new file mode 100644 index 0000000..0ff0867 --- /dev/null +++ b/hash/sha256/bench_test.go @@ -0,0 +1,52 @@ +package sha256 + +import ( + "bytes" + "crypto/sha256" + "strings" + "testing" +) + +var sink any + +var manySlices = []struct { + name string + in [][]byte + want [32]byte +}{ + { + name: "all empty", + in: [][]byte{[]byte(""), []byte("")}, + want: sha256.Sum256(nil), + }, + { + name: "ax6", + in: [][]byte{[]byte("aaaa"), []byte("😎"), []byte("aaaa")}, + want: sha256.Sum256([]byte("aaaa😎aaaa")), + }, + { + name: "composite joined", + in: [][]byte{bytes.Repeat([]byte("a"), 1<<10), []byte("AA"), bytes.Repeat([]byte("z"), 100)}, + want: sha256.Sum256([]byte(strings.Repeat("a", 1<<10) + "AA" + strings.Repeat("z", 100))), + }, +} + +func BenchmarkSHA256Many(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tt := range manySlices { + got := SumMany(tt.in[0], tt.in[1:]...) + if !bytes.Equal(got, tt.want[:]) { + b.Fatalf("Outward checksum mismatch for %q\n\tGot: %x\n\tWant: %x", tt.name, got, tt.want) + } + sink = got + } + } + + if sink == nil { + b.Fatal("Benchmark did not run!") + } + + sink = nil +} diff --git a/hash/sha256/hash.go b/hash/sha256/hash.go new file mode 100644 index 0000000..277a3a8 --- /dev/null +++ b/hash/sha256/hash.go @@ -0,0 +1,78 @@ +package sha256 + +import ( + "crypto/sha256" + "hash" +) + +const ( + Size = sha256.Size + BlockSize = sha256.BlockSize +) + +// New returns a new hash.Hash. +func New() hash.Hash { + return sha256.New() +} + +// Sum returns the SHA256 of the bz. +func Sum(bz []byte) []byte { + h := sha256.Sum256(bz) + return h[:] +} + +// SumMany takes at least 1 byteslice along with a variadic +// number of other byteslices and produces the SHA256 sum from +// hashing them as if they were 1 joined slice. +func SumMany(data []byte, rest ...[]byte) []byte { + h := sha256.New() + h.Write(data) + for _, data := range rest { + h.Write(data) + } + return h.Sum(nil) +} + +// ------------------------------------------------------------- + +const ( + TruncatedSize = 20 +) + +type sha256trunc struct { + sha256 hash.Hash +} + +func (h sha256trunc) Write(p []byte) (n int, err error) { + return h.sha256.Write(p) +} + +func (h sha256trunc) Sum(b []byte) []byte { + shasum := h.sha256.Sum(b) + return shasum[:TruncatedSize] +} + +func (h sha256trunc) Reset() { + h.sha256.Reset() +} + +func (sha256trunc) Size() int { + return TruncatedSize +} + +func (h sha256trunc) BlockSize() int { + return h.sha256.BlockSize() +} + +// NewTruncated returns a new hash.Hash. +func NewTruncated() hash.Hash { + return sha256trunc{ + sha256: sha256.New(), + } +} + +// SumTruncated returns the first 20 bytes of SHA256 of the bz. +func SumTruncated(bz []byte) []byte { + h := sha256.Sum256(bz) + return h[:TruncatedSize] +} diff --git a/hash/sha256/hash_test.go b/hash/sha256/hash_test.go new file mode 100644 index 0000000..5d66676 --- /dev/null +++ b/hash/sha256/hash_test.go @@ -0,0 +1,46 @@ +package sha256 + +import ( + "crypto/sha256" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHash(t *testing.T) { + testVector := []byte("abc") + hasher := New() + _, err := hasher.Write(testVector) + require.NoError(t, err) + bz := hasher.Sum(nil) + + bz2 := Sum(testVector) + + hasher = sha256.New() + _, err = hasher.Write(testVector) + require.NoError(t, err) + bz3 := hasher.Sum(nil) + + assert.Equal(t, bz, bz2) + assert.Equal(t, bz, bz3) +} + +func TestHashTruncated(t *testing.T) { + testVector := []byte("abc") + hasher := NewTruncated() + _, err := hasher.Write(testVector) + require.NoError(t, err) + bz := hasher.Sum(nil) + + bz2 := SumTruncated(testVector) + + hasher = sha256.New() + _, err = hasher.Write(testVector) + require.NoError(t, err) + bz3 := hasher.Sum(nil) + bz3 = bz3[:TruncatedSize] + + assert.Equal(t, bz, bz2) + assert.Equal(t, bz, bz3) +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..066e8b8 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,144 @@ +package cache + +import ( + "errors" + "sync" +) + +// EvictCallback is used to get a callback when a cache entry is evicted. +type EvictCallback[K comparable, V any] func(key K, value V) + +// LRU implements a non-thread safe fixed size LRU cache. +type LRU[K comparable, V any] struct { + itemsLock sync.RWMutex + evictListLock sync.RWMutex + size int + evictList *lruList[K, V] + items map[K]*entry[K, V] + onEvict EvictCallback[K, V] + getChan chan *entry[K, V] +} + +// NewLRU constructs an LRU of the given size. +func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V]) (*LRU[K, V], error) { + if size <= 0 { + return nil, errors.New("must provide a positive size") + } + // Initialize the channel buffer size as being 10% of the cache size. + chanSize := size / 10 + + c := &LRU[K, V]{ + size: size, + evictList: newList[K, V](), + items: make(map[K]*entry[K, V]), + onEvict: onEvict, + getChan: make(chan *entry[K, V], chanSize), + } + // Spin off separate go-routine to handle evict list + // operations. + go c.handleGetRequests() + return c, nil +} + +// Add adds a value to the cache. Returns true if an eviction occurred. +func (c *LRU[K, V]) Add(key K, value V) (evicted bool) { + // Check for existing item + c.itemsLock.RLock() + if ent, ok := c.items[key]; ok { + c.itemsLock.RUnlock() + + c.evictListLock.Lock() + c.evictList.moveToFront(ent) + c.evictListLock.Unlock() + ent.value = value + return false + } + c.itemsLock.RUnlock() + + // Add new item + c.evictListLock.Lock() + ent := c.evictList.pushFront(key, value) + c.evictListLock.Unlock() + + c.itemsLock.Lock() + c.items[key] = ent + c.itemsLock.Unlock() + + c.evictListLock.RLock() + evict := c.evictList.length() > c.size + c.evictListLock.RUnlock() + + // Verify size not exceeded + if evict { + c.removeOldest() + } + return evict +} + +// Get looks up a key's value from the cache. +func (c *LRU[K, V]) Get(key K) (value V, ok bool) { + c.itemsLock.RLock() + if ent, ok := c.items[key]; ok { + c.itemsLock.RUnlock() + + // Make this get function non-blocking for multiple readers. + c.getChan <- ent + return ent.value, true + } + c.itemsLock.RUnlock() + return +} + +// Len returns the number of items in the cache. +func (c *LRU[K, V]) Len() int { + c.evictListLock.RLock() + defer c.evictListLock.RUnlock() + return c.evictList.length() +} + +// Resize changes the cache size. +func (c *LRU[K, V]) Resize(size int) (evicted int) { + diff := c.Len() - size + if diff < 0 { + diff = 0 + } + for i := 0; i < diff; i++ { + c.removeOldest() + } + c.size = size + return diff +} + +// removeOldest removes the oldest item from the cache. +func (c *LRU[K, V]) removeOldest() { + c.evictListLock.RLock() + if ent := c.evictList.back(); ent != nil { + c.evictListLock.RUnlock() + c.removeElement(ent) + return + } + c.evictListLock.RUnlock() +} + +// removeElement is used to remove a given list element from the cache. +func (c *LRU[K, V]) removeElement(e *entry[K, V]) { + c.evictListLock.Lock() + c.evictList.remove(e) + c.evictListLock.Unlock() + + c.itemsLock.Lock() + delete(c.items, e.key) + c.itemsLock.Unlock() + if c.onEvict != nil { + c.onEvict(e.key, e.value) + } +} + +func (c *LRU[K, V]) handleGetRequests() { + for { + entry := <-c.getChan + c.evictListLock.Lock() + c.evictList.moveToFront(entry) + c.evictListLock.Unlock() + } +} diff --git a/internal/cache/list.go b/internal/cache/list.go new file mode 100644 index 0000000..5bf4f56 --- /dev/null +++ b/internal/cache/list.go @@ -0,0 +1,123 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE_list file. +package cache + +// entry is an LRU entry. +type entry[K comparable, V any] struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *entry[K, V] + + // The list to which this element belongs. + list *lruList[K, V] + + // The LRU key of this element. + key K + + // The value stored with this element. + value V +} + +// lruList represents a doubly linked list. +// The zero value for lruList is an empty list ready to use. +type lruList[K comparable, V any] struct { + root entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// init initializes or clears list l. +func (l *lruList[K, V]) init() *lruList[K, V] { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// newList returns an initialized list. +func newList[K comparable, V any]() *lruList[K, V] { return new(lruList[K, V]).init() } + +// length returns the number of elements of list l. +// The complexity is O(1). +func (l *lruList[K, V]) length() int { return l.len } + +// back returns the last element of list l or nil if the list is empty. +func (l *lruList[K, V]) back() *entry[K, V] { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *lruList[K, V]) lazyInit() { + if l.root.next == nil { + l.init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *lruList[K, V]) insert(e, at *entry[K, V]) *entry[K, V] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *lruList[K, V]) insertValue(k K, v V, at *entry[K, V]) *entry[K, V] { + return l.insert(&entry[K, V]{value: v, key: k}, at) +} + +// remove removes e from its list, decrements l.len. +func (l *lruList[K, V]) remove(e *entry[K, V]) V { + // If already removed, do nothing. + if e.prev == nil && e.next == nil { + return e.value + } + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- + + return e.value +} + +// move moves e to next to at. +func (*lruList[K, V]) move(e, at *entry[K, V]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} + +// pushFront inserts a new element e with value v at the front of list l and returns e. +func (l *lruList[K, V]) pushFront(k K, v V) *entry[K, V] { + l.lazyInit() + return l.insertValue(k, v, &l.root) +} + +// moveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *lruList[K, V]) moveToFront(e *entry[K, V]) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) +} diff --git a/internal/libs/bytes/bytes.go b/internal/libs/bytes/bytes.go new file mode 100644 index 0000000..3b6eefc --- /dev/null +++ b/internal/libs/bytes/bytes.go @@ -0,0 +1,65 @@ +package bytes + +import ( + "encoding/hex" + "fmt" + "strings" +) + +// HexBytes enables HEX-encoding for json/encoding. +type HexBytes []byte + +// Marshal needed for protobuf compatibility. +func (bz HexBytes) Marshal() ([]byte, error) { + return bz, nil +} + +// Unmarshal needed for protobuf compatibility. +func (bz *HexBytes) Unmarshal(data []byte) error { + *bz = data + return nil +} + +// This is the point of Bytes. +func (bz HexBytes) MarshalJSON() ([]byte, error) { + s := strings.ToUpper(hex.EncodeToString(bz)) + jbz := make([]byte, len(s)+2) + jbz[0] = '"' + copy(jbz[1:], s) + jbz[len(jbz)-1] = '"' + return jbz, nil +} + +// This is the point of Bytes. +func (bz *HexBytes) UnmarshalJSON(data []byte) error { + if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' { + return fmt.Errorf("invalid hex string: %s", data) + } + bz2, err := hex.DecodeString(string(data[1 : len(data)-1])) + if err != nil { + return err + } + *bz = bz2 + return nil +} + +// Bytes fulfills various interfaces in light-client, etc... +func (bz HexBytes) Bytes() []byte { + return bz +} + +func (bz HexBytes) String() string { + return strings.ToUpper(hex.EncodeToString(bz)) +} + +// Format writes either address of 0th element in a slice in base 16 notation, +// with leading 0x (%p), or casts HexBytes to bytes and writes as hexadecimal +// string to s. +func (bz HexBytes) Format(s fmt.State, verb rune) { + switch verb { + case 'p': + _, _ = s.Write([]byte(fmt.Sprintf("%p", bz))) + default: + _, _ = s.Write([]byte(fmt.Sprintf("%X", []byte(bz)))) + } +} diff --git a/internal/libs/bytes/bytes_test.go b/internal/libs/bytes/bytes_test.go new file mode 100644 index 0000000..b4e2c8b --- /dev/null +++ b/internal/libs/bytes/bytes_test.go @@ -0,0 +1,75 @@ +package bytes + +import ( + "encoding/json" + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// This is a trivial test for protobuf compatibility. +func TestMarshal(t *testing.T) { + bz := []byte("hello world") + dataB := HexBytes(bz) + bz2, err := dataB.Marshal() + require.NoError(t, err) + assert.Equal(t, bz, bz2) + + var dataB2 HexBytes + err = (&dataB2).Unmarshal(bz) + require.NoError(t, err) + assert.Equal(t, dataB, dataB2) +} + +// Test that the hex encoding works. +func TestJSONMarshal(t *testing.T) { + type TestStruct struct { + B1 []byte `json:"B1" yaml:"B1"` // normal bytes + B2 HexBytes `json:"B2" yaml:"B2"` // hex bytes + } + + cases := []struct { + input []byte + expected string + }{ + {[]byte(``), `{"B1":"","B2":""}`}, + {[]byte(`a`), `{"B1":"YQ==","B2":"61"}`}, + {[]byte(`abc`), `{"B1":"YWJj","B2":"616263"}`}, + } + + for i, tc := range cases { + tc := tc + t.Run(fmt.Sprintf("Case %d", i), func(t *testing.T) { + ts := TestStruct{B1: tc.input, B2: tc.input} + + // Test that it marshals correctly to JSON. + jsonBytes, err := json.Marshal(ts) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, tc.expected, string(jsonBytes)) + + // TODO do fuzz testing to ensure that unmarshal fails + + // Test that unmarshaling works correctly. + ts2 := TestStruct{} + err = json.Unmarshal(jsonBytes, &ts2) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, ts2.B1, tc.input) + assert.Equal(t, ts2.B2, HexBytes(tc.input)) + }) + } +} + +// Test that the hex encoding works. +func TestHexBytes_String(t *testing.T) { + hs := HexBytes([]byte("test me")) + if _, err := strconv.ParseInt(hs.String(), 16, 64); err != nil { + t.Fatal(err) + } +} diff --git a/internal/libs/bytes/byteslice.go b/internal/libs/bytes/byteslice.go new file mode 100644 index 0000000..1d535eb --- /dev/null +++ b/internal/libs/bytes/byteslice.go @@ -0,0 +1,10 @@ +package bytes + +// Fingerprint returns the first 6 bytes of a byte slice. +// If the slice is less than 6 bytes, the fingerprint +// contains trailing zeroes. +func Fingerprint(slice []byte) []byte { + fingerprint := make([]byte, 6) + copy(fingerprint, slice) + return fingerprint +} diff --git a/internal/libs/json/decoder.go b/internal/libs/json/decoder.go new file mode 100644 index 0000000..1aca2ca --- /dev/null +++ b/internal/libs/json/decoder.go @@ -0,0 +1,277 @@ +package json + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" +) + +// Unmarshal unmarshals JSON into the given value, using Amino-compatible JSON encoding (strings +// for 64-bit numbers, and type wrappers for registered types). +func Unmarshal(bz []byte, v any) error { + return decode(bz, v) +} + +func decode(bz []byte, v any) error { + if len(bz) == 0 { + return errors.New("cannot decode empty bytes") + } + + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return errors.New("must decode into a pointer") + } + rv = rv.Elem() + + // If this is a registered type, defer to interface decoder regardless of whether the input is + // an interface or a bare value. This retains Amino's behavior, but is inconsistent with + // behavior in structs where an interface field will get the type wrapper while a bare value + // field will not. + if typeRegistry.name(rv.Type()) != "" { + return decodeReflectInterface(bz, rv) + } + + return decodeReflect(bz, rv) +} + +func decodeReflect(bz []byte, rv reflect.Value) error { + if !rv.CanAddr() { + return errors.New("value is not addressable") + } + + // Handle null for slices, interfaces, and pointers + if bytes.Equal(bz, []byte("null")) { + rv.Set(reflect.Zero(rv.Type())) + return nil + } + + // Dereference-and-construct pointers, to handle nested pointers. + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + // Times must be UTC and end with Z + if rv.Type() == timeType { + switch { + case len(bz) < 2 || bz[0] != '"' || bz[len(bz)-1] != '"': + return fmt.Errorf("JSON time must be an RFC3339 string, but got %q", bz) + case bz[len(bz)-2] != 'Z': + return fmt.Errorf("JSON time must be UTC and end with 'Z', but got %q", bz) + } + } + + // If value implements json.Umarshaler, call it. + if rv.Addr().Type().Implements(jsonUnmarshalerType) { + return rv.Addr().Interface().(json.Unmarshaler).UnmarshalJSON(bz) + } + + switch rv.Type().Kind() { + // Decode complex types recursively. + case reflect.Slice, reflect.Array: + return decodeReflectList(bz, rv) + + case reflect.Map: + return decodeReflectMap(bz, rv) + + case reflect.Struct: + return decodeReflectStruct(bz, rv) + + case reflect.Interface: + return decodeReflectInterface(bz, rv) + + // For 64-bit integers, unwrap expected string and defer to stdlib for integer decoding. + case reflect.Int64, reflect.Int, reflect.Uint64, reflect.Uint: + if bz[0] != '"' || bz[len(bz)-1] != '"' { + return fmt.Errorf("invalid 64-bit integer encoding %q, expected string", string(bz)) + } + bz = bz[1 : len(bz)-1] + fallthrough + + // Anything else we defer to the stdlib. + default: + return decodeStdlib(bz, rv) + } +} + +func decodeReflectList(bz []byte, rv reflect.Value) error { + if !rv.CanAddr() { + return errors.New("list value is not addressable") + } + + switch rv.Type().Elem().Kind() { + // Decode base64-encoded bytes using stdlib decoder, via byte slice for arrays. + case reflect.Uint8: + if rv.Type().Kind() == reflect.Array { + var buf []byte + if err := json.Unmarshal(bz, &buf); err != nil { + return err + } + if len(buf) != rv.Len() { + return fmt.Errorf("got %v bytes, expected %v", len(buf), rv.Len()) + } + reflect.Copy(rv, reflect.ValueOf(buf)) + } else if err := decodeStdlib(bz, rv); err != nil { + return err + } + + // Decode anything else into a raw JSON slice, and decode values recursively. + default: + var rawSlice []json.RawMessage + if err := json.Unmarshal(bz, &rawSlice); err != nil { + return err + } + if rv.Type().Kind() == reflect.Slice { + rv.Set(reflect.MakeSlice(reflect.SliceOf(rv.Type().Elem()), len(rawSlice), len(rawSlice))) + } + if rv.Len() != len(rawSlice) { // arrays of wrong size + return fmt.Errorf("got list of %v elements, expected %v", len(rawSlice), rv.Len()) + } + for i, bz := range rawSlice { + if err := decodeReflect(bz, rv.Index(i)); err != nil { + return err + } + } + } + + // Replace empty slices with nil slices, for Amino compatibility + if rv.Type().Kind() == reflect.Slice && rv.Len() == 0 { + rv.Set(reflect.Zero(rv.Type())) + } + + return nil +} + +func decodeReflectMap(bz []byte, rv reflect.Value) error { + if !rv.CanAddr() { + return errors.New("map value is not addressable") + } + + // Decode into a raw JSON map, using string keys. + rawMap := make(map[string]json.RawMessage) + if err := json.Unmarshal(bz, &rawMap); err != nil { + return err + } + if rv.Type().Key().Kind() != reflect.String { + return fmt.Errorf("map keys must be strings, got %v", rv.Type().Key().String()) + } + + // Recursively decode values. + rv.Set(reflect.MakeMapWithSize(rv.Type(), len(rawMap))) + for key, bz := range rawMap { + value := reflect.New(rv.Type().Elem()).Elem() + if err := decodeReflect(bz, value); err != nil { + return err + } + rv.SetMapIndex(reflect.ValueOf(key), value) + } + return nil +} + +func decodeReflectStruct(bz []byte, rv reflect.Value) error { + if !rv.CanAddr() { + return errors.New("struct value is not addressable") + } + sInfo := makeStructInfo(rv.Type()) + + // Decode raw JSON values into a string-keyed map. + rawMap := make(map[string]json.RawMessage) + if err := json.Unmarshal(bz, &rawMap); err != nil { + return err + } + for i, fInfo := range sInfo.fields { + if !fInfo.hidden { + frv := rv.Field(i) + bz := rawMap[fInfo.jsonName] + if len(bz) > 0 { + if err := decodeReflect(bz, frv); err != nil { + return err + } + } else if !fInfo.omitEmpty { + frv.Set(reflect.Zero(frv.Type())) + } + } + } + + return nil +} + +func decodeReflectInterface(bz []byte, rv reflect.Value) error { + if !rv.CanAddr() { + return errors.New("interface value not addressable") + } + + // Decode the interface wrapper. + wrapper := interfaceWrapper{} + if err := json.Unmarshal(bz, &wrapper); err != nil { + return err + } + if wrapper.Type == "" { + return errors.New("interface type cannot be empty") + } + if len(wrapper.Value) == 0 { + return errors.New("interface value cannot be empty") + } + + // Dereference-and-construct pointers, to handle nested pointers. + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + // Look up the interface type, and construct a concrete value. + rt, returnPtr := typeRegistry.lookup(wrapper.Type) + if rt == nil { + return fmt.Errorf("unknown type %q", wrapper.Type) + } + + cptr := reflect.New(rt) + crv := cptr.Elem() + if err := decodeReflect(wrapper.Value, crv); err != nil { + return err + } + + // This makes sure interface implementations with pointer receivers (e.g. func (c *Car)) are + // constructed as pointers behind the interface. The types must be registered as pointers with + // RegisterType(). + if rv.Type().Kind() == reflect.Interface && returnPtr { + if !cptr.Type().AssignableTo(rv.Type()) { + return fmt.Errorf("invalid type %q for this value", wrapper.Type) + } + rv.Set(cptr) + } else { + if !crv.Type().AssignableTo(rv.Type()) { + return fmt.Errorf("invalid type %q for this value", wrapper.Type) + } + rv.Set(crv) + } + return nil +} + +func decodeStdlib(bz []byte, rv reflect.Value) error { + if !rv.CanAddr() && rv.Kind() != reflect.Ptr { + return errors.New("value must be addressable or pointer") + } + + // Make sure we are unmarshaling into a pointer. + target := rv + if rv.Kind() != reflect.Ptr { + target = reflect.New(rv.Type()) + } + if err := json.Unmarshal(bz, target.Interface()); err != nil { + return err + } + rv.Set(target.Elem()) + return nil +} + +type interfaceWrapper struct { + Type string `json:"type"` + Value json.RawMessage `json:"value"` +} diff --git a/internal/libs/json/decoder_test.go b/internal/libs/json/decoder_test.go new file mode 100644 index 0000000..7e28020 --- /dev/null +++ b/internal/libs/json/decoder_test.go @@ -0,0 +1,150 @@ +package json_test + +import ( + "reflect" + "testing" + "time" + + "github.com/cosmos/crypto/internal/libs/json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnmarshal(t *testing.T) { + i64Nil := (*int64)(nil) + str := "string" + strPtr := &str + structNil := (*Struct)(nil) + i32 := int32(32) + i64 := int64(64) + + testcases := map[string]struct { + json string + value any + err bool + }{ + "bool true": {"true", true, false}, + "bool false": {"false", false, false}, + "float32": {"3.14", float32(3.14), false}, + "float64": {"3.14", float64(3.14), false}, + "int32": {`32`, int32(32), false}, + "int32 string": {`"32"`, int32(32), true}, + "int32 ptr": {`32`, &i32, false}, + "int64": {`"64"`, int64(64), false}, + "int64 noend": {`"64`, int64(64), true}, + "int64 number": {`64`, int64(64), true}, + "int64 ptr": {`"64"`, &i64, false}, + "int64 ptr nil": {`null`, i64Nil, false}, + "string": {`"foo"`, "foo", false}, + "string noend": {`"foo`, "foo", true}, + "string ptr": {`"string"`, &str, false}, + "slice byte": {`"AQID"`, []byte{1, 2, 3}, false}, + "slice bytes": {`["AQID"]`, [][]byte{{1, 2, 3}}, false}, + "slice int32": {`[1,2,3]`, []int32{1, 2, 3}, false}, + "slice int64": {`["1","2","3"]`, []int64{1, 2, 3}, false}, + "slice int64 number": {`[1,2,3]`, []int64{1, 2, 3}, true}, + "slice int64 ptr": {`["64"]`, []*int64{&i64}, false}, + "slice int64 empty": {`[]`, []int64(nil), false}, + "slice int64 null": {`null`, []int64(nil), false}, + "array byte": {`"AQID"`, [3]byte{1, 2, 3}, false}, + "array byte large": {`"AQID"`, [4]byte{1, 2, 3, 4}, true}, + "array byte small": {`"AQID"`, [2]byte{1, 2}, true}, + "array int32": {`[1,2,3]`, [3]int32{1, 2, 3}, false}, + "array int64": {`["1","2","3"]`, [3]int64{1, 2, 3}, false}, + "array int64 number": {`[1,2,3]`, [3]int64{1, 2, 3}, true}, + "array int64 large": {`["1","2","3"]`, [4]int64{1, 2, 3, 4}, true}, + "array int64 small": {`["1","2","3"]`, [2]int64{1, 2}, true}, + "map bytes": {`{"b":"AQID"}`, map[string][]byte{"b": {1, 2, 3}}, false}, + "map int32": {`{"a":1,"b":2}`, map[string]int32{"a": 1, "b": 2}, false}, + "map int64": {`{"a":"1","b":"2"}`, map[string]int64{"a": 1, "b": 2}, false}, + "map int64 empty": {`{}`, map[string]int64{}, false}, + "map int64 null": {`null`, map[string]int64(nil), false}, + "map int key": {`{}`, map[int]int{}, true}, + "time": {`"2020-06-03T17:35:30Z"`, time.Date(2020, 6, 3, 17, 35, 30, 0, time.UTC), false}, + "time non-utc": {`"2020-06-03T17:35:30+02:00"`, time.Time{}, true}, + "time nozone": {`"2020-06-03T17:35:30"`, time.Time{}, true}, + "car": {`{"type":"vehicle/car","value":{"Wheels":4}}`, Car{Wheels: 4}, false}, + "car ptr": {`{"type":"vehicle/car","value":{"Wheels":4}}`, &Car{Wheels: 4}, false}, + "car iface": {`{"type":"vehicle/car","value":{"Wheels":4}}`, Vehicle(&Car{Wheels: 4}), false}, + "boat": {`{"type":"vehicle/boat","value":{"Sail":true}}`, Boat{Sail: true}, false}, + "boat ptr": {`{"type":"vehicle/boat","value":{"Sail":true}}`, &Boat{Sail: true}, false}, + "boat iface": {`{"type":"vehicle/boat","value":{"Sail":true}}`, Vehicle(Boat{Sail: true}), false}, + "boat into car": {`{"type":"vehicle/boat","value":{"Sail":true}}`, Car{}, true}, + "boat into car iface": {`{"type":"vehicle/boat","value":{"Sail":true}}`, Vehicle(&Car{}), true}, + "shoes": {`{"type":"vehicle/shoes","value":{"Soles":"rubber"}}`, Car{}, true}, + "shoes ptr": {`{"type":"vehicle/shoes","value":{"Soles":"rubber"}}`, &Car{}, true}, + "shoes iface": {`{"type":"vehicle/shoes","value":{"Soles":"rubbes"}}`, Vehicle(&Car{}), true}, + "key public": {`{"type":"key/public","value":"AQIDBAUGBwg="}`, PublicKey{1, 2, 3, 4, 5, 6, 7, 8}, false}, + "key wrong": {`{"type":"key/public","value":"AQIDBAUGBwg="}`, PrivateKey{1, 2, 3, 4, 5, 6, 7, 8}, true}, + "key into car": {`{"type":"key/public","value":"AQIDBAUGBwg="}`, Vehicle(&Car{}), true}, + "tags": { + `{"name":"name","OmitEmpty":"foo","Hidden":"bar","tags":{"name":"child"}}`, + Tags{JSONName: "name", OmitEmpty: "foo", Tags: &Tags{JSONName: "child"}}, + false, + }, + "tags ptr": { + `{"name":"name","OmitEmpty":"foo","tags":null}`, + &Tags{JSONName: "name", OmitEmpty: "foo"}, + false, + }, + "tags real name": {`{"JSONName":"name"}`, Tags{}, false}, + "struct": { + `{ + "Bool":true, "Float64":3.14, "Int32":32, "Int64":"64", "Int64Ptr":"64", + "String":"foo", "StringPtrPtr": "string", "Bytes":"AQID", + "Time":"2020-06-02T16:05:13.004346374Z", + "Car":{"Wheels":4}, + "Boat":{"Sail":true}, + "Vehicles":[ + {"type":"vehicle/car","value":{"Wheels":4}}, + {"type":"vehicle/boat","value":{"Sail":true}} + ], + "Child":{ + "Bool":false, "Float64":0, "Int32":0, "Int64":"0", "Int64Ptr":null, + "String":"child", "StringPtrPtr":null, "Bytes":null, + "Time":"0001-01-01T00:00:00Z", + "Car":null, "Boat":{"Sail":false}, "Vehicles":null, "Child":null + }, + "private": "foo", "unknown": "bar" + }`, + Struct{ + Bool: true, Float64: 3.14, Int32: 32, Int64: 64, Int64Ptr: &i64, + String: "foo", StringPtrPtr: &strPtr, Bytes: []byte{1, 2, 3}, + Time: time.Date(2020, 6, 2, 16, 5, 13, 4346374, time.UTC), + Car: &Car{Wheels: 4}, Boat: Boat{Sail: true}, Vehicles: []Vehicle{ + Vehicle(&Car{Wheels: 4}), + Vehicle(Boat{Sail: true}), + }, + Child: &Struct{Bool: false, String: "child"}, + }, + false, + }, + "struct key into vehicle": {`{"Vehicles":[ + {"type":"vehicle/car","value":{"Wheels":4}}, + {"type":"key/public","value":"MTIzNDU2Nzg="} + ]}`, Struct{}, true}, + "struct ptr null": {`null`, structNil, false}, + "custom value": {`{"Value":"foo"}`, CustomValue{}, false}, + "custom ptr": {`"foo"`, &CustomPtr{Value: "custom"}, false}, + "custom ptr value": {`"foo"`, CustomPtr{Value: "custom"}, false}, + "invalid type": {`"foo"`, Struct{}, true}, + } + for name, tc := range testcases { + tc := tc + t.Run(name, func(t *testing.T) { + // Create a target variable as a pointer to the zero value of the tc.value type, + // and wrap it in an empty interface. Decode into that interface. + target := reflect.New(reflect.TypeOf(tc.value)).Interface() + err := json.Unmarshal([]byte(tc.json), target) + if tc.err { + require.Error(t, err) + return + } + require.NoError(t, err) + + // Unwrap the target pointer and get the value behind the interface. + actual := reflect.ValueOf(target).Elem().Interface() + assert.Equal(t, tc.value, actual) + }) + } +} diff --git a/internal/libs/json/doc.go b/internal/libs/json/doc.go new file mode 100644 index 0000000..18a4c97 --- /dev/null +++ b/internal/libs/json/doc.go @@ -0,0 +1,98 @@ +// Package json provides functions for marshaling and unmarshaling JSON in a format that is +// backwards-compatible with Amino JSON encoding. This mostly differs from encoding/json in +// encoding of integers (64-bit integers are encoded as strings, not numbers), and handling +// of interfaces (wrapped in an interface object with type/value keys). +// +// JSON tags (e.g. `json:"name,omitempty"`) are supported in the same way as encoding/json, as is +// custom marshaling overrides via the json.Marshaler and json.Unmarshaler interfaces. +// +// Note that not all JSON emitted by CometBFT is generated by this library; some is generated by +// encoding/json instead, and kept like that for backwards compatibility. +// +// Encoding of numbers uses strings for 64-bit integers (including unspecified ints), to improve +// compatibility with e.g. Javascript (which uses 64-bit floats for numbers, having 53-bit +// precision): +// +// int32(32) // Output: 32 +// uint32(32) // Output: 32 +// int64(64) // Output: "64" +// uint64(64) // Output: "64" +// int(64) // Output: "64" +// uint(64) // Output: "64" +// +// Encoding of other scalars follows encoding/json: +// +// nil // Output: null +// true // Output: true +// "foo" // Output: "foo" +// "" // Output: "" +// +// Slices and arrays are encoded as encoding/json, including base64-encoding of byte slices +// with additional base64-encoding of byte arrays as well: +// +// []int64(nil) // Output: null +// []int64{} // Output: [] +// []int64{1, 2, 3} // Output: ["1", "2", "3"] +// []int32{1, 2, 3} // Output: [1, 2, 3] +// []byte{1, 2, 3} // Output: "AQID" +// [3]int64{1, 2, 3} // Output: ["1", "2", "3"] +// [3]byte{1, 2, 3} // Output: "AQID" +// +// Maps are encoded as encoding/json, but only strings are allowed as map keys (nil maps are not +// emitted as null, to retain Amino backwards-compatibility): +// +// map[string]int64(nil) // Output: {} +// map[string]int64{} // Output: {} +// map[string]int64{"a":1,"b":2} // Output: {"a":"1","b":"2"} +// map[string]int32{"a":1,"b":2} // Output: {"a":1,"b":2} +// map[bool]int{true:1} // Errors +// +// Times are encoded as encoding/json, in RFC3339Nano format, but requiring UTC time zone (with zero +// times emitted as "0001-01-01T00:00:00Z" as with encoding/json): +// +// time.Date(2020, 6, 8, 16, 21, 28, 123, time.FixedZone("UTC+2", 2*60*60)) +// // Output: "2020-06-08T14:21:28.000000123Z" +// time.Time{} // Output: "0001-01-01T00:00:00Z" +// (*time.Time)(nil) // Output: null +// +// Structs are encoded as encoding/json, supporting JSON tags and ignoring private fields: +// +// type Struct struct{ +// Name string +// Value int32 `json:"value,omitempty"` +// private bool +// } +// +// Struct{Name: "foo", Value: 7, private: true} // Output: {"Name":"foo","value":7} +// Struct{} // Output: {"Name":""} +// +// Registered types are encoded with type wrapper, regardless of whether they are given as interface +// or bare struct, but inside structs they are only emitted with type wrapper for interface fields +// (this follows Amino behavior): +// +// type Vehicle interface { +// Drive() error +// } +// +// type Car struct { +// Wheels int8 +// } +// +// func (c *Car) Drive() error { return nil } +// +// RegisterType(&Car{}, "vehicle/car") +// +// Car{Wheels: 4} // Output: {"type":"vehicle/car","value":{"Wheels":4}} +// &Car{Wheels: 4} // Output: {"type":"vehicle/car","value":{"Wheels":4}} +// (*Car)(nil) // Output: null +// Vehicle(Car{Wheels: 4}) // Output: {"type":"vehicle/car","value":{"Wheels":4}} +// Vehicle(nil) // Output: null +// +// type Struct struct { +// Car *Car +// Vehicle +// } +// +// Struct{Car: &Car{Wheels: 4}, Vehicle: &Car{Wheels: 4}} +// // Output: {"Car": {"Wheels: 4"}, "Vehicle": {"type":"vehicle/car","value":{"Wheels":4}}} +package json diff --git a/internal/libs/json/encoder.go b/internal/libs/json/encoder.go new file mode 100644 index 0000000..86d71a5 --- /dev/null +++ b/internal/libs/json/encoder.go @@ -0,0 +1,257 @@ +package json + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "time" +) + +var ( + timeType = reflect.TypeOf(time.Time{}) + jsonMarshalerType = reflect.TypeOf(new(json.Marshaler)).Elem() + jsonUnmarshalerType = reflect.TypeOf(new(json.Unmarshaler)).Elem() +) + +// Marshal marshals the value as JSON, using Amino-compatible JSON encoding (strings for +// 64-bit numbers, and type wrappers for registered types). +func Marshal(v any) ([]byte, error) { + buf := new(bytes.Buffer) + err := encode(buf, v) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// MarshalIndent marshals the value as JSON, using the given prefix and indentation. +func MarshalIndent(v any, prefix, indent string) ([]byte, error) { + bz, err := Marshal(v) + if err != nil { + return nil, err + } + buf := new(bytes.Buffer) + err = json.Indent(buf, bz, prefix, indent) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func encode(w *bytes.Buffer, v any) error { + // Bare nil values can't be reflected, so we must handle them here. + if v == nil { + return writeStr(w, "null") + } + rv := reflect.ValueOf(v) + + // If this is a registered type, defer to interface encoder regardless of whether the input is + // an interface or a bare value. This retains Amino's behavior, but is inconsistent with + // behavior in structs where an interface field will get the type wrapper while a bare value + // field will not. + if typeRegistry.name(rv.Type()) != "" { + return encodeReflectInterface(w, rv) + } + + return encodeReflect(w, rv) +} + +func encodeReflect(w *bytes.Buffer, rv reflect.Value) error { + if !rv.IsValid() { + return errors.New("invalid reflect value") + } + + // Recursively dereference if pointer. + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + return writeStr(w, "null") + } + rv = rv.Elem() + } + + // Convert times to UTC. + if rv.Type() == timeType { + rv = reflect.ValueOf(rv.Interface().(time.Time).Round(0).UTC()) + } + + // If the value implements json.Marshaler, defer to stdlib directly. Since we've already + // dereferenced, we try implementations with both value receiver and pointer receiver. We must + // do this after the time normalization above, and thus after dereferencing. + if rv.Type().Implements(jsonMarshalerType) { + return encodeStdlib(w, rv.Interface()) + } else if rv.CanAddr() && rv.Addr().Type().Implements(jsonMarshalerType) { + return encodeStdlib(w, rv.Addr().Interface()) + } + + switch rv.Type().Kind() { + // Complex types must be recursively encoded. + case reflect.Interface: + return encodeReflectInterface(w, rv) + + case reflect.Array, reflect.Slice: + return encodeReflectList(w, rv) + + case reflect.Map: + return encodeReflectMap(w, rv) + + case reflect.Struct: + return encodeReflectStruct(w, rv) + + // 64-bit integers are emitted as strings, to avoid precision problems with e.g. + // Javascript which uses 64-bit floats (having 53-bit precision). + case reflect.Int64, reflect.Int: + return writeStr(w, `"`+strconv.FormatInt(rv.Int(), 10)+`"`) + + case reflect.Uint64, reflect.Uint: + return writeStr(w, `"`+strconv.FormatUint(rv.Uint(), 10)+`"`) + + // For everything else, defer to the stdlib encoding/json encoder + default: + return encodeStdlib(w, rv.Interface()) + } +} + +func encodeReflectList(w *bytes.Buffer, rv reflect.Value) error { + // Emit nil slices as null. + if rv.Kind() == reflect.Slice && rv.IsNil() { + return writeStr(w, "null") + } + + // Encode byte slices as base64 with the stdlib encoder. + if rv.Type().Elem().Kind() == reflect.Uint8 { + // Stdlib does not base64-encode byte arrays, only slices, so we copy to slice. + if rv.Type().Kind() == reflect.Array { + slice := reflect.MakeSlice(reflect.SliceOf(rv.Type().Elem()), rv.Len(), rv.Len()) + reflect.Copy(slice, rv) + rv = slice + } + return encodeStdlib(w, rv.Interface()) + } + + // Anything else we recursively encode ourselves. + length := rv.Len() + if err := writeStr(w, "["); err != nil { + return err + } + for i := 0; i < length; i++ { + if err := encodeReflect(w, rv.Index(i)); err != nil { + return err + } + if i < length-1 { + if err := writeStr(w, ","); err != nil { + return err + } + } + } + return writeStr(w, "]") +} + +func encodeReflectMap(w *bytes.Buffer, rv reflect.Value) error { + if rv.Type().Key().Kind() != reflect.String { + return errors.New("map key must be string") + } + + // nil maps are not emitted as nil, to retain Amino compatibility. + + if err := writeStr(w, "{"); err != nil { + return err + } + writeComma := false + for _, keyrv := range rv.MapKeys() { + if writeComma { + if err := writeStr(w, ","); err != nil { + return err + } + } + if err := encodeStdlib(w, keyrv.Interface()); err != nil { + return err + } + if err := writeStr(w, ":"); err != nil { + return err + } + if err := encodeReflect(w, rv.MapIndex(keyrv)); err != nil { + return err + } + writeComma = true + } + return writeStr(w, "}") +} + +func encodeReflectStruct(w *bytes.Buffer, rv reflect.Value) error { + sInfo := makeStructInfo(rv.Type()) + if err := writeStr(w, "{"); err != nil { + return err + } + writeComma := false + for i, fInfo := range sInfo.fields { + frv := rv.Field(i) + if fInfo.hidden || (fInfo.omitEmpty && frv.IsZero()) { + continue + } + + if writeComma { + if err := writeStr(w, ","); err != nil { + return err + } + } + if err := encodeStdlib(w, fInfo.jsonName); err != nil { + return err + } + if err := writeStr(w, ":"); err != nil { + return err + } + if err := encodeReflect(w, frv); err != nil { + return err + } + writeComma = true + } + return writeStr(w, "}") +} + +func encodeReflectInterface(w *bytes.Buffer, rv reflect.Value) error { + // Get concrete value and dereference pointers. + for rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface { + if rv.IsNil() { + return writeStr(w, "null") + } + rv = rv.Elem() + } + + // Look up the name of the concrete type + name := typeRegistry.name(rv.Type()) + if name == "" { + return fmt.Errorf("cannot encode unregistered type %v", rv.Type()) + } + + // Write value wrapped in interface envelope + if err := writeStr(w, fmt.Sprintf(`{"type":%q,"value":`, name)); err != nil { + return err + } + if err := encodeReflect(w, rv); err != nil { + return err + } + return writeStr(w, "}") +} + +func encodeStdlib(w *bytes.Buffer, v any) error { + // Stream the output of the JSON marshaling directly into the buffer. + // The stdlib encoder will write a newline, so we must truncate it, + // which is why we pass in a bytes.Buffer throughout, not io.Writer. + enc := json.NewEncoder(w) + err := enc.Encode(v) + if err != nil { + return err + } + // Remove the last byte from the buffer + w.Truncate(w.Len() - 1) + return err +} + +func writeStr(w io.Writer, s string) error { + _, err := w.Write([]byte(s)) + return err +} diff --git a/internal/libs/json/encoder_test.go b/internal/libs/json/encoder_test.go new file mode 100644 index 0000000..6584214 --- /dev/null +++ b/internal/libs/json/encoder_test.go @@ -0,0 +1,120 @@ +package json_test + +import ( + "testing" + "time" + + "github.com/cosmos/crypto/internal/libs/json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMarshal(t *testing.T) { + s := "string" + sPtr := &s + i64 := int64(64) + ti := time.Date(2020, 6, 2, 18, 5, 13, 4346374, time.FixedZone("UTC+2", 2*60*60)) + car := &Car{Wheels: 4} + boat := Boat{Sail: true} + + testcases := map[string]struct { + value any + output string + }{ + "nil": {nil, `null`}, + "string": {"foo", `"foo"`}, + "float32": {float32(3.14), `3.14`}, + "float32 neg": {float32(-3.14), `-3.14`}, + "float64": {float64(3.14), `3.14`}, + "float64 neg": {float64(-3.14), `-3.14`}, + "int32": {int32(32), `32`}, + "int64": {int64(64), `"64"`}, + "int64 neg": {int64(-64), `"-64"`}, + "int64 ptr": {&i64, `"64"`}, + "uint64": {uint64(64), `"64"`}, + "time": {ti, `"2020-06-02T16:05:13.004346374Z"`}, + "time empty": {time.Time{}, `"0001-01-01T00:00:00Z"`}, + "time ptr": {&ti, `"2020-06-02T16:05:13.004346374Z"`}, + "customptr": {CustomPtr{Value: "x"}, `{"Value":"x"}`}, // same as encoding/json + "customptr ptr": {&CustomPtr{Value: "x"}, `"custom"`}, + "customvalue": {CustomValue{Value: "x"}, `"custom"`}, + "customvalue ptr": {&CustomValue{Value: "x"}, `"custom"`}, + "slice nil": {[]int(nil), `null`}, + "slice empty": {[]int{}, `[]`}, + "slice bytes": {[]byte{1, 2, 3}, `"AQID"`}, + "slice int64": {[]int64{1, 2, 3}, `["1","2","3"]`}, + "slice int64 ptr": {[]*int64{&i64, nil}, `["64",null]`}, + "array bytes": {[3]byte{1, 2, 3}, `"AQID"`}, + "array int64": {[3]int64{1, 2, 3}, `["1","2","3"]`}, + "map nil": {map[string]int64(nil), `{}`}, // retain Amino compatibility + "map empty": {map[string]int64{}, `{}`}, + "map int64": {map[string]int64{"a": 1, "b": 2, "c": 3}, `{"a":"1","b":"2","c":"3"}`}, + "car": {car, `{"type":"vehicle/car","value":{"Wheels":4}}`}, + "car value": {*car, `{"type":"vehicle/car","value":{"Wheels":4}}`}, + "car iface": {Vehicle(car), `{"type":"vehicle/car","value":{"Wheels":4}}`}, + "car nil": {(*Car)(nil), `null`}, + "boat": {boat, `{"type":"vehicle/boat","value":{"Sail":true}}`}, + "boat ptr": {&boat, `{"type":"vehicle/boat","value":{"Sail":true}}`}, + "boat iface": {Vehicle(boat), `{"type":"vehicle/boat","value":{"Sail":true}}`}, + "key public": {PublicKey{1, 2, 3, 4, 5, 6, 7, 8}, `{"type":"key/public","value":"AQIDBAUGBwg="}`}, + "tags": { + Tags{JSONName: "name", OmitEmpty: "foo", Hidden: "bar", Tags: &Tags{JSONName: "child"}}, + `{"name":"name","OmitEmpty":"foo","tags":{"name":"child"}}`, + }, + "tags empty": {Tags{}, `{"name":""}`}, + // The encoding of the Car and Boat fields do not have type wrappers, even though they get + // type wrappers when encoded directly (see "car" and "boat" tests). This is to retain the + // same behavior as Amino. If the field was a Vehicle interface instead, it would get + // type wrappers, as seen in the Vehicles field. + "struct": { + Struct{ + Bool: true, Float64: 3.14, Int32: 32, Int64: 64, Int64Ptr: &i64, + String: "foo", StringPtrPtr: &sPtr, Bytes: []byte{1, 2, 3}, + Time: ti, Car: car, Boat: boat, Vehicles: []Vehicle{car, boat}, + Child: &Struct{Bool: false, String: "child"}, private: "private", + }, + `{ + "Bool":true, "Float64":3.14, "Int32":32, "Int64":"64", "Int64Ptr":"64", + "String":"foo", "StringPtrPtr": "string", "Bytes":"AQID", + "Time":"2020-06-02T16:05:13.004346374Z", + "Car":{"Wheels":4}, + "Boat":{"Sail":true}, + "Vehicles":[ + {"type":"vehicle/car","value":{"Wheels":4}}, + {"type":"vehicle/boat","value":{"Sail":true}} + ], + "Child":{ + "Bool":false, "Float64":0, "Int32":0, "Int64":"0", "Int64Ptr":null, + "String":"child", "StringPtrPtr":null, "Bytes":null, + "Time":"0001-01-01T00:00:00Z", + "Car":null, "Boat":{"Sail":false}, "Vehicles":null, "Child":null + } + }`, + }, + } + for name, tc := range testcases { + tc := tc + t.Run(name, func(t *testing.T) { + bz, err := json.Marshal(tc.value) + require.NoError(t, err) + assert.JSONEq(t, tc.output, string(bz)) + }) + } +} + +func BenchmarkJsonMarshalStruct(b *testing.B) { + s := "string" + sPtr := &s + i64 := int64(64) + ti := time.Date(2020, 6, 2, 18, 5, 13, 4346374, time.FixedZone("UTC+2", 2*60*60)) + car := &Car{Wheels: 4} + boat := Boat{Sail: true} + for i := 0; i < b.N; i++ { + _, _ = json.Marshal(Struct{ + Bool: true, Float64: 3.14, Int32: 32, Int64: 64, Int64Ptr: &i64, + String: "foo", StringPtrPtr: &sPtr, Bytes: []byte{1, 2, 3}, + Time: ti, Car: car, Boat: boat, Vehicles: []Vehicle{car, boat}, + Child: &Struct{Bool: false, String: "child"}, private: "private", + }) + } +} diff --git a/internal/libs/json/helpers_test.go b/internal/libs/json/helpers_test.go new file mode 100644 index 0000000..aae5e85 --- /dev/null +++ b/internal/libs/json/helpers_test.go @@ -0,0 +1,92 @@ +package json_test + +import ( + "github.com/cosmos/crypto/internal/libs/json" + "time" +) + +// Register Car, an instance of the Vehicle interface. +func init() { + json.RegisterType(&Car{}, "vehicle/car") + json.RegisterType(Boat{}, "vehicle/boat") + json.RegisterType(PublicKey{}, "key/public") + json.RegisterType(PrivateKey{}, "key/private") +} + +type Vehicle interface { + Drive() error +} + +// Car is a pointer implementation of Vehicle. +type Car struct { + Wheels int32 +} + +func (*Car) Drive() error { return nil } + +// Boat is a value implementation of Vehicle. +type Boat struct { + Sail bool +} + +func (Boat) Drive() error { return nil } + +// These are public and private encryption keys. +type ( + PublicKey [8]byte + PrivateKey [8]byte +) + +// Custom has custom marshalers and unmarshalers, taking pointer receivers. +type CustomPtr struct { + Value string +} + +func (*CustomPtr) MarshalJSON() ([]byte, error) { + return []byte("\"custom\""), nil +} + +func (c *CustomPtr) UnmarshalJSON(_ []byte) error { + c.Value = "custom" + return nil +} + +// CustomValue has custom marshalers and unmarshalers, taking value receivers (which usually doesn't +// make much sense since the unmarshaler can't change anything). +type CustomValue struct { + Value string +} + +func (CustomValue) MarshalJSON() ([]byte, error) { + return []byte("\"custom\""), nil +} + +func (CustomValue) UnmarshalJSON(_ []byte) error { + return nil +} + +// Tags tests JSON tags. +type Tags struct { + JSONName string `json:"name"` + OmitEmpty string `json:",omitempty"` + Hidden string `json:"-"` + Tags *Tags `json:"tags,omitempty"` +} + +// Struct tests structs with lots of contents. +type Struct struct { + Bool bool + Float64 float64 + Int32 int32 + Int64 int64 + Int64Ptr *int64 + String string + StringPtrPtr **string + Bytes []byte + Time time.Time + Car *Car + Boat Boat + Vehicles []Vehicle + Child *Struct + private string +} diff --git a/internal/libs/json/structs.go b/internal/libs/json/structs.go new file mode 100644 index 0000000..d6380f2 --- /dev/null +++ b/internal/libs/json/structs.go @@ -0,0 +1,86 @@ +package json + +import ( + "fmt" + "reflect" + "strings" + "unicode" + + "github.com/cosmos/crypto/internal/sync" +) + +// cache caches struct info. +var cache = newStructInfoCache() + +// structCache is a cache of struct info. +type structInfoCache struct { + sync.RWMutex + structInfos map[reflect.Type]*structInfo +} + +func newStructInfoCache() *structInfoCache { + return &structInfoCache{ + structInfos: make(map[reflect.Type]*structInfo), + } +} + +func (c *structInfoCache) get(rt reflect.Type) *structInfo { + c.RLock() + defer c.RUnlock() + return c.structInfos[rt] +} + +func (c *structInfoCache) set(rt reflect.Type, sInfo *structInfo) { + c.Lock() + defer c.Unlock() + c.structInfos[rt] = sInfo +} + +// structInfo contains JSON info for a struct. +type structInfo struct { + fields []*fieldInfo +} + +// fieldInfo contains JSON info for a struct field. +type fieldInfo struct { + jsonName string + omitEmpty bool + hidden bool +} + +// makeStructInfo generates structInfo for a struct as a reflect.Value. +func makeStructInfo(rt reflect.Type) *structInfo { + if rt.Kind() != reflect.Struct { + panic(fmt.Sprintf("can't make struct info for non-struct value %v", rt)) + } + if sInfo := cache.get(rt); sInfo != nil { + return sInfo + } + fields := make([]*fieldInfo, 0, rt.NumField()) + for i := 0; i < cap(fields); i++ { + frt := rt.Field(i) + fInfo := &fieldInfo{ + jsonName: frt.Name, + omitEmpty: false, + hidden: frt.Name == "" || !unicode.IsUpper(rune(frt.Name[0])), + } + o := frt.Tag.Get("json") + if o == "-" { + fInfo.hidden = true + } else if o != "" { + opts := strings.Split(o, ",") + if opts[0] != "" { + fInfo.jsonName = opts[0] + } + for _, o := range opts[1:] { + if o == "omitempty" { + fInfo.omitEmpty = true + } + } + } + fields = append(fields, fInfo) + } + sInfo := &structInfo{fields: fields} + cache.set(rt, sInfo) + return sInfo +} diff --git a/internal/libs/json/types.go b/internal/libs/json/types.go new file mode 100644 index 0000000..20d3424 --- /dev/null +++ b/internal/libs/json/types.go @@ -0,0 +1,107 @@ +package json + +import ( + "errors" + "fmt" + "reflect" + + cmtsync "github.com/cosmos/crypto/internal/sync" +) + +// typeRegistry contains globally registered types for JSON encoding/decoding. +var typeRegistry = newTypes() + +// RegisterType registers a type for Amino-compatible interface encoding in the global type +// registry. These types will be encoded with a type wrapper `{"type":"","value":}` +// regardless of which interface they are wrapped in (if any). If the type is a pointer, it will +// still be valid both for value and pointer types, but decoding into an interface will generate +// the a value or pointer based on the registered type. +// +// Should only be called in init() functions, as it panics on error. +func RegisterType(_type any, name string) { + if _type == nil { + panic("cannot register nil type") + } + err := typeRegistry.register(name, reflect.ValueOf(_type).Type()) + if err != nil { + panic(err) + } +} + +// typeInfo contains type information. +type typeInfo struct { + name string + rt reflect.Type + returnPtr bool +} + +// types is a type registry. It is safe for concurrent use. +type types struct { + cmtsync.RWMutex + byType map[reflect.Type]*typeInfo + byName map[string]*typeInfo +} + +// newTypes creates a new type registry. +func newTypes() types { + return types{ + byType: map[reflect.Type]*typeInfo{}, + byName: map[string]*typeInfo{}, + } +} + +// registers the given type with the given name. The name and type must not be registered already. +func (t *types) register(name string, rt reflect.Type) error { + if name == "" { + return errors.New("name cannot be empty") + } + // If this is a pointer type, we recursively resolve until we get a bare type, but register that + // we should return pointers. + returnPtr := false + for rt.Kind() == reflect.Ptr { + returnPtr = true + rt = rt.Elem() + } + tInfo := &typeInfo{ + name: name, + rt: rt, + returnPtr: returnPtr, + } + + t.Lock() + defer t.Unlock() + if _, ok := t.byName[tInfo.name]; ok { + return fmt.Errorf("a type with name %q is already registered", name) + } + if _, ok := t.byType[tInfo.rt]; ok { + return fmt.Errorf("the type %v is already registered", rt) + } + t.byName[name] = tInfo + t.byType[rt] = tInfo + return nil +} + +// lookup looks up a type from a name, or nil if not registered. +func (t *types) lookup(name string) (reflect.Type, bool) { + t.RLock() + defer t.RUnlock() + tInfo := t.byName[name] + if tInfo == nil { + return nil, false + } + return tInfo.rt, tInfo.returnPtr +} + +// name looks up the name of a type, or empty if not registered. Unwraps pointers as necessary. +func (t *types) name(rt reflect.Type) string { + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + t.RLock() + defer t.RUnlock() + tInfo := t.byType[rt] + if tInfo == nil { + return "" + } + return tInfo.name +} diff --git a/internal/sync/deadlock.go b/internal/sync/deadlock.go new file mode 100644 index 0000000..21b5130 --- /dev/null +++ b/internal/sync/deadlock.go @@ -0,0 +1,18 @@ +//go:build deadlock +// +build deadlock + +package sync + +import ( + deadlock "github.com/sasha-s/go-deadlock" +) + +// A Mutex is a mutual exclusion lock. +type Mutex struct { + deadlock.Mutex +} + +// An RWMutex is a reader/writer mutual exclusion lock. +type RWMutex struct { + deadlock.RWMutex +} diff --git a/internal/sync/sync.go b/internal/sync/sync.go new file mode 100644 index 0000000..c6e7101 --- /dev/null +++ b/internal/sync/sync.go @@ -0,0 +1,16 @@ +//go:build !deadlock +// +build !deadlock + +package sync + +import "sync" + +// A Mutex is a mutual exclusion lock. +type Mutex struct { + sync.Mutex +} + +// An RWMutex is a reader/writer mutual exclusion lock. +type RWMutex struct { + sync.RWMutex +} diff --git a/random/random.go b/random/random.go new file mode 100644 index 0000000..6530a18 --- /dev/null +++ b/random/random.go @@ -0,0 +1,35 @@ +package random + +import ( + crand "crypto/rand" + "encoding/hex" + "io" +) + +// This only uses the OS's randomness. +func randBytes(numBytes int) []byte { + b := make([]byte, numBytes) + _, err := crand.Read(b) + if err != nil { + panic(err) + } + return b +} + +// This only uses the OS's randomness. +func CRandBytes(numBytes int) []byte { + return randBytes(numBytes) +} + +// CRandHex returns a hex encoded string that's floor(numDigits/2) * 2 long. +// +// Note: CRandHex(24) gives 96 bits of randomness that +// are usually strong enough for most purposes. +func CRandHex(numDigits int) string { + return hex.EncodeToString(CRandBytes(numDigits / 2)) +} + +// Returns a crand.Reader. +func CReader() io.Reader { + return crand.Reader +} diff --git a/random/random_test.go b/random/random_test.go new file mode 100644 index 0000000..80b8602 --- /dev/null +++ b/random/random_test.go @@ -0,0 +1,22 @@ +package random_test + +import ( + "github.com/cosmos/crypto/random" + "testing" + + "github.com/stretchr/testify/require" +) + +// the purpose of this test is primarily to ensure that the randomness +// generation won't error. +func TestRandomConsistency(t *testing.T) { + x1 := random.CRandBytes(256) + x2 := random.CRandBytes(256) + x3 := random.CRandBytes(256) + x4 := random.CRandBytes(256) + x5 := random.CRandBytes(256) + require.NotEqual(t, x1, x2) + require.NotEqual(t, x3, x4) + require.NotEqual(t, x4, x5) + require.NotEqual(t, x1, x5) +} diff --git a/symmetric/types.go b/symmetric/types.go new file mode 100644 index 0000000..84fd39f --- /dev/null +++ b/symmetric/types.go @@ -0,0 +1,7 @@ +package symmetric + +type Symmetric interface { + Keygen() []byte + Encrypt(plaintext []byte, secret []byte) (ciphertext []byte) + Decrypt(ciphertext []byte, secret []byte) (plaintext []byte, err error) +} diff --git a/symmetric/xchacha20poly1305/vector_test.go b/symmetric/xchacha20poly1305/vector_test.go new file mode 100644 index 0000000..c6ca9d8 --- /dev/null +++ b/symmetric/xchacha20poly1305/vector_test.go @@ -0,0 +1,122 @@ +package xchacha20poly1305 + +import ( + "bytes" + "encoding/hex" + "testing" +) + +func toHex(bits []byte) string { + return hex.EncodeToString(bits) +} + +func fromHex(bits string) []byte { + b, err := hex.DecodeString(bits) + if err != nil { + panic(err) + } + return b +} + +func TestHChaCha20(t *testing.T) { + for i, v := range hChaCha20Vectors { + var key [32]byte + var nonce [16]byte + copy(key[:], v.key) + copy(nonce[:], v.nonce) + + HChaCha20(&key, &nonce, &key) + if !bytes.Equal(key[:], v.keystream) { + t.Errorf("test %d: keystream mismatch:\n \t got: %s\n \t want: %s", i, toHex(key[:]), toHex(v.keystream)) + } + } +} + +var hChaCha20Vectors = []struct { + key, nonce, keystream []byte +}{ + { + fromHex("0000000000000000000000000000000000000000000000000000000000000000"), + fromHex("000000000000000000000000000000000000000000000000"), + fromHex("1140704c328d1d5d0e30086cdf209dbd6a43b8f41518a11cc387b669b2ee6586"), + }, + { + fromHex("8000000000000000000000000000000000000000000000000000000000000000"), + fromHex("000000000000000000000000000000000000000000000000"), + fromHex("7d266a7fd808cae4c02a0a70dcbfbcc250dae65ce3eae7fc210f54cc8f77df86"), + }, + { + fromHex("0000000000000000000000000000000000000000000000000000000000000001"), + fromHex("000000000000000000000000000000000000000000000002"), + fromHex("e0c77ff931bb9163a5460c02ac281c2b53d792b1c43fea817e9ad275ae546963"), + }, + { + fromHex("000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + fromHex("000102030405060708090a0b0c0d0e0f1011121314151617"), + fromHex("51e3ff45a895675c4b33b46c64f4a9ace110d34df6a2ceab486372bacbd3eff6"), + }, + { + fromHex("24f11cce8a1b3d61e441561a696c1c1b7e173d084fd4812425435a8896a013dc"), + fromHex("d9660c5900ae19ddad28d6e06e45fe5e"), + fromHex("5966b3eec3bff1189f831f06afe4d4e3be97fa9235ec8c20d08acfbbb4e851e3"), + }, +} + +func TestVectors(t *testing.T) { + for i, v := range vectors { + if len(v.plaintext) == 0 { + v.plaintext = make([]byte, len(v.ciphertext)) + } + + var nonce [24]byte + copy(nonce[:], v.nonce) + + aead, err := New(v.key) + if err != nil { + t.Error(err) + } + + dst := aead.Seal(nil, nonce[:], v.plaintext, v.ad) + if !bytes.Equal(dst, v.ciphertext) { + t.Errorf("test %d: ciphertext mismatch:\n \t got: %s\n \t want: %s", i, toHex(dst), toHex(v.ciphertext)) + } + open, err := aead.Open(nil, nonce[:], dst, v.ad) + if err != nil { + t.Error(err) + } + if !bytes.Equal(open, v.plaintext) { + t.Errorf("test %d: plaintext mismatch:\n \t got: %s\n \t want: %s", i, string(open), string(v.plaintext)) + } + } +} + +var vectors = []struct { + key, nonce, ad, plaintext, ciphertext []byte +}{ + { + []byte{ + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, + 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, + 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + }, + []byte{0x07, 0x00, 0x00, 0x00, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b}, + []byte{0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7}, + []byte( + "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it.", + ), + []byte{ + 0x45, 0x3c, 0x06, 0x93, 0xa7, 0x40, 0x7f, 0x04, 0xff, 0x4c, 0x56, + 0xae, 0xdb, 0x17, 0xa3, 0xc0, 0xa1, 0xaf, 0xff, 0x01, 0x17, 0x49, + 0x30, 0xfc, 0x22, 0x28, 0x7c, 0x33, 0xdb, 0xcf, 0x0a, 0xc8, 0xb8, + 0x9a, 0xd9, 0x29, 0x53, 0x0a, 0x1b, 0xb3, 0xab, 0x5e, 0x69, 0xf2, + 0x4c, 0x7f, 0x60, 0x70, 0xc8, 0xf8, 0x40, 0xc9, 0xab, 0xb4, 0xf6, + 0x9f, 0xbf, 0xc8, 0xa7, 0xff, 0x51, 0x26, 0xfa, 0xee, 0xbb, 0xb5, + 0x58, 0x05, 0xee, 0x9c, 0x1c, 0xf2, 0xce, 0x5a, 0x57, 0x26, 0x32, + 0x87, 0xae, 0xc5, 0x78, 0x0f, 0x04, 0xec, 0x32, 0x4c, 0x35, 0x14, + 0x12, 0x2c, 0xfc, 0x32, 0x31, 0xfc, 0x1a, 0x8b, 0x71, 0x8a, 0x62, + 0x86, 0x37, 0x30, 0xa2, 0x70, 0x2b, 0xb7, 0x63, 0x66, 0x11, 0x6b, + 0xed, 0x09, 0xe0, 0xfd, 0x5c, 0x6d, 0x84, 0xb6, 0xb0, 0xc1, 0xab, + 0xaf, 0x24, 0x9d, 0x5d, 0xd0, 0xf7, 0xf5, 0xa7, 0xea, + }, + }, +} diff --git a/symmetric/xchacha20poly1305/xchachapoly.go b/symmetric/xchacha20poly1305/xchachapoly.go new file mode 100644 index 0000000..763d4db --- /dev/null +++ b/symmetric/xchacha20poly1305/xchachapoly.go @@ -0,0 +1,264 @@ +// Package xchacha20poly1305 creates an AEAD using hchacha, chacha, and poly1305 +// This allows for randomized nonces to be used in conjunction with chacha. +package xchacha20poly1305 + +import ( + "crypto/cipher" + "encoding/binary" + "errors" + + "golang.org/x/crypto/chacha20poly1305" +) + +// Implements crypto.AEAD. +type xchacha20poly1305 struct { + key [KeySize]byte +} + +const ( + // KeySize is the size of the key used by this AEAD, in bytes. + KeySize = 32 + // NonceSize is the size of the nonce used with this AEAD, in bytes. + NonceSize = 24 + // TagSize is the size added from poly1305. + TagSize = 16 + // MaxPlaintextSize is the max size that can be passed into a single call of Seal. + MaxPlaintextSize = (1 << 38) - 64 + // MaxCiphertextSize is the max size that can be passed into a single call of Open, + // this differs from plaintext size due to the tag. + MaxCiphertextSize = (1 << 38) - 48 + + // sigma are constants used in xchacha. + // Unrolled from a slice so that they can be inlined, as slices can't be constants. + sigma0 = uint32(0x61707865) + sigma1 = uint32(0x3320646e) + sigma2 = uint32(0x79622d32) + sigma3 = uint32(0x6b206574) +) + +var ( + ErrInvalidKeyLen = errors.New("xchacha20poly1305: bad key length") + ErrInvalidNonceLen = errors.New("xchacha20poly1305: bad nonce length") + ErrInvalidCipherTextLen = errors.New("xchacha20poly1305: ciphertext too large") +) + +// New returns a new xchachapoly1305 AEAD. +func New(key []byte) (cipher.AEAD, error) { + if len(key) != KeySize { + return nil, ErrInvalidKeyLen + } + ret := new(xchacha20poly1305) + copy(ret.key[:], key) + return ret, nil +} + +func (*xchacha20poly1305) NonceSize() int { + return NonceSize +} + +func (*xchacha20poly1305) Overhead() int { + return TagSize +} + +func (c *xchacha20poly1305) Seal(dst, nonce, plaintext, additionalData []byte) []byte { + if len(nonce) != NonceSize { + panic("xchacha20poly1305: bad nonce length passed to Seal") + } + + if uint64(len(plaintext)) > MaxPlaintextSize { + panic("xchacha20poly1305: plaintext too large") + } + + var subKey [KeySize]byte + var hNonce [16]byte + var subNonce [chacha20poly1305.NonceSize]byte + copy(hNonce[:], nonce[:16]) + + HChaCha20(&subKey, &hNonce, &c.key) + + // This can't error because we always provide a correctly sized key + chacha20poly1305, _ := chacha20poly1305.New(subKey[:]) + + copy(subNonce[4:], nonce[16:]) + + return chacha20poly1305.Seal(dst, subNonce[:], plaintext, additionalData) +} + +func (c *xchacha20poly1305) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { + if len(nonce) != NonceSize { + return nil, ErrInvalidNonceLen + } + if uint64(len(ciphertext)) > MaxCiphertextSize { + return nil, ErrInvalidCipherTextLen + } + var subKey [KeySize]byte + var hNonce [16]byte + var subNonce [chacha20poly1305.NonceSize]byte + copy(hNonce[:], nonce[:16]) + + HChaCha20(&subKey, &hNonce, &c.key) + + // This can't error because we always provide a correctly sized key + chacha20poly1305, _ := chacha20poly1305.New(subKey[:]) + + copy(subNonce[4:], nonce[16:]) + + return chacha20poly1305.Open(dst, subNonce[:], ciphertext, additionalData) +} + +// HChaCha exported from +// https://github.com/aead/chacha20/blob/8b13a72661dae6e9e5dea04f344f0dc95ea29547/chacha/chacha_generic.go#L194 +// TODO: Add support for the different assembly instructions used there. + +// The MIT License (MIT) + +// Copyright (c) 2016 Andreas Auernhammer + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// HChaCha20 generates 32 pseudo-random bytes from a 128 bit nonce and a 256 bit secret key. +// It can be used as a key-derivation-function (KDF). +func HChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { hChaCha20Generic(out, nonce, key) } + +func hChaCha20Generic(out *[32]byte, nonce *[16]byte, key *[32]byte) { + v00 := sigma0 + v01 := sigma1 + v02 := sigma2 + v03 := sigma3 + v04 := binary.LittleEndian.Uint32(key[0:]) + v05 := binary.LittleEndian.Uint32(key[4:]) + v06 := binary.LittleEndian.Uint32(key[8:]) + v07 := binary.LittleEndian.Uint32(key[12:]) + v08 := binary.LittleEndian.Uint32(key[16:]) + v09 := binary.LittleEndian.Uint32(key[20:]) + v10 := binary.LittleEndian.Uint32(key[24:]) + v11 := binary.LittleEndian.Uint32(key[28:]) + v12 := binary.LittleEndian.Uint32(nonce[0:]) + v13 := binary.LittleEndian.Uint32(nonce[4:]) + v14 := binary.LittleEndian.Uint32(nonce[8:]) + v15 := binary.LittleEndian.Uint32(nonce[12:]) + + for i := 0; i < 20; i += 2 { + v00 += v04 + v12 ^= v00 + v12 = (v12 << 16) | (v12 >> 16) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 12) | (v04 >> 20) + v00 += v04 + v12 ^= v00 + v12 = (v12 << 8) | (v12 >> 24) + v08 += v12 + v04 ^= v08 + v04 = (v04 << 7) | (v04 >> 25) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 16) | (v13 >> 16) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 12) | (v05 >> 20) + v01 += v05 + v13 ^= v01 + v13 = (v13 << 8) | (v13 >> 24) + v09 += v13 + v05 ^= v09 + v05 = (v05 << 7) | (v05 >> 25) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 16) | (v14 >> 16) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 12) | (v06 >> 20) + v02 += v06 + v14 ^= v02 + v14 = (v14 << 8) | (v14 >> 24) + v10 += v14 + v06 ^= v10 + v06 = (v06 << 7) | (v06 >> 25) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 16) | (v15 >> 16) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 12) | (v07 >> 20) + v03 += v07 + v15 ^= v03 + v15 = (v15 << 8) | (v15 >> 24) + v11 += v15 + v07 ^= v11 + v07 = (v07 << 7) | (v07 >> 25) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 16) | (v15 >> 16) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 12) | (v05 >> 20) + v00 += v05 + v15 ^= v00 + v15 = (v15 << 8) | (v15 >> 24) + v10 += v15 + v05 ^= v10 + v05 = (v05 << 7) | (v05 >> 25) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 16) | (v12 >> 16) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 12) | (v06 >> 20) + v01 += v06 + v12 ^= v01 + v12 = (v12 << 8) | (v12 >> 24) + v11 += v12 + v06 ^= v11 + v06 = (v06 << 7) | (v06 >> 25) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 16) | (v13 >> 16) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 12) | (v07 >> 20) + v02 += v07 + v13 ^= v02 + v13 = (v13 << 8) | (v13 >> 24) + v08 += v13 + v07 ^= v08 + v07 = (v07 << 7) | (v07 >> 25) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 16) | (v14 >> 16) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 12) | (v04 >> 20) + v03 += v04 + v14 ^= v03 + v14 = (v14 << 8) | (v14 >> 24) + v09 += v14 + v04 ^= v09 + v04 = (v04 << 7) | (v04 >> 25) + } + + binary.LittleEndian.PutUint32(out[0:], v00) + binary.LittleEndian.PutUint32(out[4:], v01) + binary.LittleEndian.PutUint32(out[8:], v02) + binary.LittleEndian.PutUint32(out[12:], v03) + binary.LittleEndian.PutUint32(out[16:], v12) + binary.LittleEndian.PutUint32(out[20:], v13) + binary.LittleEndian.PutUint32(out[24:], v14) + binary.LittleEndian.PutUint32(out[28:], v15) +} diff --git a/symmetric/xchacha20poly1305/xchachapoly_test.go b/symmetric/xchacha20poly1305/xchachapoly_test.go new file mode 100644 index 0000000..6844f74 --- /dev/null +++ b/symmetric/xchacha20poly1305/xchachapoly_test.go @@ -0,0 +1,113 @@ +package xchacha20poly1305 + +import ( + "bytes" + cr "crypto/rand" + mr "math/rand" + "testing" +) + +// The following test is taken from +// https://github.com/golang/crypto/blob/master/chacha20poly1305/chacha20poly1305_test.go#L69 +// It requires the below copyright notice, where "this source code" refers to the following function. +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found at the bottom of this file. +func TestRandom(t *testing.T) { + // Some random tests to verify Open(Seal) == Plaintext + for i := 0; i < 256; i++ { + var nonce [24]byte + var key [32]byte + + al := mr.Intn(128) + pl := mr.Intn(16384) + ad := make([]byte, al) + plaintext := make([]byte, pl) + _, err := cr.Read(key[:]) + if err != nil { + t.Errorf("error on read: %v", err) + } + _, err = cr.Read(nonce[:]) + if err != nil { + t.Errorf("error on read: %v", err) + } + _, err = cr.Read(ad) + if err != nil { + t.Errorf("error on read: %v", err) + } + _, err = cr.Read(plaintext) + if err != nil { + t.Errorf("error on read: %v", err) + } + + aead, err := New(key[:]) + if err != nil { + t.Fatal(err) + } + + ct := aead.Seal(nil, nonce[:], plaintext, ad) + + plaintext2, err := aead.Open(nil, nonce[:], ct, ad) + if err != nil { + t.Errorf("random #%d: Open failed", i) + continue + } + + if !bytes.Equal(plaintext, plaintext2) { + t.Errorf("random #%d: plaintext's don't match: got %x vs %x", i, plaintext2, plaintext) + continue + } + + if len(ad) > 0 { + alterAdIdx := mr.Intn(len(ad)) + ad[alterAdIdx] ^= 0x80 + if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil { + t.Errorf("random #%d: Open was successful after altering additional data", i) + } + ad[alterAdIdx] ^= 0x80 + } + + alterNonceIdx := mr.Intn(aead.NonceSize()) + nonce[alterNonceIdx] ^= 0x80 + if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil { + t.Errorf("random #%d: Open was successful after altering nonce", i) + } + nonce[alterNonceIdx] ^= 0x80 + + alterCtIdx := mr.Intn(len(ct)) + ct[alterCtIdx] ^= 0x80 + if _, err := aead.Open(nil, nonce[:], ct, ad); err == nil { + t.Errorf("random #%d: Open was successful after altering ciphertext", i) + } + ct[alterCtIdx] ^= 0x80 + } +} + +// AFOREMENTIONED LICENSE +// Copyright (c) 2009 The Go Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/symmetric/xsalsa20symmetric/symmetric.go b/symmetric/xsalsa20symmetric/symmetric.go new file mode 100644 index 0000000..61d02cb --- /dev/null +++ b/symmetric/xsalsa20symmetric/symmetric.go @@ -0,0 +1,60 @@ +package xsalsa20symmetric + +import ( + "errors" + "fmt" + crypto "github.com/cosmos/crypto/random" + + "golang.org/x/crypto/nacl/secretbox" +) + +// TODO, make this into a struct that implements crypto.Symmetric. + +const ( + nonceLen = 24 + secretLen = 32 +) + +var ( + ErrInvalidCiphertextLen = errors.New("xsalsa20symmetric: ciphertext is too short") + ErrCiphertextDecryption = errors.New("xsalsa20symmetric: ciphertext decryption failed") +) + +// secret must be 32 bytes long. Use something like Sha256(Bcrypt(passphrase)) +// The ciphertext is (secretbox.Overhead + 24) bytes longer than the plaintext. +func EncryptSymmetric(plaintext []byte, secret []byte) (ciphertext []byte) { + if len(secret) != secretLen { + panic(fmt.Sprintf("Secret must be 32 bytes long, got len %v", len(secret))) + } + nonce := crypto.CRandBytes(nonceLen) + nonceArr := [nonceLen]byte{} + copy(nonceArr[:], nonce) + secretArr := [secretLen]byte{} + copy(secretArr[:], secret) + ciphertext = make([]byte, nonceLen+secretbox.Overhead+len(plaintext)) + copy(ciphertext, nonce) + secretbox.Seal(ciphertext[nonceLen:nonceLen], plaintext, &nonceArr, &secretArr) + return ciphertext +} + +// secret must be 32 bytes long. Use something like Sha256(Bcrypt(passphrase)) +// The ciphertext is (secretbox.Overhead + 24) bytes longer than the plaintext. +func DecryptSymmetric(ciphertext []byte, secret []byte) (plaintext []byte, err error) { + if len(secret) != secretLen { + panic(fmt.Sprintf("Secret must be 32 bytes long, got len %v", len(secret))) + } + if len(ciphertext) <= secretbox.Overhead+nonceLen { + return nil, ErrInvalidCiphertextLen + } + nonce := ciphertext[:nonceLen] + nonceArr := [nonceLen]byte{} + copy(nonceArr[:], nonce) + secretArr := [secretLen]byte{} + copy(secretArr[:], secret) + plaintext = make([]byte, len(ciphertext)-nonceLen-secretbox.Overhead) + _, ok := secretbox.Open(plaintext[:0], ciphertext[nonceLen:], &nonceArr, &secretArr) + if !ok { + return nil, ErrCiphertextDecryption + } + return plaintext, nil +} diff --git a/symmetric/xsalsa20symmetric/symmetric_test.go b/symmetric/xsalsa20symmetric/symmetric_test.go new file mode 100644 index 0000000..785f4cc --- /dev/null +++ b/symmetric/xsalsa20symmetric/symmetric_test.go @@ -0,0 +1,36 @@ +package xsalsa20symmetric + +import ( + "github.com/cosmos/crypto/hash/sha256" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" +) + +func TestSimple(t *testing.T) { + plaintext := []byte("sometext") + secret := []byte("somesecretoflengththirtytwo===32") + ciphertext := EncryptSymmetric(plaintext, secret) + plaintext2, err := DecryptSymmetric(ciphertext, secret) + + require.NoError(t, err, "%+v", err) + assert.Equal(t, plaintext, plaintext2) +} + +func TestSimpleWithKDF(t *testing.T) { + plaintext := []byte("sometext") + secretPass := []byte("somesecret") + secret, err := bcrypt.GenerateFromPassword(secretPass, 12) + if err != nil { + t.Error(err) + } + secret = sha256.Sum(secret) + + ciphertext := EncryptSymmetric(plaintext, secret) + plaintext2, err := DecryptSymmetric(ciphertext, secret) + + require.NoError(t, err, "%+v", err) + assert.Equal(t, plaintext, plaintext2) +} diff --git a/types/address.go b/types/address.go new file mode 100644 index 0000000..db2f3b5 --- /dev/null +++ b/types/address.go @@ -0,0 +1,20 @@ +package types + +import ( + "github.com/cosmos/crypto/hash/sha256" + "github.com/cosmos/crypto/internal/libs/bytes" +) + +const ( + // AddressSize is the size of a pubkey address. + AddressSize = sha256.TruncatedSize +) + +// Address An address is a []byte, but hex-encoded even in JSON. +// []byte leaves us the option to change the address length. +// Use an alias so Unmarshal methods (with ptr receivers) are available too. +type Address = bytes.HexBytes + +func AddressHash(bz []byte) Address { + return sha256.SumTruncated(bz) +} diff --git a/types/keys.go b/types/keys.go new file mode 100644 index 0000000..44702b4 --- /dev/null +++ b/types/keys.go @@ -0,0 +1,18 @@ +package types + +type PubKey interface { + Address() Address + Bytes() []byte + VerifySignature(msg []byte, sig []byte) bool + Equals(other PubKey) bool + Type() string +} + +// PrivKey interface with generics +type PrivKey[T PubKey] interface { + Bytes() []byte + Sign(msg []byte) ([]byte, error) + PubKey() T + Equals(other PrivKey[T]) bool + Type() string +}