Skip to content

Commit

Permalink
Merge pull request #8121 from Roasbeef/tlv-record-type-param
Browse files Browse the repository at this point in the history
tlv: add new RecordT[T] utility type
  • Loading branch information
Roasbeef authored Dec 13, 2023
2 parents f2d48c3 + 63e86b7 commit ac9ca02
Show file tree
Hide file tree
Showing 7 changed files with 1,073 additions and 3 deletions.
4 changes: 3 additions & 1 deletion tlv/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ require (
github.com/btcsuite/btcd/btcec/v2 v2.1.3
github.com/davecgh/go-spew v1.1.1
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1
github.com/lightningnetwork/lnd/fn v1.0.0
github.com/stretchr/testify v1.8.2
golang.org/x/exp v0.0.0-20231006140011-7918f672742d
)

require (
Expand All @@ -14,7 +16,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.9.0 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/sys v0.13.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Expand Down
8 changes: 6 additions & 2 deletions tlv/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lightningnetwork/lnd/fn v1.0.0 h1:I5VG9AD63mOQ89RMQEu7HRI1r68wn8yz539LoylUIKM=
github.com/lightningnetwork/lnd/fn v1.0.0/go.mod h1:XV+0vBXSnh3aUjskJUv58TOpsveiXQ+ac8rEnXZDGFc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
Expand All @@ -33,8 +35,10 @@ github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
Expand Down
59 changes: 59 additions & 0 deletions tlv/internal/gen/gen_tlv_types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"bytes"
"flag"
"os"
"text/template"
)

const (
numberOfTypes = 100
defaultOutputFile = "tlv_types_generated.go"
)

const typeCodeTemplate = `// Code generated by tlv/internal/gen; DO NOT EDIT.
package tlv
{{- range $index, $element := . }}
type tlvType{{ $index }} struct{}
func (t *tlvType{{ $index }}) typeVal() Type {
return {{ $index }}
}
type TlvType{{ $index }} = *tlvType{{ $index }}
{{- end }}
`

func main() {
// Create a slice of items that the template can range over.
var items []struct{}
for i := uint16(0); i <= numberOfTypes; i++ {
items = append(items, struct{}{})
}

tpl, err := template.New("tlv").Parse(typeCodeTemplate)
if err != nil {
panic(err)
}

// Execute the template
var out bytes.Buffer
err = tpl.Execute(&out, items)
if err != nil {
panic(err)
}

outputFile := flag.String(
"o", defaultOutputFile, "Output file for generated code",
)
flag.Parse()

err = os.WriteFile(*outputFile, out.Bytes(), 0644)
if err != nil {
panic(err)
}
}
85 changes: 85 additions & 0 deletions tlv/record_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package tlv

import (
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/fn"
"golang.org/x/exp/constraints"
)

// RecordT is a high-order type makes it easy to encode known "primitive" types
// as TLV records.
type RecordT[T TlvType, V any] struct {
// recordType is the type of the TLV record.
recordType T

// val is the value of the underlying record. Go doesn't let us just
// embed the type param as a struct field, so we need an intermediate
// variable.
Val V
}

// RecordProducerT is a type-aware wrapper around the normal RecordProducer
// interface.
type RecordProducerT[T any] interface {
RecordProducer

// This is a non-interface type constraint that allows us to pass a
// concrete type as a type parameter rather than a pointer to the type
// that satisfies the Record interface.
*T
}

// NewRecordT creates a new RecordT type from a given RecordProducer type. This
// is useful to wrap a given record in this utility type, which also serves as
// an extra type annotation. The underlying type of the record is retained.
func NewRecordT[T TlvType, K any, V RecordProducerT[K]](
record K,
) RecordT[T, K] {

return RecordT[T, K]{
Val: record,
}
}

// Primitive is a type constraint that capture the set of "primitive" types,
// which are the built in stdlib types, and type defs of those types
type Primitive interface {
constraints.Unsigned | ~[]byte | ~[32]byte | ~[33]byte | ~bool |
~*btcec.PublicKey | ~[64]byte
}

// NewPrimitiveRecord creates a new RecordT type from a given primitive type.
func NewPrimitiveRecord[T TlvType, V Primitive](val V) RecordT[T, V] {
return RecordT[T, V]{
Val: val,
}
}

// Record returns the underlying record interface for the record type.
func (t *RecordT[T, V]) Record() Record {
// Go doesn't allow type assertions on a type param, so to work around
// this, we'll convert to any, then do our type assertion.
tlvRecord, ok := any(&t.Val).(RecordProducer)
if !ok {
return MakePrimitiveRecord(
t.recordType.typeVal(), &t.Val,
)
}

return tlvRecord.Record()
}

// OptionalRecordT is a high-order type that represents an optional TLV record.
// This can be used when a TLV record doesn't always need to be present (ok to
// be odd).
type OptionalRecordT[T TlvType, V any] struct {
fn.Option[RecordT[T, V]]
}

// ZeroRecordT returns a zero value of the RecordT type.
func ZeroRecordT[T TlvType, V any]() RecordT[T, V] {
var v V
return RecordT[T, V]{
Val: v,
}
}
93 changes: 93 additions & 0 deletions tlv/record_type_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package tlv

import (
"bytes"
"testing"

"github.com/stretchr/testify/require"
)

const (
fakeCsvDelayType = 1

fakeIsCoolType = 2
)

type fakeWireMsg struct {
CsvDelay RecordT[TlvType1, uint16]

IsCool RecordT[TlvType2, bool]
}

// TestRecordTFromPrimitive tests the RecordT type. We should be able to create
// types of both record types, and also primitive types, and encode/decode them
// as normal.
func TestRecordTFromPrimitive(t *testing.T) {
t.Parallel()

wireMsg := fakeWireMsg{
CsvDelay: NewPrimitiveRecord[TlvType1](uint16(5)),
IsCool: NewPrimitiveRecord[TlvType2](true),
}

encodeStream, err := NewStream(
wireMsg.CsvDelay.Record(), wireMsg.IsCool.Record(),
)
require.NoError(t, err)

var b bytes.Buffer
err = encodeStream.Encode(&b)
require.NoError(t, err)

var newWireMsg fakeWireMsg

decodeStream, err := NewStream(
newWireMsg.CsvDelay.Record(),
newWireMsg.IsCool.Record(),
)
require.NoError(t, err)

err = decodeStream.Decode(&b)
require.NoError(t, err)

require.Equal(t, wireMsg, newWireMsg)
}

type wireCsv uint16

func (w *wireCsv) Record() Record {
return MakeStaticRecord(fakeCsvDelayType, (*uint16)(w), 2, EUint16, DUint16)
}

type coolWireMsg struct {
CsvDelay RecordT[TlvType1, wireCsv]
}

// TestRecordTFromRecord tests that we can create a RecordT type from an
// existing record type and encode/decode as normal.
func TestRecordTFromRecord(t *testing.T) {
t.Parallel()

val := wireCsv(5)

wireMsg := coolWireMsg{
CsvDelay: NewRecordT[TlvType1](val),
}

encodeStream, err := NewStream(wireMsg.CsvDelay.Record())
require.NoError(t, err)

var b bytes.Buffer
err = encodeStream.Encode(&b)
require.NoError(t, err)

var wireMsg2 coolWireMsg

decodeStream, err := NewStream(wireMsg2.CsvDelay.Record())
require.NoError(t, err)

err = decodeStream.Decode(&b)
require.NoError(t, err)

require.Equal(t, wireMsg, wireMsg2)
}
16 changes: 16 additions & 0 deletions tlv/tlv_type_param.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package tlv

import "fmt"

// TlvType is an interface used to enable binding the integer type of a TLV
// record to the type at compile time.
type TlvType interface {
typeVal() Type
}

//go:generate go run internal/gen/gen_tlv_types.go -o tlv_types_generated.go

func main() {
// This function is only here to satisfy the go:generate directive.
fmt.Println("Generating TLV type structures...")
}
Loading

0 comments on commit ac9ca02

Please sign in to comment.