Skip to content

Commit

Permalink
feat(generic_http_thrift): fail on nil value for required field
Browse files Browse the repository at this point in the history
  • Loading branch information
wasd96040501 committed Sep 14, 2024
1 parent 4e1dbe9 commit 4ffeadd
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
17 changes: 11 additions & 6 deletions pkg/generic/thrift/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ func NewHTTPReaderWriter(svc *descriptor.ServiceDescriptor) *HTTPReaderWriter {

// WriteHTTPRequest implement of MessageWriter
type WriteHTTPRequest struct {
svc *descriptor.ServiceDescriptor
binaryWithBase64 bool
convOpts conv.Options // used for dynamicgo conversion
convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
dynamicgoEnabled bool
svc *descriptor.ServiceDescriptor
binaryWithBase64 bool
convOpts conv.Options // used for dynamicgo conversion
convOptsWithThriftBase conv.Options // used for dynamicgo conversion with EnableThriftBase turned on
dynamicgoEnabled bool
failOnNilValueForRequiredField bool // will return error when field is required but input value is nil
}

var (
Expand All @@ -71,6 +72,10 @@ func (w *WriteHTTPRequest) SetBinaryWithBase64(enable bool) {
w.binaryWithBase64 = enable
}

func (w *WriteHTTPRequest) SetFailOnNilValueForRequiredField(enable bool) {
w.failOnNilValueForRequiredField = enable
}

// SetDynamicGo ...
func (w *WriteHTTPRequest) SetDynamicGo(convOpts, convOptsWithThriftBase *conv.Options) {
w.convOpts = *convOpts
Expand All @@ -94,7 +99,7 @@ func (w *WriteHTTPRequest) originalWrite(ctx context.Context, out bufiox.Writer,
requestBase = nil
}
bw := thrift.NewBufferWriter(out)
err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64})
err = wrapStructWriter(ctx, req, bw, fn.Request, &writerOption{requestBase: requestBase, binaryWithBase64: w.binaryWithBase64, failOnNilValueForRequiredField: w.failOnNilValueForRequiredField})
bw.Recycle()
return err
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/generic/thrift/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ type writerOption struct {
requestBase *base.Base // request base from metahandler
// decoding Base64 to binary
binaryWithBase64 bool
// will return error when field is required but input value is nil
failOnNilValueForRequiredField bool
}

type writer func(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error
Expand Down Expand Up @@ -778,6 +780,9 @@ func writeHTTPRequest(ctx context.Context, val interface{}, out *thrift.BufferWr

if v == nil {
if !field.Optional {
if opt != nil && opt.failOnNilValueForRequiredField {
return fmt.Errorf("value of field [%s] is nil", name)
}
if err := out.WriteFieldBegin(field.Type.Type.ToThriftTType(), int16(field.ID)); err != nil {
return err
}
Expand Down
32 changes: 31 additions & 1 deletion pkg/generic/thrift/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,36 @@ func Test_writeHTTPRequest(t *testing.T) {
},
false,
},
{
"writeStructRequiredFail",
args{
val: &descriptor.HTTPRequest{
Body: map[string]interface{}{"hello": nil},
},

t: &descriptor.TypeDescriptor{
Type: descriptor.STRUCT,
Key: &descriptor.TypeDescriptor{Type: descriptor.STRING},
Elem: &descriptor.TypeDescriptor{Type: descriptor.STRING},
Struct: &descriptor.StructDescriptor{
Name: "Demo",
FieldsByName: map[string]*descriptor.FieldDescriptor{
"hello": {
Name: "hello",
ID: 1,
Required: true,
Type: &descriptor.TypeDescriptor{Type: descriptor.STRING},
HTTPMapping: descriptor.DefaultNewMapping("hello"),
},
},
},
},
opt: &writerOption{
failOnNilValueForRequiredField: true,
},
},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -1434,7 +1464,7 @@ func getReqPbBody() (proto.Message, error) {
path := "main.proto"
content := `
package kitex.test.server;
message BizReq {
optional int32 user_id = 1;
optional string user_name = 2;
Expand Down

0 comments on commit 4ffeadd

Please sign in to comment.