diff --git a/examples/helloworld/greeting/greet_host.pb.go b/examples/helloworld/greeting/greet_host.pb.go index 6a1de77..55e6e06 100644 --- a/examples/helloworld/greeting/greet_host.pb.go +++ b/examples/helloworld/greeting/greet_host.pb.go @@ -27,7 +27,7 @@ type GreeterPlugin struct { func NewGreeterPlugin(ctx context.Context, opts ...wazeroConfigOption) (*GreeterPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/examples/helloworld/greeting/greet_options.pb.go b/examples/helloworld/greeting/greet_options.pb.go index 837ef33..c47d13a 100644 --- a/examples/helloworld/greeting/greet_options.pb.go +++ b/examples/helloworld/greeting/greet_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/examples/host-function-library/README.md b/examples/host-function-library/README.md new file mode 100644 index 0000000..5eeef46 --- /dev/null +++ b/examples/host-function-library/README.md @@ -0,0 +1,106 @@ +# Register Host Function Library Example +This example shows how to embed functions defined in a host into plugins to expand the plugin functionalities. +Host functions can be defined either in the current plugin distribution in the proto file or imported from another module defined as a source for these host functions. +This allows host functions to be distributed as a library or SDK instead of having to copy them every time you create plugins. + +## Generate Go code for distributed host functions +A proto file is under `library/json-parser/export`. + +```protobuf +// Distributing host functions without plugin code +// go:plugin type=host module=json-parser +service ParserLibrary { + rpc ParseJson(ParseJsonRequest) returns (ParseJsonResponse) {} +} +``` + +> **_NOTE:_** You must specify `type=host` in the comment `module=json-parser` and module name to be unique for the host function redistribution and registration. +It represents the service is for host functions. + +Then, generate source code for `json-parser` module hosts functions. + +```shell +$ protoc --go-plugin_out=. --go-plugin_opt=paths=source_relative library/json-parser/export/library.proto +``` + +## Implement `json-parser` module host functions +The following interface is generated. + +```go +// Distributing host functions without plugin code +// go:plugin type=host module=json-parser +type ParserLibrary interface { + ParseJson(context.Context, *ParseJsonRequest) (*ParseJsonResponse, error) +} +``` + +Implement that interface in separate package which will be exported to the main plugin implementation. + +```go +// ParserLibraryImpl implements export.ParserLibrary functions +type ParserLibraryImpl struct{} + +// ParseJson is embedded into the plugin and can be called by the plugin. +func (ParserLibraryImpl) ParseJson(_ context.Context, request *export.ParseJsonRequest) (*export.ParseJsonResponse, error) { + var person export.Person + if err := json.Unmarshal(request.GetContent(), &person); err != nil { + return nil, err + } + + return &export.ParseJsonResponse{Response: &person}, nil +} +``` + +Then, generate source code stubs for the base plugin. + +```shell +$ protoc --go-plugin_out=. --go-plugin_opt=paths=source_relative proto/greer.proto +``` + +Register exported hosts functions module int while providing new WazeroRuntime in the `proto.NewGreeterPlugin()`. + +```go +ctx := context.Background() +p, err := proto.NewGreeterPlugin(ctx, proto.WazeroRuntime(func(ctx context.Context) (wazero.Runtime, error) { + r, err := proto.DefaultWazeroRuntime()(ctx) + if err != nil { + return nil, err + } + return r, export.Instantiate(ctx, r, impl.ParserLibraryImpl{}) +})) +``` + +## Call host functions in a plugin +The exported as a library or SDK host functions can be called in plugins. + +```go +parserLibrary := export.NewParserLibrary() + +// Call the host function to parse JSON +resp, err := parserLibrary.ParseJson(ctx, &export.ParseJsonRequest{ + Content: []byte(fmt.Sprintf(`{"name": "%s", "age": 20}`, sanrequest.Message)), +}) +if err != nil { + return nil, err +} +``` + +## Compile a plugin +Use TinyGo to compile the plugin to Wasm. + +```shell +$ go generate main.go +``` + +## Run +`main.go` loads the above plugin. + +```shell +$ go run main.go +``` +```shll +2022/08/28 10:13:57 Sending a HTTP request... +Hello, Sato. This is Yamada-san (age 20). +``` + + diff --git a/examples/host-function-library/library/json-parser/export/library.pb.go b/examples/host-function-library/library/json-parser/export/library.pb.go new file mode 100644 index 0000000..cb1fb7e --- /dev/null +++ b/examples/host-function-library/library/json-parser/export/library.pb.go @@ -0,0 +1,91 @@ +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/library/json-parser/export/library.proto + +package export + +import ( + context "context" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ParseJsonRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Content []byte `protobuf:"bytes,1,opt,name=content,proto3" json:"content,omitempty"` +} + +func (x *ParseJsonRequest) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *ParseJsonRequest) GetContent() []byte { + if x != nil { + return x.Content + } + return nil +} + +type ParseJsonResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Response *Person `protobuf:"bytes,1,opt,name=response,proto3" json:"response,omitempty"` +} + +func (x *ParseJsonResponse) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *ParseJsonResponse) GetResponse() *Person { + if x != nil { + return x.Response + } + return nil +} + +type Person struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Age int64 `protobuf:"varint,2,opt,name=age,proto3" json:"age,omitempty"` +} + +func (x *Person) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *Person) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *Person) GetAge() int64 { + if x != nil { + return x.Age + } + return 0 +} + +// Distributing host functions without plugin code +// go:plugin type=host module=json-parser +type ParserLibrary interface { + ParseJson(context.Context, *ParseJsonRequest) (*ParseJsonResponse, error) +} diff --git a/examples/host-function-library/library/json-parser/export/library.proto b/examples/host-function-library/library/json-parser/export/library.proto new file mode 100644 index 0000000..a70eddf --- /dev/null +++ b/examples/host-function-library/library/json-parser/export/library.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; +package host; + +option go_package = "github.com/knqyf263/go-plugin/examples/host-functions-library/library/json-parser/export"; + +// Distributing host functions without plugin code +// go:plugin type=host module=json-parser +service ParserLibrary { + rpc ParseJson(ParseJsonRequest) returns (ParseJsonResponse) {} +} + +message ParseJsonRequest { + bytes content = 1; +} + +message ParseJsonResponse { + Person response = 1; +} + +message Person { + string name = 1; + int64 age = 2; +} diff --git a/examples/host-function-library/library/json-parser/export/library_host.pb.go b/examples/host-function-library/library/json-parser/export/library_host.pb.go new file mode 100644 index 0000000..c998e6e --- /dev/null +++ b/examples/host-function-library/library/json-parser/export/library_host.pb.go @@ -0,0 +1,66 @@ +//go:build !tinygo.wasm + +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/library/json-parser/export/library.proto + +package export + +import ( + context "context" + wasm "github.com/knqyf263/go-plugin/wasm" + wazero "github.com/tetratelabs/wazero" + api "github.com/tetratelabs/wazero/api" +) + +const ( + i32 = api.ValueTypeI32 + i64 = api.ValueTypeI64 +) + +type _parserLibrary struct { + ParserLibrary +} + +// Instantiate a Go-defined module named "json-parser" that exports host functions. +func Instantiate(ctx context.Context, r wazero.Runtime, hostFunctions ParserLibrary) error { + envBuilder := r.NewHostModuleBuilder("json-parser") + h := _parserLibrary{hostFunctions} + + envBuilder.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc(h._ParseJson), []api.ValueType{i32, i32}, []api.ValueType{i64}). + WithParameterNames("offset", "size"). + Export("parse_json") + + _, err := envBuilder.Instantiate(ctx) + return err +} + +func (h _parserLibrary) _ParseJson(ctx context.Context, m api.Module, stack []uint64) { + offset, size := uint32(stack[0]), uint32(stack[1]) + buf, err := wasm.ReadMemory(m.Memory(), offset, size) + if err != nil { + panic(err) + } + request := new(ParseJsonRequest) + err = request.UnmarshalVT(buf) + if err != nil { + panic(err) + } + resp, err := h.ParseJson(ctx, request) + if err != nil { + panic(err) + } + buf, err = resp.MarshalVT() + if err != nil { + panic(err) + } + ptr, err := wasm.WriteMemory(ctx, m, buf) + if err != nil { + panic(err) + } + ptrLen := (ptr << uint64(32)) | uint64(len(buf)) + stack[0] = ptrLen +} diff --git a/examples/host-function-library/library/json-parser/export/library_plugin.pb.go b/examples/host-function-library/library/json-parser/export/library_plugin.pb.go new file mode 100644 index 0000000..49d3f58 --- /dev/null +++ b/examples/host-function-library/library/json-parser/export/library_plugin.pb.go @@ -0,0 +1,45 @@ +//go:build tinygo.wasm + +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/library/json-parser/export/library.proto + +package export + +import ( + context "context" + wasm "github.com/knqyf263/go-plugin/wasm" + _ "unsafe" +) + +type parserLibrary struct{} + +func NewParserLibrary() ParserLibrary { + return parserLibrary{} +} + +//go:wasm-module json-parser +//export parse_json +//go:linkname _parse_json +func _parse_json(ptr uint32, size uint32) uint64 + +func (h parserLibrary) ParseJson(ctx context.Context, request *ParseJsonRequest) (*ParseJsonResponse, error) { + buf, err := request.MarshalVT() + if err != nil { + return nil, err + } + ptr, size := wasm.ByteToPtr(buf) + ptrSize := _parse_json(ptr, size) + + ptr = uint32(ptrSize >> 32) + size = uint32(ptrSize) + buf = wasm.PtrToByte(ptr, size) + + response := new(ParseJsonResponse) + if err = response.UnmarshalVT(buf); err != nil { + return nil, err + } + return response, nil +} diff --git a/examples/host-function-library/library/json-parser/export/library_vtproto.pb.go b/examples/host-function-library/library/json-parser/export/library_vtproto.pb.go new file mode 100644 index 0000000..0c151e4 --- /dev/null +++ b/examples/host-function-library/library/json-parser/export/library_vtproto.pb.go @@ -0,0 +1,576 @@ +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/library/json-parser/export/library.proto + +package export + +import ( + fmt "fmt" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + io "io" + bits "math/bits" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +func (m *ParseJsonRequest) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ParseJsonRequest) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *ParseJsonRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Content) > 0 { + i -= len(m.Content) + copy(dAtA[i:], m.Content) + i = encodeVarint(dAtA, i, uint64(len(m.Content))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *ParseJsonResponse) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ParseJsonResponse) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *ParseJsonResponse) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if m.Response != nil { + size, err := m.Response.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *Person) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Person) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *Person) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if m.Age != 0 { + i = encodeVarint(dAtA, i, uint64(m.Age)) + i-- + dAtA[i] = 0x10 + } + if len(m.Name) > 0 { + i -= len(m.Name) + copy(dAtA[i:], m.Name) + i = encodeVarint(dAtA, i, uint64(len(m.Name))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarint(dAtA []byte, offset int, v uint64) int { + offset -= sov(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *ParseJsonRequest) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Content) + if l > 0 { + n += 1 + l + sov(uint64(l)) + } + if m.unknownFields != nil { + n += len(m.unknownFields) + } + return n +} + +func (m *ParseJsonResponse) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Response != nil { + l = m.Response.SizeVT() + n += 1 + l + sov(uint64(l)) + } + if m.unknownFields != nil { + n += len(m.unknownFields) + } + return n +} + +func (m *Person) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Name) + if l > 0 { + n += 1 + l + sov(uint64(l)) + } + if m.Age != 0 { + n += 1 + sov(uint64(m.Age)) + } + if m.unknownFields != nil { + n += len(m.unknownFields) + } + return n +} + +func sov(x uint64) (n int) { + return (bits.Len64(x|1) + 6) / 7 +} +func soz(x uint64) (n int) { + return sov(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *ParseJsonRequest) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ParseJsonRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ParseJsonRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Content", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Content = append(m.Content[:0], dAtA[iNdEx:postIndex]...) + if m.Content == nil { + m.Content = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ParseJsonResponse) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ParseJsonResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ParseJsonResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Response", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Response == nil { + m.Response = &Person{} + } + if err := m.Response.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Person) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Person: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Person: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Name = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Age", wireType) + } + m.Age = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Age |= int64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skip(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLength + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroup + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLength + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLength = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflow = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroup = fmt.Errorf("proto: unexpected end of group") +) diff --git a/examples/host-function-library/library/json-parser/impl/parser.go b/examples/host-function-library/library/json-parser/impl/parser.go new file mode 100644 index 0000000..fd7cd88 --- /dev/null +++ b/examples/host-function-library/library/json-parser/impl/parser.go @@ -0,0 +1,22 @@ +package impl + +import ( + "context" + "encoding/json" + "github.com/knqyf263/go-plugin/examples/host-function-library/library/json-parser/export" +) + +var _ export.ParserLibrary = (*ParserLibraryImpl)(nil) + +// ParserLibraryImpl implements export.ParserLibrary functions +type ParserLibraryImpl struct{} + +// ParseJson is embedded into the plugin and can be called by the plugin. +func (ParserLibraryImpl) ParseJson(_ context.Context, request *export.ParseJsonRequest) (*export.ParseJsonResponse, error) { + var person export.Person + if err := json.Unmarshal(request.GetContent(), &person); err != nil { + return nil, err + } + + return &export.ParseJsonResponse{Response: &person}, nil +} diff --git a/examples/host-function-library/main.go b/examples/host-function-library/main.go new file mode 100644 index 0000000..82d92ef --- /dev/null +++ b/examples/host-function-library/main.go @@ -0,0 +1,58 @@ +//go:build !tinygo.wasm + +//go:generate tinygo build -o plugin/plugin.wasm -scheduler=none -target=wasi --no-debug plugin/plugin.go + +package main + +import ( + "context" + "fmt" + "log" + + "github.com/tetratelabs/wazero" + + "github.com/knqyf263/go-plugin/examples/host-function-library/library/json-parser/export" + "github.com/knqyf263/go-plugin/examples/host-function-library/library/json-parser/impl" + "github.com/knqyf263/go-plugin/examples/host-function-library/proto" +) + +func main() { + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + ctx := context.Background() + p, err := proto.NewGreeterPlugin(ctx, proto.WazeroRuntime(func(ctx context.Context) (wazero.Runtime, error) { + r, err := proto.DefaultWazeroRuntime()(ctx) + if err != nil { + return nil, err + } + return r, export.Instantiate(ctx, r, impl.ParserLibraryImpl{}) + })) + + // Pass my host functions that are embedded into the plugin. + plugin, err := p.Load(ctx, "plugin/plugin.wasm", myHostFunctions{}) + if err != nil { + return err + } + + defer plugin.Close(ctx) + + reply, err := plugin.Greet(ctx, &proto.GreetRequest{ + Name: "Sato", + }) + + fmt.Println(reply.GetMessage()) + + return nil +} + +type myHostFunctions struct{} + +var _ proto.HostFunctions = (*myHostFunctions)(nil) + +func (m myHostFunctions) San(_ context.Context, request *proto.SanRequest) (*proto.SanResponse, error) { + return &proto.SanResponse{Message: fmt.Sprintf("%s-san", request.GetMessage())}, nil +} diff --git a/examples/host-function-library/main_test.go b/examples/host-function-library/main_test.go new file mode 100644 index 0000000..cdec647 --- /dev/null +++ b/examples/host-function-library/main_test.go @@ -0,0 +1,14 @@ +package main + +import ( + "testing" + + "github.com/knqyf263/go-plugin/tests" + "github.com/stretchr/testify/assert" +) + +func Test_main(t *testing.T) { + got := tests.TestStdout(t, main) + want := "Hello, Sato. This is Yamada-san (age 20)." + assert.Equal(t, want, got) +} diff --git a/examples/host-function-library/plugin/plugin.go b/examples/host-function-library/plugin/plugin.go new file mode 100644 index 0000000..f7521c4 --- /dev/null +++ b/examples/host-function-library/plugin/plugin.go @@ -0,0 +1,43 @@ +//go:build tinygo.wasm + +package main + +import ( + "context" + "fmt" + + "github.com/knqyf263/go-plugin/examples/host-function-library/library/json-parser/export" + "github.com/knqyf263/go-plugin/examples/host-function-library/proto" +) + +// main is required for TinyGo to compile to Wasm. +func main() { + proto.RegisterGreeter(TestPlugin{}) +} + +type TestPlugin struct{} + +var _ proto.Greeter = (*TestPlugin)(nil) + +func (p TestPlugin) Greet(ctx context.Context, request *proto.GreetRequest) (*proto.GreetReply, error) { + parserLibrary := export.NewParserLibrary() + localHostFunctions := proto.NewHostFunctions() + + sanrequest, err := localHostFunctions.San(ctx, &proto.SanRequest{Message: "Yamada"}) + if err != nil { + return nil, err + } + + // Call the host function to parse JSON + resp, err := parserLibrary.ParseJson(ctx, &export.ParseJsonRequest{ + Content: []byte(fmt.Sprintf(`{"name": "%s", "age": 20}`, sanrequest.Message)), + }) + if err != nil { + return nil, err + } + + return &proto.GreetReply{ + Message: fmt.Sprintf("Hello, %s. This is %s (age %d).", + request.GetName(), resp.GetResponse().GetName(), resp.GetResponse().GetAge()), + }, nil +} diff --git a/examples/host-function-library/proto/greet.pb.go b/examples/host-function-library/proto/greet.pb.go new file mode 100644 index 0000000..934af16 --- /dev/null +++ b/examples/host-function-library/proto/greet.pb.go @@ -0,0 +1,111 @@ +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/proto/greet.proto + +package proto + +import ( + context "context" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// The request message containing the user's name. +type GreetRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` +} + +func (x *GreetRequest) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *GreetRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +// The response message containing the greetings +type GreetReply struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *GreetReply) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *GreetReply) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type SanRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *SanRequest) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *SanRequest) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type SanResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *SanResponse) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *SanResponse) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +// The greeting service definition. +// go:plugin type=plugin version=1 +type Greeter interface { + // Sends a greeting + Greet(context.Context, *GreetRequest) (*GreetReply, error) +} + +// The host functions embedded into the plugin +// go:plugin type=host +type HostFunctions interface { + San(context.Context, *SanRequest) (*SanResponse, error) +} diff --git a/examples/host-function-library/proto/greet.proto b/examples/host-function-library/proto/greet.proto new file mode 100644 index 0000000..3ccad95 --- /dev/null +++ b/examples/host-function-library/proto/greet.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; +package host; + +option go_package = "github.com/knqyf263/go-plugin/tests/import-host-functions/proto"; + +// The greeting service definition. +// go:plugin type=plugin version=1 +service Greeter { + // Sends a greeting + rpc Greet(GreetRequest) returns (GreetReply) {} +} + +// The request message containing the user's name. +message GreetRequest { + string name = 1; +} + +// The response message containing the greetings +message GreetReply { + string message = 1; +} + +// The host functions embedded into the plugin +// go:plugin type=host +service HostFunctions { + rpc San(SanRequest) returns (SanResponse) {} +} + +message SanRequest { + string message = 1; +} + +message SanResponse { + string message = 1; +} diff --git a/examples/host-function-library/proto/greet_host.pb.go b/examples/host-function-library/proto/greet_host.pb.go new file mode 100644 index 0000000..5cb845f --- /dev/null +++ b/examples/host-function-library/proto/greet_host.pb.go @@ -0,0 +1,244 @@ +//go:build !tinygo.wasm + +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/proto/greet.proto + +package proto + +import ( + context "context" + errors "errors" + fmt "fmt" + wasm "github.com/knqyf263/go-plugin/wasm" + wazero "github.com/tetratelabs/wazero" + api "github.com/tetratelabs/wazero/api" + sys "github.com/tetratelabs/wazero/sys" + os "os" +) + +const ( + i32 = api.ValueTypeI32 + i64 = api.ValueTypeI64 +) + +type _hostFunctions struct { + HostFunctions +} + +// Instantiate a Go-defined module named "env" that exports host functions. +func (h _hostFunctions) Instantiate(ctx context.Context, r wazero.Runtime) error { + envBuilder := r.NewHostModuleBuilder("env") + + envBuilder.NewFunctionBuilder(). + WithGoModuleFunction(api.GoModuleFunc(h._San), []api.ValueType{i32, i32}, []api.ValueType{i64}). + WithParameterNames("offset", "size"). + Export("san") + + _, err := envBuilder.Instantiate(ctx) + return err +} + +func (h _hostFunctions) _San(ctx context.Context, m api.Module, stack []uint64) { + offset, size := uint32(stack[0]), uint32(stack[1]) + buf, err := wasm.ReadMemory(m.Memory(), offset, size) + if err != nil { + panic(err) + } + request := new(SanRequest) + err = request.UnmarshalVT(buf) + if err != nil { + panic(err) + } + resp, err := h.San(ctx, request) + if err != nil { + panic(err) + } + buf, err = resp.MarshalVT() + if err != nil { + panic(err) + } + ptr, err := wasm.WriteMemory(ctx, m, buf) + if err != nil { + panic(err) + } + ptrLen := (ptr << uint64(32)) | uint64(len(buf)) + stack[0] = ptrLen +} + +const GreeterPluginAPIVersion = 1 + +type GreeterPlugin struct { + newRuntime func(context.Context) (wazero.Runtime, error) + moduleConfig wazero.ModuleConfig +} + +func NewGreeterPlugin(ctx context.Context, opts ...wazeroConfigOption) (*GreeterPlugin, error) { + o := &WazeroConfig{ + newRuntime: DefaultWazeroRuntime(), + moduleConfig: wazero.NewModuleConfig(), + } + + for _, opt := range opts { + opt(o) + } + + return &GreeterPlugin{ + newRuntime: o.newRuntime, + moduleConfig: o.moduleConfig, + }, nil +} + +type greeter interface { + Close(ctx context.Context) error + Greeter +} + +func (p *GreeterPlugin) Load(ctx context.Context, pluginPath string, hostFunctions HostFunctions) (greeter, error) { + b, err := os.ReadFile(pluginPath) + if err != nil { + return nil, err + } + + // Create a new runtime so that multiple modules will not conflict + r, err := p.newRuntime(ctx) + if err != nil { + return nil, err + } + + h := _hostFunctions{hostFunctions} + + if err := h.Instantiate(ctx, r); err != nil { + return nil, err + } + + // Compile the WebAssembly module using the default configuration. + code, err := r.CompileModule(ctx, b) + if err != nil { + return nil, err + } + + // InstantiateModule runs the "_start" function, WASI's "main". + module, err := r.InstantiateModule(ctx, code, p.moduleConfig) + if err != nil { + // Note: Most compilers do not exit the module after running "_start", + // unless there was an Error. This allows you to call exported functions. + if exitErr, ok := err.(*sys.ExitError); ok && exitErr.ExitCode() != 0 { + return nil, fmt.Errorf("unexpected exit_code: %d", exitErr.ExitCode()) + } else if !ok { + return nil, err + } + } + + // Compare API versions with the loading plugin + apiVersion := module.ExportedFunction("greeter_api_version") + if apiVersion == nil { + return nil, errors.New("greeter_api_version is not exported") + } + results, err := apiVersion.Call(ctx) + if err != nil { + return nil, err + } else if len(results) != 1 { + return nil, errors.New("invalid greeter_api_version signature") + } + if results[0] != GreeterPluginAPIVersion { + return nil, fmt.Errorf("API version mismatch, host: %d, plugin: %d", GreeterPluginAPIVersion, results[0]) + } + + greet := module.ExportedFunction("greeter_greet") + if greet == nil { + return nil, errors.New("greeter_greet is not exported") + } + + malloc := module.ExportedFunction("malloc") + if malloc == nil { + return nil, errors.New("malloc is not exported") + } + + free := module.ExportedFunction("free") + if free == nil { + return nil, errors.New("free is not exported") + } + return &greeterPlugin{ + runtime: r, + module: module, + malloc: malloc, + free: free, + greet: greet, + }, nil +} + +func (p *greeterPlugin) Close(ctx context.Context) (err error) { + if r := p.runtime; r != nil { + r.Close(ctx) + } + return +} + +type greeterPlugin struct { + runtime wazero.Runtime + module api.Module + malloc api.Function + free api.Function + greet api.Function +} + +func (p *greeterPlugin) Greet(ctx context.Context, request *GreetRequest) (*GreetReply, error) { + data, err := request.MarshalVT() + if err != nil { + return nil, err + } + dataSize := uint64(len(data)) + + var dataPtr uint64 + // If the input data is not empty, we must allocate the in-Wasm memory to store it, and pass to the plugin. + if dataSize != 0 { + results, err := p.malloc.Call(ctx, dataSize) + if err != nil { + return nil, err + } + dataPtr = results[0] + // This pointer is managed by TinyGo, but TinyGo is unaware of external usage. + // So, we have to free it when finished + defer p.free.Call(ctx, dataPtr) + + // The pointer is a linear memory offset, which is where we write the name. + if !p.module.Memory().Write(uint32(dataPtr), data) { + return nil, fmt.Errorf("Memory.Write(%d, %d) out of range of memory size %d", dataPtr, dataSize, p.module.Memory().Size()) + } + } + + ptrSize, err := p.greet.Call(ctx, dataPtr, dataSize) + if err != nil { + return nil, err + } + + // Note: This pointer is still owned by TinyGo, so don't try to free it! + resPtr := uint32(ptrSize[0] >> 32) + resSize := uint32(ptrSize[0]) + var isErrResponse bool + if (resSize & (1 << 31)) > 0 { + isErrResponse = true + resSize &^= (1 << 31) + } + + // The pointer is a linear memory offset, which is where we write the name. + bytes, ok := p.module.Memory().Read(resPtr, resSize) + if !ok { + return nil, fmt.Errorf("Memory.Read(%d, %d) out of range of memory size %d", + resPtr, resSize, p.module.Memory().Size()) + } + + if isErrResponse { + return nil, errors.New(string(bytes)) + } + + response := new(GreetReply) + if err = response.UnmarshalVT(bytes); err != nil { + return nil, err + } + + return response, nil +} diff --git a/examples/host-function-library/proto/greet_options.pb.go b/examples/host-function-library/proto/greet_options.pb.go new file mode 100644 index 0000000..6799a61 --- /dev/null +++ b/examples/host-function-library/proto/greet_options.pb.go @@ -0,0 +1,47 @@ +//go:build !tinygo.wasm + +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/proto/greet.proto + +package proto + +import ( + context "context" + wazero "github.com/tetratelabs/wazero" + wasi_snapshot_preview1 "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" +) + +type wazeroConfigOption func(plugin *WazeroConfig) + +type WazeroNewRuntime func(context.Context) (wazero.Runtime, error) + +type WazeroConfig struct { + newRuntime func(context.Context) (wazero.Runtime, error) + moduleConfig wazero.ModuleConfig +} + +func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { + return func(h *WazeroConfig) { + h.newRuntime = newRuntime + } +} + +func DefaultWazeroRuntime() WazeroNewRuntime { + return func(ctx context.Context) (wazero.Runtime, error) { + r := wazero.NewRuntime(ctx) + if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { + return nil, err + } + + return r, nil + } +} + +func WazeroModuleConfig(moduleConfig wazero.ModuleConfig) wazeroConfigOption { + return func(h *WazeroConfig) { + h.moduleConfig = moduleConfig + } +} diff --git a/examples/host-function-library/proto/greet_plugin.pb.go b/examples/host-function-library/proto/greet_plugin.pb.go new file mode 100644 index 0000000..b8f2b9e --- /dev/null +++ b/examples/host-function-library/proto/greet_plugin.pb.go @@ -0,0 +1,82 @@ +//go:build tinygo.wasm + +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/proto/greet.proto + +package proto + +import ( + context "context" + wasm "github.com/knqyf263/go-plugin/wasm" + _ "unsafe" +) + +const GreeterPluginAPIVersion = 1 + +//export greeter_api_version +func _greeter_api_version() uint64 { + return GreeterPluginAPIVersion +} + +var greeter Greeter + +func RegisterGreeter(p Greeter) { + greeter = p +} + +//export greeter_greet +func _greeter_greet(ptr, size uint32) uint64 { + b := wasm.PtrToByte(ptr, size) + req := new(GreetRequest) + if err := req.UnmarshalVT(b); err != nil { + return 0 + } + response, err := greeter.Greet(context.Background(), req) + if err != nil { + ptr, size = wasm.ByteToPtr([]byte(err.Error())) + return (uint64(ptr) << uint64(32)) | uint64(size) | + // Indicate that this is the error string by setting the 32-th bit, assuming that + // no data exceeds 31-bit size (2 GiB). + (1 << 31) + } + + b, err = response.MarshalVT() + if err != nil { + return 0 + } + ptr, size = wasm.ByteToPtr(b) + return (uint64(ptr) << uint64(32)) | uint64(size) +} + +type hostFunctions struct{} + +func NewHostFunctions() HostFunctions { + return hostFunctions{} +} + +//go:wasm-module env +//export san +//go:linkname _san +func _san(ptr uint32, size uint32) uint64 + +func (h hostFunctions) San(ctx context.Context, request *SanRequest) (*SanResponse, error) { + buf, err := request.MarshalVT() + if err != nil { + return nil, err + } + ptr, size := wasm.ByteToPtr(buf) + ptrSize := _san(ptr, size) + + ptr = uint32(ptrSize >> 32) + size = uint32(ptrSize) + buf = wasm.PtrToByte(ptr, size) + + response := new(SanResponse) + if err = response.UnmarshalVT(buf); err != nil { + return nil, err + } + return response, nil +} diff --git a/examples/host-function-library/proto/greet_vtproto.pb.go b/examples/host-function-library/proto/greet_vtproto.pb.go new file mode 100644 index 0000000..1245f63 --- /dev/null +++ b/examples/host-function-library/proto/greet_vtproto.pb.go @@ -0,0 +1,679 @@ +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.21.12 +// source: examples/host-function-library/proto/greet.proto + +package proto + +import ( + fmt "fmt" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + io "io" + bits "math/bits" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +func (m *GreetRequest) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GreetRequest) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *GreetRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Name) > 0 { + i -= len(m.Name) + copy(dAtA[i:], m.Name) + i = encodeVarint(dAtA, i, uint64(len(m.Name))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *GreetReply) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *GreetReply) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *GreetReply) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Message) > 0 { + i -= len(m.Message) + copy(dAtA[i:], m.Message) + i = encodeVarint(dAtA, i, uint64(len(m.Message))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *SanRequest) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SanRequest) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *SanRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Message) > 0 { + i -= len(m.Message) + copy(dAtA[i:], m.Message) + i = encodeVarint(dAtA, i, uint64(len(m.Message))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *SanResponse) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *SanResponse) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *SanResponse) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Message) > 0 { + i -= len(m.Message) + copy(dAtA[i:], m.Message) + i = encodeVarint(dAtA, i, uint64(len(m.Message))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func encodeVarint(dAtA []byte, offset int, v uint64) int { + offset -= sov(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *GreetRequest) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Name) + if l > 0 { + n += 1 + l + sov(uint64(l)) + } + if m.unknownFields != nil { + n += len(m.unknownFields) + } + return n +} + +func (m *GreetReply) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Message) + if l > 0 { + n += 1 + l + sov(uint64(l)) + } + if m.unknownFields != nil { + n += len(m.unknownFields) + } + return n +} + +func (m *SanRequest) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Message) + if l > 0 { + n += 1 + l + sov(uint64(l)) + } + if m.unknownFields != nil { + n += len(m.unknownFields) + } + return n +} + +func (m *SanResponse) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Message) + if l > 0 { + n += 1 + l + sov(uint64(l)) + } + if m.unknownFields != nil { + n += len(m.unknownFields) + } + return n +} + +func sov(x uint64) (n int) { + return (bits.Len64(x|1) + 6) / 7 +} +func soz(x uint64) (n int) { + return sov(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *GreetRequest) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GreetRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GreetRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Name", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Name = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *GreetReply) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: GreetReply: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: GreetReply: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Message", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Message = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SanRequest) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SanRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SanRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Message", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Message = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *SanResponse) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: SanResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: SanResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Message", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Message = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skip(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLength + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroup + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLength + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLength = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflow = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroup = fmt.Errorf("proto: unexpected end of group") +) diff --git a/examples/host-functions/README.md b/examples/host-functions/README.md index 4007d72..0f985a4 100644 --- a/examples/host-functions/README.md +++ b/examples/host-functions/README.md @@ -67,7 +67,7 @@ Pass it to a plugin in `Load()`. ```go // Pass my host functions that are embedded into the plugin. -greetingPlugin, err := p.Load(ctx, "plugin/plugin.wasm") +greetingPlugin, err := p.Load(ctx, "plugin/plugin.wasm", myHostFunctions{}) ``` ## Call host functions in a plugin diff --git a/examples/host-functions/greeting/greet_host.pb.go b/examples/host-functions/greeting/greet_host.pb.go index 9d39346..65ffb9c 100644 --- a/examples/host-functions/greeting/greet_host.pb.go +++ b/examples/host-functions/greeting/greet_host.pb.go @@ -113,7 +113,7 @@ type GreeterPlugin struct { func NewGreeterPlugin(ctx context.Context, opts ...wazeroConfigOption) (*GreeterPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/examples/host-functions/greeting/greet_options.pb.go b/examples/host-functions/greeting/greet_options.pb.go index a91021d..4a34c8c 100644 --- a/examples/host-functions/greeting/greet_options.pb.go +++ b/examples/host-functions/greeting/greet_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/examples/known-types/known/known_host.pb.go b/examples/known-types/known/known_host.pb.go index ec730fc..9d46c82 100644 --- a/examples/known-types/known/known_host.pb.go +++ b/examples/known-types/known/known_host.pb.go @@ -27,7 +27,7 @@ type WellKnownPlugin struct { func NewWellKnownPlugin(ctx context.Context, opts ...wazeroConfigOption) (*WellKnownPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/examples/known-types/known/known_options.pb.go b/examples/known-types/known/known_options.pb.go index 8ed0b03..96c44d8 100644 --- a/examples/known-types/known/known_options.pb.go +++ b/examples/known-types/known/known_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/examples/wasi/cat/cat_host.pb.go b/examples/wasi/cat/cat_host.pb.go index 653e95c..aab8670 100644 --- a/examples/wasi/cat/cat_host.pb.go +++ b/examples/wasi/cat/cat_host.pb.go @@ -27,7 +27,7 @@ type FileCatPlugin struct { func NewFileCatPlugin(ctx context.Context, opts ...wazeroConfigOption) (*FileCatPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/examples/wasi/cat/cat_options.pb.go b/examples/wasi/cat/cat_options.pb.go index 4a0e14b..3786a09 100644 --- a/examples/wasi/cat/cat_options.pb.go +++ b/examples/wasi/cat/cat_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/gen/host.go b/gen/host.go index 8c71b20..5599558 100644 --- a/gen/host.go +++ b/gen/host.go @@ -12,7 +12,7 @@ func (gg *Generator) generateHostFile(f *fileInfo) { filename := f.GeneratedFilenamePrefix + "_host.pb.go" g := gg.plugin.NewGeneratedFile(filename, f.GoImportPath) - if len(f.pluginServices) == 0 { + if len(f.pluginServices) == 0 && f.hostService == nil { g.Skip() } @@ -46,14 +46,31 @@ func (gg *Generator) genHostFunctions(g *protogen.GeneratedFile, f *fileInfo) { `, structName, f.hostService.GoName)) // Define exporting functions - g.P(fmt.Sprintf(` - // Instantiate a Go-defined module named "env" that exports host functions. + // If it is only distributable host-functions, i.e. there is no plugin service definition + if len(f.pluginServices) == 0 { + g.P(fmt.Sprintf(` + // Instantiate a Go-defined module named "%s" that exports host functions. + func Instantiate(ctx %s, r %s, hostFunctions %s) error { + envBuilder := r.NewHostModuleBuilder("%s") + h := %s{hostFunctions}`, + f.hostService.Module, + g.QualifiedGoIdent(contextPackage.Ident("Context")), + g.QualifiedGoIdent(wazeroPackage.Ident("Runtime")), + f.hostService.GoName, + f.hostService.Module, + structName)) + } else { + g.P(fmt.Sprintf(` + // Instantiate a Go-defined module named "%s" that exports host functions. func (h %s) Instantiate(ctx %s, r %s) error { - envBuilder := r.NewHostModuleBuilder("env")`, - structName, - g.QualifiedGoIdent(contextPackage.Ident("Context")), - g.QualifiedGoIdent(wazeroPackage.Ident("Runtime")), - )) + envBuilder := r.NewHostModuleBuilder("%s")`, + f.hostService.Module, + structName, + g.QualifiedGoIdent(contextPackage.Ident("Context")), + g.QualifiedGoIdent(wazeroPackage.Ident("Runtime")), + f.hostService.Module)) + } + for _, method := range f.hostService.Methods { g.P(fmt.Sprintf(` envBuilder.NewFunctionBuilder(). @@ -124,7 +141,7 @@ func genHost(g *protogen.GeneratedFile, f *fileInfo, service *serviceInfo) { pluginName, )) g.P(fmt.Sprintf(`o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: %s(), } diff --git a/gen/init.go b/gen/init.go index a80a7f3..8025079 100644 --- a/gen/init.go +++ b/gen/init.go @@ -185,6 +185,7 @@ type serviceInfo struct { *protogen.Service Version int Type ServiceType + Module string } func newServiceInfo(service *protogen.Service, param Parameter) *serviceInfo { @@ -192,6 +193,7 @@ func newServiceInfo(service *protogen.Service, param Parameter) *serviceInfo { Service: service, Type: param.Type, Version: param.APIVersion, + Module: param.Module, } return x } diff --git a/gen/main.go b/gen/main.go index 66706ca..3d96403 100644 --- a/gen/main.go +++ b/gen/main.go @@ -832,6 +832,7 @@ func (c trailingComment) String() string { type Parameter struct { APIVersion int Type ServiceType + Module string } type ServiceType string @@ -841,6 +842,7 @@ const ( ServicePlugin ServiceType = "plugin" ServiceUnknown ServiceType = "unknown" ServiceNone ServiceType = "none" + EnvModuleName = "env" ) // parseParam parses a comment and extract parameters for go-plugin @@ -849,6 +851,7 @@ func parseParam(comment string) (Parameter, error) { param := Parameter{ APIVersion: 1, Type: ServiceNone, + Module: EnvModuleName, } for _, line := range strings.Split(comment, "\n") { line = strings.TrimPrefix(line, "//") @@ -883,6 +886,10 @@ func parseParam(comment string) (Parameter, error) { return Parameter{}, fmt.Errorf("invalid version: %w", err) } param.APIVersion = ver + case "module": + if param.Type == ServiceHost && len(value) > 0 { + param.Module = value + } } } } diff --git a/gen/options.go b/gen/options.go index 8ae3432..e589fa0 100644 --- a/gen/options.go +++ b/gen/options.go @@ -31,7 +31,7 @@ func (gg *Generator) generateOptionsFile(f *fileInfo) { } } - func defaultWazeroRuntime() WazeroNewRuntime { + func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx %s) (%s, error) { r := %s(ctx) if _, err := %s(ctx, r); err != nil { diff --git a/gen/plugin.go b/gen/plugin.go index a7e0157..47850f5 100644 --- a/gen/plugin.go +++ b/gen/plugin.go @@ -12,7 +12,7 @@ func (gg *Generator) generatePluginFile(f *fileInfo) { filename := f.GeneratedFilenamePrefix + "_plugin.pb.go" g := gg.plugin.NewGeneratedFile(filename, f.GoImportPath) - if len(f.pluginServices) == 0 { + if len(f.pluginServices) == 0 && f.hostService == nil { g.Skip() } @@ -103,7 +103,7 @@ func genHostFunctions(g *protogen.GeneratedFile, f *fileInfo) { for _, method := range f.hostService.Methods { importedName := toSnakeCase(method.GoName) g.P(fmt.Sprintf(` - //go:wasm-module env + //go:wasm-module %s //export %s //go:linkname _%s func _%s(ptr uint32, size uint32) uint64 @@ -126,7 +126,7 @@ func genHostFunctions(g *protogen.GeneratedFile, f *fileInfo) { } return response, nil }`, - importedName, importedName, importedName, structName, method.GoName, + f.hostService.Module, importedName, importedName, importedName, structName, method.GoName, g.QualifiedGoIdent(contextPackage.Ident("Context")), g.QualifiedGoIdent(method.Input.GoIdent), g.QualifiedGoIdent(method.Output.GoIdent), diff --git a/tests/fields/proto/fields_host.pb.go b/tests/fields/proto/fields_host.pb.go index 4b3e0e4..552cc08 100644 --- a/tests/fields/proto/fields_host.pb.go +++ b/tests/fields/proto/fields_host.pb.go @@ -28,7 +28,7 @@ type FieldTestPlugin struct { func NewFieldTestPlugin(ctx context.Context, opts ...wazeroConfigOption) (*FieldTestPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/tests/fields/proto/fields_options.pb.go b/tests/fields/proto/fields_options.pb.go index 9fa184d..f6f9c15 100644 --- a/tests/fields/proto/fields_options.pb.go +++ b/tests/fields/proto/fields_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/tests/host-functions/host_functions_test.go b/tests/host-functions/host_functions_test.go index b5dd029..3b10bea 100644 --- a/tests/host-functions/host_functions_test.go +++ b/tests/host-functions/host_functions_test.go @@ -6,24 +6,17 @@ import ( "os" "testing" + "github.com/knqyf263/go-plugin/tests/host-functions/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tetratelabs/wazero" - "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" - - "github.com/knqyf263/go-plugin/tests/host-functions/proto" ) func TestHostFunctions(t *testing.T) { ctx := context.Background() mc := wazero.NewModuleConfig().WithStdout(os.Stdout) p, err := proto.NewGreeterPlugin(ctx, proto.WazeroRuntime(func(ctx context.Context) (wazero.Runtime, error) { - r := wazero.NewRuntimeWithConfig(ctx, wazero.NewRuntimeConfig().WithCompilationCache(wazero.NewCompilationCache())) - if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { - return nil, err - } - - return r, nil + return proto.DefaultWazeroRuntime()(ctx) }), proto.WazeroModuleConfig(mc)) require.NoError(t, err) diff --git a/tests/host-functions/proto/host_host.pb.go b/tests/host-functions/proto/host_host.pb.go index cf18980..4240d7b 100644 --- a/tests/host-functions/proto/host_host.pb.go +++ b/tests/host-functions/proto/host_host.pb.go @@ -77,7 +77,7 @@ type GreeterPlugin struct { func NewGreeterPlugin(ctx context.Context, opts ...wazeroConfigOption) (*GreeterPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/tests/host-functions/proto/host_options.pb.go b/tests/host-functions/proto/host_options.pb.go index 0b87ded..a79893b 100644 --- a/tests/host-functions/proto/host_options.pb.go +++ b/tests/host-functions/proto/host_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/tests/import/proto/bar/bar_host.pb.go b/tests/import/proto/bar/bar_host.pb.go index bedf80d..f9001d8 100644 --- a/tests/import/proto/bar/bar_host.pb.go +++ b/tests/import/proto/bar/bar_host.pb.go @@ -27,7 +27,7 @@ type BarPlugin struct { func NewBarPlugin(ctx context.Context, opts ...wazeroConfigOption) (*BarPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/tests/import/proto/bar/bar_options.pb.go b/tests/import/proto/bar/bar_options.pb.go index 0b63b06..52783cf 100644 --- a/tests/import/proto/bar/bar_options.pb.go +++ b/tests/import/proto/bar/bar_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/tests/import/proto/foo/foo_host.pb.go b/tests/import/proto/foo/foo_host.pb.go index c86c3af..4c19ebc 100644 --- a/tests/import/proto/foo/foo_host.pb.go +++ b/tests/import/proto/foo/foo_host.pb.go @@ -28,7 +28,7 @@ type FooPlugin struct { func NewFooPlugin(ctx context.Context, opts ...wazeroConfigOption) (*FooPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/tests/import/proto/foo/foo_options.pb.go b/tests/import/proto/foo/foo_options.pb.go index 62893d4..5b94e2e 100644 --- a/tests/import/proto/foo/foo_options.pb.go +++ b/tests/import/proto/foo/foo_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil { diff --git a/tests/well-known/proto/known_host.pb.go b/tests/well-known/proto/known_host.pb.go index 38dd374..b98351b 100644 --- a/tests/well-known/proto/known_host.pb.go +++ b/tests/well-known/proto/known_host.pb.go @@ -28,7 +28,7 @@ type KnownTypesTestPlugin struct { func NewKnownTypesTestPlugin(ctx context.Context, opts ...wazeroConfigOption) (*KnownTypesTestPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } @@ -197,7 +197,7 @@ type EmptyTestPlugin struct { func NewEmptyTestPlugin(ctx context.Context, opts ...wazeroConfigOption) (*EmptyTestPlugin, error) { o := &WazeroConfig{ - newRuntime: defaultWazeroRuntime(), + newRuntime: DefaultWazeroRuntime(), moduleConfig: wazero.NewModuleConfig(), } diff --git a/tests/well-known/proto/known_options.pb.go b/tests/well-known/proto/known_options.pb.go index e46590f..8d33d49 100644 --- a/tests/well-known/proto/known_options.pb.go +++ b/tests/well-known/proto/known_options.pb.go @@ -29,7 +29,7 @@ func WazeroRuntime(newRuntime WazeroNewRuntime) wazeroConfigOption { } } -func defaultWazeroRuntime() WazeroNewRuntime { +func DefaultWazeroRuntime() WazeroNewRuntime { return func(ctx context.Context) (wazero.Runtime, error) { r := wazero.NewRuntime(ctx) if _, err := wasi_snapshot_preview1.Instantiate(ctx, r); err != nil {