-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8121 from Roasbeef/tlv-record-type-param
tlv: add new RecordT[T] utility type
- Loading branch information
Showing
7 changed files
with
1,073 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package main | ||
|
||
import ( | ||
"bytes" | ||
"flag" | ||
"os" | ||
"text/template" | ||
) | ||
|
||
const ( | ||
numberOfTypes = 100 | ||
defaultOutputFile = "tlv_types_generated.go" | ||
) | ||
|
||
const typeCodeTemplate = `// Code generated by tlv/internal/gen; DO NOT EDIT. | ||
package tlv | ||
{{- range $index, $element := . }} | ||
type tlvType{{ $index }} struct{} | ||
func (t *tlvType{{ $index }}) typeVal() Type { | ||
return {{ $index }} | ||
} | ||
type TlvType{{ $index }} = *tlvType{{ $index }} | ||
{{- end }} | ||
` | ||
|
||
func main() { | ||
// Create a slice of items that the template can range over. | ||
var items []struct{} | ||
for i := uint16(0); i <= numberOfTypes; i++ { | ||
items = append(items, struct{}{}) | ||
} | ||
|
||
tpl, err := template.New("tlv").Parse(typeCodeTemplate) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
// Execute the template | ||
var out bytes.Buffer | ||
err = tpl.Execute(&out, items) | ||
if err != nil { | ||
panic(err) | ||
} | ||
|
||
outputFile := flag.String( | ||
"o", defaultOutputFile, "Output file for generated code", | ||
) | ||
flag.Parse() | ||
|
||
err = os.WriteFile(*outputFile, out.Bytes(), 0644) | ||
if err != nil { | ||
panic(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package tlv | ||
|
||
import ( | ||
"github.com/btcsuite/btcd/btcec/v2" | ||
"github.com/lightningnetwork/lnd/fn" | ||
"golang.org/x/exp/constraints" | ||
) | ||
|
||
// RecordT is a high-order type makes it easy to encode known "primitive" types | ||
// as TLV records. | ||
type RecordT[T TlvType, V any] struct { | ||
// recordType is the type of the TLV record. | ||
recordType T | ||
|
||
// val is the value of the underlying record. Go doesn't let us just | ||
// embed the type param as a struct field, so we need an intermediate | ||
// variable. | ||
Val V | ||
} | ||
|
||
// RecordProducerT is a type-aware wrapper around the normal RecordProducer | ||
// interface. | ||
type RecordProducerT[T any] interface { | ||
RecordProducer | ||
|
||
// This is a non-interface type constraint that allows us to pass a | ||
// concrete type as a type parameter rather than a pointer to the type | ||
// that satisfies the Record interface. | ||
*T | ||
} | ||
|
||
// NewRecordT creates a new RecordT type from a given RecordProducer type. This | ||
// is useful to wrap a given record in this utility type, which also serves as | ||
// an extra type annotation. The underlying type of the record is retained. | ||
func NewRecordT[T TlvType, K any, V RecordProducerT[K]]( | ||
record K, | ||
) RecordT[T, K] { | ||
|
||
return RecordT[T, K]{ | ||
Val: record, | ||
} | ||
} | ||
|
||
// Primitive is a type constraint that capture the set of "primitive" types, | ||
// which are the built in stdlib types, and type defs of those types | ||
type Primitive interface { | ||
constraints.Unsigned | ~[]byte | ~[32]byte | ~[33]byte | ~bool | | ||
~*btcec.PublicKey | ~[64]byte | ||
} | ||
|
||
// NewPrimitiveRecord creates a new RecordT type from a given primitive type. | ||
func NewPrimitiveRecord[T TlvType, V Primitive](val V) RecordT[T, V] { | ||
return RecordT[T, V]{ | ||
Val: val, | ||
} | ||
} | ||
|
||
// Record returns the underlying record interface for the record type. | ||
func (t *RecordT[T, V]) Record() Record { | ||
// Go doesn't allow type assertions on a type param, so to work around | ||
// this, we'll convert to any, then do our type assertion. | ||
tlvRecord, ok := any(&t.Val).(RecordProducer) | ||
if !ok { | ||
return MakePrimitiveRecord( | ||
t.recordType.typeVal(), &t.Val, | ||
) | ||
} | ||
|
||
return tlvRecord.Record() | ||
} | ||
|
||
// OptionalRecordT is a high-order type that represents an optional TLV record. | ||
// This can be used when a TLV record doesn't always need to be present (ok to | ||
// be odd). | ||
type OptionalRecordT[T TlvType, V any] struct { | ||
fn.Option[RecordT[T, V]] | ||
} | ||
|
||
// ZeroRecordT returns a zero value of the RecordT type. | ||
func ZeroRecordT[T TlvType, V any]() RecordT[T, V] { | ||
var v V | ||
return RecordT[T, V]{ | ||
Val: v, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package tlv | ||
|
||
import ( | ||
"bytes" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const ( | ||
fakeCsvDelayType = 1 | ||
|
||
fakeIsCoolType = 2 | ||
) | ||
|
||
type fakeWireMsg struct { | ||
CsvDelay RecordT[TlvType1, uint16] | ||
|
||
IsCool RecordT[TlvType2, bool] | ||
} | ||
|
||
// TestRecordTFromPrimitive tests the RecordT type. We should be able to create | ||
// types of both record types, and also primitive types, and encode/decode them | ||
// as normal. | ||
func TestRecordTFromPrimitive(t *testing.T) { | ||
t.Parallel() | ||
|
||
wireMsg := fakeWireMsg{ | ||
CsvDelay: NewPrimitiveRecord[TlvType1](uint16(5)), | ||
IsCool: NewPrimitiveRecord[TlvType2](true), | ||
} | ||
|
||
encodeStream, err := NewStream( | ||
wireMsg.CsvDelay.Record(), wireMsg.IsCool.Record(), | ||
) | ||
require.NoError(t, err) | ||
|
||
var b bytes.Buffer | ||
err = encodeStream.Encode(&b) | ||
require.NoError(t, err) | ||
|
||
var newWireMsg fakeWireMsg | ||
|
||
decodeStream, err := NewStream( | ||
newWireMsg.CsvDelay.Record(), | ||
newWireMsg.IsCool.Record(), | ||
) | ||
require.NoError(t, err) | ||
|
||
err = decodeStream.Decode(&b) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, wireMsg, newWireMsg) | ||
} | ||
|
||
type wireCsv uint16 | ||
|
||
func (w *wireCsv) Record() Record { | ||
return MakeStaticRecord(fakeCsvDelayType, (*uint16)(w), 2, EUint16, DUint16) | ||
} | ||
|
||
type coolWireMsg struct { | ||
CsvDelay RecordT[TlvType1, wireCsv] | ||
} | ||
|
||
// TestRecordTFromRecord tests that we can create a RecordT type from an | ||
// existing record type and encode/decode as normal. | ||
func TestRecordTFromRecord(t *testing.T) { | ||
t.Parallel() | ||
|
||
val := wireCsv(5) | ||
|
||
wireMsg := coolWireMsg{ | ||
CsvDelay: NewRecordT[TlvType1](val), | ||
} | ||
|
||
encodeStream, err := NewStream(wireMsg.CsvDelay.Record()) | ||
require.NoError(t, err) | ||
|
||
var b bytes.Buffer | ||
err = encodeStream.Encode(&b) | ||
require.NoError(t, err) | ||
|
||
var wireMsg2 coolWireMsg | ||
|
||
decodeStream, err := NewStream(wireMsg2.CsvDelay.Record()) | ||
require.NoError(t, err) | ||
|
||
err = decodeStream.Decode(&b) | ||
require.NoError(t, err) | ||
|
||
require.Equal(t, wireMsg, wireMsg2) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package tlv | ||
|
||
import "fmt" | ||
|
||
// TlvType is an interface used to enable binding the integer type of a TLV | ||
// record to the type at compile time. | ||
type TlvType interface { | ||
typeVal() Type | ||
} | ||
|
||
//go:generate go run internal/gen/gen_tlv_types.go -o tlv_types_generated.go | ||
|
||
func main() { | ||
// This function is only here to satisfy the go:generate directive. | ||
fmt.Println("Generating TLV type structures...") | ||
} |
Oops, something went wrong.