diff --git a/pkg/generic/thrift/http.go b/pkg/generic/thrift/http.go index 5bb0d9e8a2..711f1ce0e2 100644 --- a/pkg/generic/thrift/http.go +++ b/pkg/generic/thrift/http.go @@ -43,11 +43,12 @@ func NewHTTPReaderWriter(svc *descriptor.ServiceDescriptor) *HTTPReaderWriter { // WriteHTTPRequest implement of MessageWriter type WriteHTTPRequest struct { - svc *descriptor.ServiceDescriptor - binaryWithBase64 bool - convOpts conv.Options // used for dynamicgo conversion - convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on - dynamicgoEnabled bool + svc *descriptor.ServiceDescriptor + binaryWithBase64 bool + convOpts conv.Options // used for dynamicgo conversion + convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on + dynamicgoEnabled bool + failOnNilValueForRequiredField bool // will return error when field is required but input value is nil } var ( @@ -71,6 +72,10 @@ func (w *WriteHTTPRequest) SetBinaryWithBase64(enable bool) { w.binaryWithBase64 = enable } +func (w *WriteHTTPRequest) SetFailOnNilValueForRequiredField(enable bool) { + w.failOnNilValueForRequiredField = enable +} + // SetDynamicGo ... func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) { w.convOpts = *convOpts @@ -94,7 +99,7 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out bufiox.Writer, requestBase = nil } bw := thrift.NewBufferWriter(out) - err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64}) + err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64, failOnNilValueForRequiredField: w.failOnNilValueForRequiredField}) bw.Recycle() return err } diff --git a/pkg/generic/thrift/write.go b/pkg/generic/thrift/write.go index 04c849e9ba..c1994af22f 100644 --- a/pkg/generic/thrift/write.go +++ b/pkg/generic/thrift/write.go @@ -35,6 +35,8 @@ type writerOption struct { requestBase *base.Base // request base from metahandler // decoding Base64 to binary binaryWithBase64 bool + // will return error when field is required but input value is nil + failOnNilValueForRequiredField bool } type writer func(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error @@ -778,6 +780,9 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BufferWr if v == nil { if !field.Optional { + if opt != nil && opt.failOnNilValueForRequiredField { + return fmt.Errorf("value of field [%s] is nil", name) + } if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil { return err } diff --git a/pkg/generic/thrift/write_test.go b/pkg/generic/thrift/write_test.go index 79a469bdca..96e57580bd 100644 --- a/pkg/generic/thrift/write_test.go +++ b/pkg/generic/thrift/write_test.go @@ -1358,6 +1358,36 @@ func Test_writeHTTPRequest(t *testing.T) { }, false, }, + { + "writeStructRequiredFail", + args{ + val: &descriptor.HTTPRequest{ + Body: map[string]interface{}{"hello": nil}, + }, + + t: &descriptor.TypeDescriptor{ + Type: descriptor.STRUCT, + Key: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + Struct: &descriptor.StructDescriptor{ + Name: "Demo", + FieldsByName: map[string]*descriptor.FieldDescriptor{ + "hello": { + Name: "hello", + ID: 1, + Required: true, + Type: &descriptor.TypeDescriptor{Type: descriptor.STRING}, + HTTPMapping: descriptor.DefaultNewMapping("hello"), + }, + }, + }, + }, + opt: &writerOption{ + failOnNilValueForRequiredField: true, + }, + }, + true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1434,7 +1464,7 @@ func getReqPbBody() (proto.Message, error) { path := "main.proto" content := ` package kitex.test.server; - + message BizReq { optional int32 user_id = 1; optional string user_name = 2;