diff --git a/pkg/generic/httpthrift_codec.go b/pkg/generic/httpthrift_codec.go index e142811145..035632bdbc 100644 --- a/pkg/generic/httpthrift_codec.go +++ b/pkg/generic/httpthrift_codec.go @@ -41,19 +41,20 @@ type HTTPRequest = descriptor.HTTPRequest type HTTPResponse = descriptor.HTTPResponse type httpThriftCodec struct { - svcDsc atomic.Value // *idl - provider DescriptorProvider - binaryWithBase64 bool - convOpts conv.Options // used for dynamicgo conversion - convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on - dynamicgoEnabled bool - useRawBodyForHTTPResp bool - svcName string + svcDsc atomic.Value // *idl + provider DescriptorProvider + binaryWithBase64 bool + convOpts conv.Options // used for dynamicgo conversion + convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on + dynamicgoEnabled bool + useRawBodyForHTTPResp bool + failOnNilValueForRequiredField bool + svcName string } func newHTTPThriftCodec(p DescriptorProvider, opts *Options) *httpThriftCodec { svc := <-p.Provide() - c := &httpThriftCodec{provider: p, binaryWithBase64: false, dynamicgoEnabled: false, useRawBodyForHTTPResp: opts.useRawBodyForHTTPResp, svcName: svc.Name} + c := &httpThriftCodec{provider: p, binaryWithBase64: false, dynamicgoEnabled: false, useRawBodyForHTTPResp: opts.useRawBodyForHTTPResp, failOnNilValueForRequiredField: opts.failOnNilValueForRequiredField, svcName: svc.Name} if dp, ok := p.(GetProviderOption); ok && dp.Option().DynamicGoEnabled { c.dynamicgoEnabled = true @@ -95,6 +96,7 @@ func (c *httpThriftCodec) configureHTTPRequestWriter(writer *thrift.WriteHTTPReq if c.dynamicgoEnabled { writer.SetDynamicGo(&c.convOpts, &c.convOptsWithThriftBase) } + writer.SetFailOnNilValueForRequiredField(c.failOnNilValueForRequiredField) } func (c *httpThriftCodec) configureHTTPResponseReader(reader *thrift.ReadHTTPResponse) { diff --git a/pkg/generic/httpthrift_codec_test.go b/pkg/generic/httpthrift_codec_test.go index 1852288de7..0cb1c5a782 100644 --- a/pkg/generic/httpthrift_codec_test.go +++ b/pkg/generic/httpthrift_codec_test.go @@ -109,6 +109,41 @@ func TestHttpThriftCodecWithDynamicGo(t *testing.T) { test.Assert(t, ok) } +func TestHttpThriftCodecWithFailOnNilValueForRequired(t *testing.T) { + // without dynamicgo + p, err := NewThriftFileProvider("./http_test/idl/binary_echo.thrift") + test.Assert(t, err == nil) + gOpts := &Options{dynamicgoConvOpts: DefaultHTTPDynamicGoConvOpts, failOnNilValueForRequiredField: true} + htc := newHTTPThriftCodec(p, gOpts) + test.Assert(t, !htc.dynamicgoEnabled) + test.Assert(t, !htc.useRawBodyForHTTPResp) + test.Assert(t, htc.failOnNilValueForRequiredField) + test.DeepEqual(t, htc.convOpts, conv.Options{}) + test.DeepEqual(t, htc.convOptsWithThriftBase, conv.Options{}) + defer htc.Close() + test.Assert(t, htc.Name() == "HttpThrift") + + req := &HTTPRequest{Request: getStdHttpRequest()} + // wrong + method, err := htc.getMethod("test") + test.Assert(t, err.Error() == "req is invalid, need descriptor.HTTPRequest" && method == nil) + // right + method, err = htc.getMethod(req) + test.Assert(t, err == nil && method.Name == "BinaryEcho") + test.Assert(t, method.StreamingMode == serviceinfo.StreamingNone) + test.Assert(t, htc.svcName == "ExampleService") + + rw := htc.getMessageReaderWriter() + _, ok := rw.(error) + test.Assert(t, !ok) + + rw = htc.getMessageReaderWriter() + _, ok = rw.(thrift.MessageWriter) + test.Assert(t, ok) + _, ok = rw.(thrift.MessageReader) + test.Assert(t, ok) +} + func getStdHttpRequest() *http.Request { body := map[string]interface{}{ "msg": []byte("hello"), diff --git a/pkg/generic/option.go b/pkg/generic/option.go index 1ebe319291..a687727da3 100644 --- a/pkg/generic/option.go +++ b/pkg/generic/option.go @@ -41,6 +41,8 @@ type Options struct { dynamicgoConvOpts conv.Options // flag to set whether to store http resp body into HTTPResponse.RawBody useRawBodyForHTTPResp bool + // will return error when field is required but input value is nil + failOnNilValueForRequiredField bool } type Option struct { @@ -68,3 +70,10 @@ func UseRawBodyForHTTPResp(enable bool) Option { opt.useRawBodyForHTTPResp = enable }} } + +// will return error when field is required but input value is nil +func WithFailOnNilValueForRequiredField(enable bool) Option { + return Option{F: func(opt *Options) { + opt.failOnNilValueForRequiredField = enable + }} +} diff --git a/pkg/generic/thrift/http.go b/pkg/generic/thrift/http.go index 5bb0d9e8a2..711f1ce0e2 100644 --- a/pkg/generic/thrift/http.go +++ b/pkg/generic/thrift/http.go @@ -43,11 +43,12 @@ func NewHTTPReaderWriter(svc *descriptor.ServiceDescriptor) *HTTPReaderWriter { // WriteHTTPRequest implement of MessageWriter type WriteHTTPRequest struct { - svc *descriptor.ServiceDescriptor - binaryWithBase64 bool - convOpts conv.Options // used for dynamicgo conversion - convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on - dynamicgoEnabled bool + svc *descriptor.ServiceDescriptor + binaryWithBase64 bool + convOpts conv.Options // used for dynamicgo conversion + convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on + dynamicgoEnabled bool + failOnNilValueForRequiredField bool // will return error when field is required but input value is nil } var ( @@ -71,6 +72,10 @@ func (w *WriteHTTPRequest) SetBinaryWithBase64(enable bool) { w.binaryWithBase64 = enable } +func (w *WriteHTTPRequest) SetFailOnNilValueForRequiredField(enable bool) { + w.failOnNilValueForRequiredField = enable +} + // SetDynamicGo ... func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) { w.convOpts = *convOpts @@ -94,7 +99,7 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out bufiox.Writer, requestBase = nil } bw := thrift.NewBufferWriter(out) - err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}) + err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64, failOnNilValueForRequiredField: w.failOnNilValueForRequiredField}) bw.Recycle() return err } diff --git a/pkg/generic/thrift/write.go b/pkg/generic/thrift/write.go index 04c849e9ba..c1994af22f 100644 --- a/pkg/generic/thrift/write.go +++ b/pkg/generic/thrift/write.go @@ -35,6 +35,8 @@ type writerOption struct { requestBase *base.Base // request base from metahandler // decoding Base64 to binary binaryWithBase64 bool + // will return error when field is required but input value is nil + failOnNilValueForRequiredField bool } type writer func(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error @@ -778,6 +780,9 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BufferWr if v == nil { if !field.Optional { + if opt != nil && opt.failOnNilValueForRequiredField { + return fmt.Errorf("value of field [%s] is nil", name) + } if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { return err } diff --git a/pkg/generic/thrift/write_test.go b/pkg/generic/thrift/write_test.go index 79a469bdca..96e57580bd 100644 --- a/pkg/generic/thrift/write_test.go +++ b/pkg/generic/thrift/write_test.go @@ -1358,6 +1358,36 @@ func Test_writeHTTPRequest(t *testing.T) { }, false, }, + { + "writeStructRequiredFail", + args{ + val: &descriptor.HTTPRequest{ + Body: map[string]interface{}{"hello": nil}, + }, + + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": { + Name: "hello", + ID: 1, + Required: true, + Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + HTTPMapping: descriptor.DefaultNewMapping("hello"), + }, + }, + }, + }, + opt: &writerOption{ + failOnNilValueForRequiredField: true, + }, + }, + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1434,7 +1464,7 @@ func getReqPbBody() (proto.Message, error) { path := "main.proto" content := ` package kitex.test.server; - + message BizReq { optional int32 user_id = 1; optional string user_name = 2;