diff --git a/CHANGELOG.md b/CHANGELOG.md index 325d7319..99ed9194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Added - Initial support for Component Model [async](https://github.com/WebAssembly/component-model/blob/main/design/mvp/Async.md) types `stream`, `future`, and `error-context`. +- Initial support for JSON serialization of WIT `list`, `enum`, and `record` types. - [`wasm-tools`](https://crates.io/crates/wasm-tools) is now vendored as a WebAssembly module, executed using [Wazero](https://wazero.io/). This allows package `wit` and `wit-bindgen-go` to run on any supported platform without needing to separately install `wasm-tools`. ### Changed diff --git a/cm/CHANGELOG.md b/cm/CHANGELOG.md index 5dcb3dd2..cc08b9e5 100644 --- a/cm/CHANGELOG.md +++ b/cm/CHANGELOG.md @@ -7,7 +7,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), ### Added - Initial support for Component Model [async](https://github.com/WebAssembly/component-model/blob/main/design/mvp/Async.md) types `stream`, `future`, and `error-context`. -- Initial support for JSON serialization of WIT types, starting with `list` and `record`. +- Initial support for JSON serialization of WIT `list`, `enum`, and `record` types. +- Added `cm.CaseUnmarshaler` helper for text and JSON unmarshaling of `enum` and `variant` types. + +### Changed + +- Breaking: package `cm`: removed `bool` from `Discriminant` type constraint. It was not used by code generation. ## [v0.1.0] — 2024-12-14 diff --git a/cm/case.go b/cm/case.go new file mode 100644 index 00000000..65ade494 --- /dev/null +++ b/cm/case.go @@ -0,0 +1,51 @@ +package cm + +import "errors" + +// CaseUnmarshaler returns an function that can unmarshal text into +// [variant] or [enum] case T. +// +// [enum]: https://component-model.bytecodealliance.org/design/wit.html#enums +// [variant]: https://component-model.bytecodealliance.org/design/wit.html#variants +func CaseUnmarshaler[T ~uint8 | ~uint16 | ~uint32](cases []string) func(v *T, text []byte) error { + if len(cases) <= linearScanThreshold { + return func(v *T, text []byte) error { + if len(text) == 0 { + return errEmpty + } + s := string(text) + for i := 0; i < len(cases); i++ { + if cases[i] == s { + *v = T(i) + return nil + } + } + return errNoMatchingCase + } + } + + m := make(map[string]T, len(cases)) + for i, v := range cases { + m[v] = T(i) + } + + return func(v *T, text []byte) error { + if len(text) == 0 { + return errEmpty + } + s := string(text) + c, ok := m[s] + if !ok { + return errNoMatchingCase + } + *v = c + return nil + } +} + +const linearScanThreshold = 16 + +var ( + errEmpty = errors.New("empty text") + errNoMatchingCase = errors.New("no matching case") +) diff --git a/cm/case_test.go b/cm/case_test.go new file mode 100644 index 00000000..4711684d --- /dev/null +++ b/cm/case_test.go @@ -0,0 +1,34 @@ +package cm + +import ( + "strings" + "testing" +) + +func TestCaseUnmarshaler(t *testing.T) { + tests := []struct { + name string + cases []string + }{ + {"nil", nil}, + {"empty slice", []string{}}, + {"a b c", strings.SplitAfter("abc", "")}, + {"a b c d e f g", strings.SplitAfter("abcdefg", "")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := CaseUnmarshaler[uint8](tt.cases) + for want, c := range tt.cases { + var got uint8 + err := f(&got, []byte(c)) + if err != nil { + t.Error(err) + return + } + if got != uint8(want) { + t.Errorf("f(%q): got %d, expected %d", c, got, want) + } + } + }) + } +} diff --git a/cm/variant.go b/cm/variant.go index 24703641..d0def34b 100644 --- a/cm/variant.go +++ b/cm/variant.go @@ -3,10 +3,9 @@ package cm import "unsafe" // Discriminant is the set of types that can represent the tag or discriminator of a variant. -// Use bool for 2-case variant types, result, or option types, uint8 where there are 256 or -// fewer cases, uint16 for up to 65,536 cases, or uint32 for anything greater. +// Use uint8 where there are 256 or fewer cases, uint16 for up to 65,536 cases, or uint32 for anything greater. type Discriminant interface { - bool | uint8 | uint16 | uint32 + uint8 | uint16 | uint32 } // Variant represents a loosely-typed Component Model variant. diff --git a/cm/variant_test.go b/cm/variant_test.go index 9977790c..0978e45c 100644 --- a/cm/variant_test.go +++ b/cm/variant_test.go @@ -17,18 +17,18 @@ func TestVariantLayout(t *testing.T) { size uintptr offset uintptr }{ - {"variant { string; string }", Variant[bool, string, string]{}, sizePlusAlignOf[string](), ptrSize}, - {"variant { bool; string }", Variant[bool, string, bool]{}, sizePlusAlignOf[string](), ptrSize}, - {"variant { string; _ }", Variant[bool, string, string]{}, sizePlusAlignOf[string](), ptrSize}, - {"variant { _; _ }", Variant[bool, string, struct{}]{}, sizePlusAlignOf[string](), ptrSize}, - {"variant { u64; u64 }", Variant[bool, uint64, uint64]{}, 16, alignOf[uint64]()}, - {"variant { u32; u64 }", Variant[bool, uint64, uint32]{}, 16, alignOf[uint64]()}, - {"variant { u64; u32 }", Variant[bool, uint64, uint32]{}, 16, alignOf[uint64]()}, - {"variant { u8; u64 }", Variant[bool, uint64, uint8]{}, 16, alignOf[uint64]()}, - {"variant { u64; u8 }", Variant[bool, uint64, uint8]{}, 16, alignOf[uint64]()}, - {"variant { u8; u32 }", Variant[bool, uint32, uint8]{}, 8, alignOf[uint32]()}, - {"variant { u32; u8 }", Variant[bool, uint32, uint8]{}, 8, alignOf[uint32]()}, - {"variant { [9]u8, u64 }", Variant[bool, [9]byte, uint64]{}, 24, alignOf[uint64]()}, + {"variant { string; string }", Variant[uint8, string, string]{}, sizePlusAlignOf[string](), ptrSize}, + {"variant { bool; string }", Variant[uint8, string, bool]{}, sizePlusAlignOf[string](), ptrSize}, + {"variant { string; _ }", Variant[uint8, string, string]{}, sizePlusAlignOf[string](), ptrSize}, + {"variant { _; _ }", Variant[uint8, string, struct{}]{}, sizePlusAlignOf[string](), ptrSize}, + {"variant { u64; u64 }", Variant[uint8, uint64, uint64]{}, 16, alignOf[uint64]()}, + {"variant { u32; u64 }", Variant[uint8, uint64, uint32]{}, 16, alignOf[uint64]()}, + {"variant { u64; u32 }", Variant[uint8, uint64, uint32]{}, 16, alignOf[uint64]()}, + {"variant { u8; u64 }", Variant[uint8, uint64, uint8]{}, 16, alignOf[uint64]()}, + {"variant { u64; u8 }", Variant[uint8, uint64, uint8]{}, 16, alignOf[uint64]()}, + {"variant { u8; u32 }", Variant[uint8, uint32, uint8]{}, 8, alignOf[uint32]()}, + {"variant { u32; u8 }", Variant[uint8, uint32, uint8]{}, 8, alignOf[uint32]()}, + {"variant { [9]u8, u64 }", Variant[uint8, [9]byte, uint64]{}, 24, alignOf[uint64]()}, } for _, tt := range tests { diff --git a/tests/generated/wasi/filesystem/v0.2.0/types/types.wit.go b/tests/generated/wasi/filesystem/v0.2.0/types/types.wit.go index a6783462..f8252d54 100755 --- a/tests/generated/wasi/filesystem/v0.2.0/types/types.wit.go +++ b/tests/generated/wasi/filesystem/v0.2.0/types/types.wit.go @@ -106,7 +106,7 @@ const ( DescriptorTypeSocket ) -var stringsDescriptorType = [8]string{ +var _DescriptorTypeStrings = [8]string{ "unknown", "block-device", "character-device", @@ -119,9 +119,22 @@ var stringsDescriptorType = [8]string{ // String implements [fmt.Stringer], returning the enum case name of e. func (e DescriptorType) String() string { - return stringsDescriptorType[e] + return _DescriptorTypeStrings[e] } +// MarshalText implements [encoding.TextMarshaler]. +func (e DescriptorType) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler], unmarshaling into an enum +// case. Returns an error if the supplied text is not one of the enum cases. +func (e *DescriptorType) UnmarshalText(text []byte) error { + return _DescriptorTypeUnmarshalCase(e, text) +} + +var _DescriptorTypeUnmarshalCase = cm.CaseUnmarshaler[DescriptorType](_DescriptorTypeStrings[:]) + // DescriptorFlags represents the flags "wasi:filesystem/types@0.2.0#descriptor-flags". // // Descriptor flags. @@ -326,7 +339,7 @@ func (self *NewTimestamp) Timestamp() *DateTime { return cm.Case[DateTime](self, 2) } -var stringsNewTimestamp = [3]string{ +var _NewTimestampStrings = [3]string{ "no-change", "now", "timestamp", @@ -334,7 +347,7 @@ var stringsNewTimestamp = [3]string{ // String implements [fmt.Stringer], returning the variant case name of v. func (v NewTimestamp) String() string { - return stringsNewTimestamp[v.Tag()] + return _NewTimestampStrings[v.Tag()] } // DirectoryEntry represents the record "wasi:filesystem/types@0.2.0#directory-entry". @@ -516,7 +529,7 @@ const ( ErrorCodeCrossDevice ) -var stringsErrorCode = [37]string{ +var _ErrorCodeStrings = [37]string{ "access", "would-block", "already", @@ -558,9 +571,22 @@ var stringsErrorCode = [37]string{ // String implements [fmt.Stringer], returning the enum case name of e. func (e ErrorCode) String() string { - return stringsErrorCode[e] + return _ErrorCodeStrings[e] +} + +// MarshalText implements [encoding.TextMarshaler]. +func (e ErrorCode) MarshalText() ([]byte, error) { + return []byte(e.String()), nil } +// UnmarshalText implements [encoding.TextUnmarshaler], unmarshaling into an enum +// case. Returns an error if the supplied text is not one of the enum cases. +func (e *ErrorCode) UnmarshalText(text []byte) error { + return _ErrorCodeUnmarshalCase(e, text) +} + +var _ErrorCodeUnmarshalCase = cm.CaseUnmarshaler[ErrorCode](_ErrorCodeStrings[:]) + // Advice represents the enum "wasi:filesystem/types@0.2.0#advice". // // File or memory access pattern advisory information. @@ -601,7 +627,7 @@ const ( AdviceNoReuse ) -var stringsAdvice = [6]string{ +var _AdviceStrings = [6]string{ "normal", "sequential", "random", @@ -612,9 +638,22 @@ var stringsAdvice = [6]string{ // String implements [fmt.Stringer], returning the enum case name of e. func (e Advice) String() string { - return stringsAdvice[e] + return _AdviceStrings[e] +} + +// MarshalText implements [encoding.TextMarshaler]. +func (e Advice) MarshalText() ([]byte, error) { + return []byte(e.String()), nil } +// UnmarshalText implements [encoding.TextUnmarshaler], unmarshaling into an enum +// case. Returns an error if the supplied text is not one of the enum cases. +func (e *Advice) UnmarshalText(text []byte) error { + return _AdviceUnmarshalCase(e, text) +} + +var _AdviceUnmarshalCase = cm.CaseUnmarshaler[Advice](_AdviceStrings[:]) + // MetadataHashValue represents the record "wasi:filesystem/types@0.2.0#metadata-hash-value". // // A 128-bit hash value, split into parts because wasm doesn't have a diff --git a/tests/generated/wasi/io/v0.2.0/streams/streams.wit.go b/tests/generated/wasi/io/v0.2.0/streams/streams.wit.go index 48336ed2..e3a381b7 100755 --- a/tests/generated/wasi/io/v0.2.0/streams/streams.wit.go +++ b/tests/generated/wasi/io/v0.2.0/streams/streams.wit.go @@ -64,14 +64,14 @@ func (self *StreamError) Closed() bool { return self.Tag() == 1 } -var stringsStreamError = [2]string{ +var _StreamErrorStrings = [2]string{ "last-operation-failed", "closed", } // String implements [fmt.Stringer], returning the variant case name of v. func (v StreamError) String() string { - return stringsStreamError[v.Tag()] + return _StreamErrorStrings[v.Tag()] } // InputStream represents the imported resource "wasi:io/streams@0.2.0#input-stream". diff --git a/tests/generated/wasi/sockets/v0.2.0/network/network.wit.go b/tests/generated/wasi/sockets/v0.2.0/network/network.wit.go index 5263ce67..a37b5b9a 100755 --- a/tests/generated/wasi/sockets/v0.2.0/network/network.wit.go +++ b/tests/generated/wasi/sockets/v0.2.0/network/network.wit.go @@ -153,7 +153,7 @@ const ( ErrorCodePermanentResolverFailure ) -var stringsErrorCode = [21]string{ +var _ErrorCodeStrings = [21]string{ "unknown", "access-denied", "not-supported", @@ -179,9 +179,22 @@ var stringsErrorCode = [21]string{ // String implements [fmt.Stringer], returning the enum case name of e. func (e ErrorCode) String() string { - return stringsErrorCode[e] + return _ErrorCodeStrings[e] } +// MarshalText implements [encoding.TextMarshaler]. +func (e ErrorCode) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler], unmarshaling into an enum +// case. Returns an error if the supplied text is not one of the enum cases. +func (e *ErrorCode) UnmarshalText(text []byte) error { + return _ErrorCodeUnmarshalCase(e, text) +} + +var _ErrorCodeUnmarshalCase = cm.CaseUnmarshaler[ErrorCode](_ErrorCodeStrings[:]) + // IPAddressFamily represents the enum "wasi:sockets/network@0.2.0#ip-address-family". // // enum ip-address-family { @@ -198,16 +211,29 @@ const ( IPAddressFamilyIPv6 ) -var stringsIPAddressFamily = [2]string{ +var _IPAddressFamilyStrings = [2]string{ "ipv4", "ipv6", } // String implements [fmt.Stringer], returning the enum case name of e. func (e IPAddressFamily) String() string { - return stringsIPAddressFamily[e] + return _IPAddressFamilyStrings[e] } +// MarshalText implements [encoding.TextMarshaler]. +func (e IPAddressFamily) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler], unmarshaling into an enum +// case. Returns an error if the supplied text is not one of the enum cases. +func (e *IPAddressFamily) UnmarshalText(text []byte) error { + return _IPAddressFamilyUnmarshalCase(e, text) +} + +var _IPAddressFamilyUnmarshalCase = cm.CaseUnmarshaler[IPAddressFamily](_IPAddressFamilyStrings[:]) + // IPv4Address represents the tuple "wasi:sockets/network@0.2.0#ipv4-address". // // type ipv4-address = tuple @@ -246,14 +272,14 @@ func (self *IPAddress) IPv6() *IPv6Address { return cm.Case[IPv6Address](self, 1) } -var stringsIPAddress = [2]string{ +var _IPAddressStrings = [2]string{ "ipv4", "ipv6", } // String implements [fmt.Stringer], returning the variant case name of v. func (v IPAddress) String() string { - return stringsIPAddress[v.Tag()] + return _IPAddressStrings[v.Tag()] } // IPv4SocketAddress represents the record "wasi:sockets/network@0.2.0#ipv4-socket-address". @@ -322,12 +348,12 @@ func (self *IPSocketAddress) IPv6() *IPv6SocketAddress { return cm.Case[IPv6SocketAddress](self, 1) } -var stringsIPSocketAddress = [2]string{ +var _IPSocketAddressStrings = [2]string{ "ipv4", "ipv6", } // String implements [fmt.Stringer], returning the variant case name of v. func (v IPSocketAddress) String() string { - return stringsIPSocketAddress[v.Tag()] + return _IPSocketAddressStrings[v.Tag()] } diff --git a/tests/generated/wasi/sockets/v0.2.0/tcp/tcp.wit.go b/tests/generated/wasi/sockets/v0.2.0/tcp/tcp.wit.go index f43fd44f..07cad4c4 100755 --- a/tests/generated/wasi/sockets/v0.2.0/tcp/tcp.wit.go +++ b/tests/generated/wasi/sockets/v0.2.0/tcp/tcp.wit.go @@ -71,7 +71,7 @@ const ( ShutdownTypeBoth ) -var stringsShutdownType = [3]string{ +var _ShutdownTypeStrings = [3]string{ "receive", "send", "both", @@ -79,9 +79,22 @@ var stringsShutdownType = [3]string{ // String implements [fmt.Stringer], returning the enum case name of e. func (e ShutdownType) String() string { - return stringsShutdownType[e] + return _ShutdownTypeStrings[e] } +// MarshalText implements [encoding.TextMarshaler]. +func (e ShutdownType) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements [encoding.TextUnmarshaler], unmarshaling into an enum +// case. Returns an error if the supplied text is not one of the enum cases. +func (e *ShutdownType) UnmarshalText(text []byte) error { + return _ShutdownTypeUnmarshalCase(e, text) +} + +var _ShutdownTypeUnmarshalCase = cm.CaseUnmarshaler[ShutdownType](_ShutdownTypeStrings[:]) + // TCPSocket represents the imported resource "wasi:sockets/tcp@0.2.0#tcp-socket". // // A TCP socket resource. diff --git a/tests/json/json_test.go b/tests/json/json_test.go new file mode 100644 index 00000000..6ff48433 --- /dev/null +++ b/tests/json/json_test.go @@ -0,0 +1,99 @@ +package json_test + +import ( + "encoding/json" + "reflect" + "testing" + + wallclock "tests/generated/wasi/clocks/v0.2.0/wall-clock" + "tests/generated/wasi/filesystem/v0.2.0/types" + + "go.bytecodealliance.org/cm" +) + +func TestJSON(t *testing.T) { + tests := []struct { + name string + json string + into any + want any + wantErr bool + }{ + { + "nil", + `null`, + ptr(ptr("")), + ptr((*string)(nil)), + false, + }, + { + "descriptor-type(block-device)", + `"block-device"`, + ptr(types.DescriptorType(0)), + ptr(types.DescriptorTypeBlockDevice), + false, + }, + { + "descriptor-type(directory)", + `"directory"`, + ptr(types.DescriptorType(0)), + ptr(types.DescriptorTypeDirectory), + false, + }, + { + "datetime", + `{"seconds":1,"nanoseconds":2}`, + &wallclock.DateTime{}, + &wallclock.DateTime{Seconds: 1, Nanoseconds: 2}, + false, + }, + { + "empty list", + `[]`, + &cm.List[uint8]{}, + &cm.List[uint8]{}, + false, + }, + { + "list of bool", + `[false,true,false]`, + &cm.List[bool]{}, + ptr(cm.ToList([]bool{false, true, false})), + false, + }, + { + "list of u32", + `[1,2,3]`, + &cm.List[uint32]{}, + ptr(cm.ToList([]uint32{1, 2, 3})), + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := json.Unmarshal([]byte(tt.json), &tt.into) + if tt.wantErr && err == nil { + t.Errorf("json.Unmarshal(%q): expected error, got nil error", tt.json) + return + } else if !tt.wantErr && err != nil { + t.Errorf("json.Unmarshal(%q): expected no error, got error: %v", tt.json, err) + return + } + got, err := json.Marshal(tt.into) + if err != nil { + t.Error(err) + return + } + if string(got) != tt.json { + if !reflect.DeepEqual(tt.want, tt.into) { + t.Errorf("json.Unmarshal(%q): resulting value different (%v != %v)", tt.json, tt.into, tt.want) + } + t.Errorf("json.Marshal(%v): %s, expected %s", tt.into, string(got), tt.json) + } + }) + } +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/wit/bindgen/generator.go b/wit/bindgen/generator.go index 19d659f6..4ad188a9 100644 --- a/wit/bindgen/generator.go +++ b/wit/bindgen/generator.go @@ -575,7 +575,9 @@ func (g *generator) declareTypeDef(file *gen.File, dir wit.Direction, t *wit.Typ // Predeclare reserved methods. switch t.Kind.(type) { case *wit.Enum: - decl.scope.DeclareName("String") // For fmt.Stringer + decl.scope.DeclareName("String") // For fmt.Stringer + decl.scope.DeclareName("MarshalText") // For encoding.TextMarshaler + decl.scope.DeclareName("UnmarshalText") // For encoding.TextUnmarshaler case *wit.Variant: decl.scope.DeclareName("Tag") // Method on cm.Variant decl.scope.DeclareName("String") // For fmt.Stringer @@ -839,7 +841,7 @@ func (g *generator) enumRep(file *gen.File, dir wit.Direction, e *wit.Enum, goNa } b.WriteString(")\n\n") - stringsName := file.DeclareName("strings" + GoName(goName, true)) + stringsName := file.DeclareName("_" + GoName(goName, true) + "Strings") stringio.Write(&b, "var ", stringsName, " = [", fmt.Sprintf("%d", len(e.Cases)), "]string {\n") for _, c := range e.Cases { stringio.Write(&b, `"`, c.Name, `"`, ",\n") @@ -851,6 +853,20 @@ func (g *generator) enumRep(file *gen.File, dir wit.Direction, e *wit.Enum, goNa stringio.Write(&b, "return ", stringsName, "[e]\n") b.WriteString("}\n\n") + b.WriteString(formatDocComments("MarshalText implements [encoding.TextMarshaler].", true)) + stringio.Write(&b, "func (e ", goName, ") MarshalText() ([]byte, error) {\n") + stringio.Write(&b, "return []byte(e.String()), nil\n") + b.WriteString("}\n\n") + + unmarshalName := file.DeclareName("_" + GoName(goName, true) + "UnmarshalCase") + + b.WriteString(formatDocComments("UnmarshalText implements [encoding.TextUnmarshaler], unmarshaling into an enum case. Returns an error if the supplied text is not one of the enum cases.", true)) + stringio.Write(&b, "func (e *", goName, ") UnmarshalText(text []byte) error {\n") + stringio.Write(&b, "return ", unmarshalName, "(e, text)\n") + b.WriteString("}\n\n") + + stringio.Write(&b, "var ", unmarshalName, " = ", file.Import(g.opts.cmPackage), ".CaseUnmarshaler[", goName, "](", stringsName, "[:])\n") + return b.String() } @@ -921,7 +937,7 @@ func (g *generator) variantRep(file *gen.File, dir wit.Direction, t *wit.TypeDef } } - stringsName := file.DeclareName("strings" + GoName(goName, true)) + stringsName := file.DeclareName("_" + GoName(goName, true) + "Strings") stringio.Write(&b, "var ", stringsName, " = [", fmt.Sprintf("%d", len(v.Cases)), "]string {\n") for _, c := range v.Cases { stringio.Write(&b, `"`, c.Name, `"`, ",\n")