Skip to content

Commit

Permalink
feat(codec): Unknown Method Handler (cloudwego#1360)
Browse files Browse the repository at this point in the history
  • Loading branch information
lokistars committed Jul 18, 2024
1 parent f31874e commit d0c77ca
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 132 deletions.
4 changes: 2 additions & 2 deletions pkg/remote/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package remote

import (
"context"
"github.com/cloudwego/kitex/pkg/unknownservice/service"
"net"
"time"

"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/profiler"
unknown "github.com/cloudwego/kitex/pkg/remote/codec/unknown/service"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand Down Expand Up @@ -114,7 +114,7 @@ type ServerOption struct {

GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error

UnknownMethodHandler unknown.UnknownMethodService
UnknownMethodService service.UnknownMethodService

// RefuseTrafficWithoutServiceName is used for a server with multi services
RefuseTrafficWithoutServiceName bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package unknown
package service

import (
"context"
Expand All @@ -24,8 +24,8 @@ import (
const (
// UnknownService name
UnknownService = "$UnknownService" // private as "$"
// UnknownCall name
UnknownMethod = "$UnknownCall"
// UnknownMethod name
UnknownMethod = "$UnknownMethod"
)

type Args struct {
Expand All @@ -41,7 +41,7 @@ type Result struct {
}

type UnknownMethodService interface {
UnknownMethodHandler(ctx context.Context, method string, request []byte) ([]byte, error)
UnknownMethodHandler(ctx context.Context, serviceName, method string, request []byte) ([]byte, error)
}

// NewServiceInfo create serviceInfo
Expand Down Expand Up @@ -78,7 +78,7 @@ func callHandler(ctx context.Context, handler, arg, result interface{}) error {
realResult := result.(*Result)
realResult.Method = realArg.Method
realResult.ServiceName = realArg.ServiceName
success, err := handler.(UnknownMethodService).UnknownMethodHandler(ctx, realArg.Method, realArg.Request)
success, err := handler.(UnknownMethodService).UnknownMethodHandler(ctx, realArg.ServiceName, realArg.Method, realArg.Request)
if err != nil {
return err
}
Expand Down
137 changes: 37 additions & 100 deletions pkg/remote/codec/unknown/unknown.go → pkg/unknownservice/unknown.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,30 @@
* limitations under the License.
*/

package unknown
package unknownservice

import (
"context"
"encoding/binary"
"errors"
"fmt"
thrif "github.com/apache/thrift/lib/go/thrift"
"github.com/cloudwego/dynamicgo/thrift"
"github.com/cloudwego/kitex/pkg/protocol/bthrift"
thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
unknowns "github.com/cloudwego/kitex/pkg/remote/codec/unknown/service"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
service2 "github.com/cloudwego/kitex/pkg/unknownservice/service"
)

// UnknownCodec implements PayloadCodec
type unknownCodec struct {
Codec remote.PayloadCodec
}

// NewUnknownCodec creates the unknown binary codec.
func NewUnknownCodec(code remote.PayloadCodec) remote.PayloadCodec {
// NewUnknownServiceCodec creates the unknown binary codec.
func NewUnknownServiceCodec(code remote.PayloadCodec) remote.PayloadCodec {
return &unknownCodec{code}
}

Expand All @@ -47,7 +46,7 @@ func (c unknownCodec) Marshal(ctx context.Context, msg remote.Message, out remot
ink := msg.RPCInfo().Invocation()
data := msg.Data()

res, ok := data.(*unknowns.Result)
res, ok := data.(*service2.Result)
if !ok {
return c.Codec.Marshal(ctx, msg, out)
}
Expand All @@ -72,37 +71,34 @@ func (c unknownCodec) Marshal(ctx context.Context, msg remote.Message, out remot
// Unmarshal implements the remote.PayloadCodec interface.
func (c unknownCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
ink := message.RPCInfo().Invocation()
service, method, size, err := decode(message, in)
service, method, err := decode(message, in)
if err != nil {
return c.Codec.Unmarshal(ctx, message, in)
}
err = codec.SetOrCheckMethodName(method, message)
if te, ok := err.(*remote.TransError); ok && te.TypeID() == remote.UnknownMethod {
svcInfo, err := message.SpecifyServiceInfo(unknowns.UnknownService, unknowns.UnknownMethod)
var te *remote.TransError
if errors.As(err, &te) && (te.TypeID() == remote.UnknownMethod || te.TypeID() == remote.UnknownService) {
svcInfo, err := message.SpecifyServiceInfo(service2.UnknownService, service2.UnknownMethod)
if err != nil {
return err
}

if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
ink.SetMethodName(unknowns.UnknownMethod)
ink.SetMethodName(service2.UnknownMethod)
ink.SetPackageName(svcInfo.GetPackageName())
ink.SetServiceName(unknowns.UnknownService)
ink.SetServiceName(service2.UnknownService)
} else {
return errors.New("the interface Invocation doesn't implement InvocationSetter")
}
if err = codec.NewDataIfNeeded(unknowns.UnknownMethod, message); err != nil {
if err = codec.NewDataIfNeeded(service2.UnknownMethod, message); err != nil {
return err
}

data := message.Data()

if data, ok := data.(*unknowns.Args); ok {
if data, ok := data.(*service2.Args); ok {
data.Method = method
data.ServiceName = service
err := in.Skip(int(size))
if err != nil {
return err
}
buf, err := in.Next(in.ReadableLen())
if err != nil {
return err
Expand All @@ -124,83 +120,23 @@ func write(dst, src []byte) {
copy(dst, src)
}

func decode(message remote.Message, in remote.ByteBuffer) (string, string, int32, error) {
func decode(message remote.Message, in remote.ByteBuffer) (string, string, error) {
code := message.ProtocolInfo().CodecType
if code == serviceinfo.Thrift {
return decodeThrift(message, in)
} else if code == serviceinfo.Protobuf {
return decodeProtobuf(message, in)
}
return "", "", 0, nil
}

// decodeThrift Thrift decoder
func decodeThrift(message remote.Message, in remote.ByteBuffer) (string, string, int32, error) {
buf, err := in.Peek(4)
if err != nil {
return "", "", 0, perrors.NewProtocolError(err)
}
size := int32(binary.BigEndian.Uint32(buf))
if size > 0 {
return "", "", 0, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Missing version in ReadMessageBegin")
}
msgType := thrift.TMessageType(size & 0x0ff)
if err = codec.UpdateMsgType(uint32(msgType), message); err != nil {
return "", "", 0, err
}
version := int64(int64(size) & thrift.VERSION_MASK)
if version != thrift.VERSION_1 {
return "", "", 0, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in ReadMessageBegin")
}
// exception message
if message.MessageType() == remote.Exception {
return "", "", 0, perrors.NewProtocolErrorWithMsg("thrift unmarshal")
}
// 获取method
method, size, err := peekMethod(in)
if err != nil {
return "", "", 0, perrors.NewProtocolError(err)
}
seqID, err := peekSeqID(in, size)
if err != nil {
return "", "", 0, perrors.NewProtocolError(err)
}
if err = codec.SetOrCheckSeqID(seqID, message); err != nil {
return "", "", 0, err
}
return message.RPCInfo().Invocation().ServiceName(), method, size + 4, nil
}

// decodeProtobuf Protobuf decoder
func decodeProtobuf(message remote.Message, in remote.ByteBuffer) (string, string, int32, error) {
magicAndMsgType, err := codec.PeekUint32(in)
if err != nil {
return "", "", 0, err
}
if magicAndMsgType&codec.MagicMask != codec.ProtobufV1Magic {
return "", "", 0, perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in protobuf Unmarshal")
}
msgType := magicAndMsgType & codec.FrontMask
if err = codec.UpdateMsgType(msgType, message); err != nil {
return "", "", 0, err
}

method, size, err := peekMethod(in)
if err != nil {
return "", "", 0, perrors.NewProtocolError(err)
}
seqID, err := peekSeqID(in, size)
if err != nil {
return "", "", 0, perrors.NewProtocolError(err)
}
if err = codec.SetOrCheckSeqID(seqID, message); err != nil && msgType != uint32(remote.Exception) {
return "", "", 0, err
}
// exception message
if message.MessageType() == remote.Exception {
return "", "", 0, perrors.NewProtocolErrorWithMsg("protobuf unmarshal")
if code == serviceinfo.Thrift || code == serviceinfo.Protobuf {
method, size, err := peekMethod(in)
if err != nil {
return "", "", perrors.NewProtocolError(err)
}
seqID, err := peekSeqID(in, size)
if err != nil {
return "", "", perrors.NewProtocolError(err)
}
if err = codec.SetOrCheckSeqID(seqID, message); err != nil {
return "", "", err
}
return message.RPCInfo().Invocation().ServiceName(), method, nil
}
return message.RPCInfo().Invocation().ServiceName(), method, size + 4, nil
return "", "", nil
}

func peekMethod(in remote.ByteBuffer) (string, int32, error) {
Expand Down Expand Up @@ -229,29 +165,30 @@ func peekSeqID(in remote.ByteBuffer, size int32) (int32, error) {
return seqID, nil
}

func encode(res *unknowns.Result, msg remote.Message, out remote.ByteBuffer) error {
func encode(res *service2.Result, msg remote.Message, out remote.ByteBuffer) error {

if msg.ProtocolInfo().CodecType == serviceinfo.Thrift {
return encodeThrift(res, msg, out)
} else if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf {
return encodeProtobuf(res, msg, out)
}
if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf {
return encodeKitexProtobuf(res, msg, out)
}
return nil
}

// encodeThrift Thrift encoder
func encodeThrift(res *unknowns.Result, msg remote.Message, out remote.ByteBuffer) error {
func encodeThrift(res *service2.Result, msg remote.Message, out remote.ByteBuffer) error {
nw, _ := out.(remote.NocopyWrite)
msgType := msg.MessageType()
ink := msg.RPCInfo().Invocation()
msgBeginLen := bthrift.Binary.MessageBeginLength(res.Method, thrif.TMessageType(msgType), ink.SeqID())
msgBeginLen := bthrift.Binary.MessageBeginLength(res.Method, thrift.TMessageType(msgType), ink.SeqID())
msgEndLen := bthrift.Binary.MessageEndLength()

buf, err := out.Malloc(msgBeginLen + len(res.Success) + msgEndLen)
if err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error()))
}
offset := bthrift.Binary.WriteMessageBegin(buf, res.Method, thrif.TMessageType(msgType), ink.SeqID())
offset := bthrift.Binary.WriteMessageBegin(buf, res.Method, thrift.TMessageType(msgType), ink.SeqID())
write(buf[offset:], res.Success)
bthrift.Binary.WriteMessageEnd(buf[offset:])
if nw == nil {
Expand All @@ -262,7 +199,7 @@ func encodeThrift(res *unknowns.Result, msg remote.Message, out remote.ByteBuffe
}

// encodeProtobuf Protobuf encoder
func encodeProtobuf(res *unknowns.Result, msg remote.Message, out remote.ByteBuffer) error {
func encodeKitexProtobuf(res *service2.Result, msg remote.Message, out remote.ByteBuffer) error {
ink := msg.RPCInfo().Invocation()
// 3.1 magic && msgType
if err := codec.WriteUint32(codec.ProtobufV1Magic+uint32(msg.MessageType()), out); err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
* limitations under the License.
*/

package unknown
package unknownservice

import (
"context"
"github.com/cloudwego/kitex/internal/mocks"
mt "github.com/cloudwego/kitex/internal/mocks/thrift/fast"
mt "github.com/cloudwego/kitex/internal/mocks/thrift"
"github.com/cloudwego/kitex/internal/test"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/thrift"
unknown "github.com/cloudwego/kitex/pkg/remote/codec/unknown/service"
netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/unknownservice/service"
"github.com/cloudwego/kitex/transport"
"github.com/cloudwego/netpoll"
"testing"
Expand All @@ -43,14 +43,17 @@ func TestNormal(t *testing.T) {
ctx := context.Background()
err := payloadCodec.Marshal(ctx, sendMsg, buf)
test.Assert(t, err == nil, err)
buf.Flush()
err = buf.Flush()
test.Assert(t, err == nil, err)
recvMsg := initRecvMsg()
recvMsg.SetPayloadLen(buf.ReadableLen())
_, size, err := peekMethod(buf)
err = payloadCodec.Unmarshal(ctx, recvMsg, buf)
test.Assert(t, err == nil, err)

req := (sendMsg.Data()).(*unknown.Result).Success
resp := (recvMsg.Data()).(*unknown.Args).Request
req := (sendMsg.Data()).(*service.Result).Success
resp := (recvMsg.Data()).(*service.Args).Request
resp = resp[size+4:]
for i, item := range req {
test.Assert(t, item == resp[i])
}
Expand All @@ -76,9 +79,8 @@ func initSendMsg(tp transport.Protocol) remote.Message {
length := _args.BLength()
bytes := make([]byte, length)
_args.FastWriteNocopy(bytes, nil)
arg := unknown.Result{Success: bytes, Method: "mock", ServiceName: ""}

ink := rpcinfo.NewInvocation("", unknown.UnknownMethod)
arg := service.Result{Success: bytes, Method: "mock", ServiceName: ""}
ink := rpcinfo.NewInvocation("", service.UnknownMethod)
ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)

msg := remote.NewMessage(&arg, svcInfo, ri, remote.Call, remote.Client)
Expand All @@ -89,10 +91,10 @@ func initSendMsg(tp transport.Protocol) remote.Message {
}

func initRecvMsg() remote.Message {
arg := unknown.Args{Request: make([]byte, 0), Method: "mock", ServiceName: ""}
ink := rpcinfo.NewInvocation("", unknown.UnknownMethod)
arg := service.Args{Request: make([]byte, 0), Method: "mock", ServiceName: ""}
ink := rpcinfo.NewInvocation("", service.UnknownMethod)
ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
svc := unknown.NewServiceInfo(svcInfo.PayloadCodec, unknown.UnknownService, unknown.UnknownMethod)
svc := service.NewServiceInfo(svcInfo.PayloadCodec, service.UnknownService, service.UnknownMethod)
msg := remote.NewMessage(&arg, svc, ri, remote.Call, remote.Server)
return msg
}
Expand Down
Loading

0 comments on commit d0c77ca

Please sign in to comment.