Skip to content

Commit

Permalink
Implement the hook interceptor (milvus-io#19294)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>

Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored Sep 23, 2022
1 parent 901d3fb commit 68a2574
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 23 deletions.
9 changes: 9 additions & 0 deletions api/hook/hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package hook

type Hook interface {
Init(params map[string]string) error
Mock(req interface{}, fullMethod string) (bool, interface{}, error)
Before(req interface{}, fullMethod string) error
After(result interface{}, err error, fullMethod string) error
Release()
}
Empty file added configs/hook.yaml
Empty file.
6 changes: 1 addition & 5 deletions internal/core/thirdparty/rocksdb/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ macro( build_rocksdb )
message( STATUS "Building ROCKSDB-${ROCKSDB_VERSION} from source" )

set ( ROCKSDB_MD5 "e4a0625f0cec82060e62c81b787a1124" )

if ( EMBEDDED_MILVUS )
message ( STATUS "Turning on fPIC while building embedded Milvus" )
set( FPIC_ARG "-DCMAKE_POSITION_INDEPENDENT_CODE=ON" )
endif()
set( FPIC_ARG "-DCMAKE_POSITION_INDEPENDENT_CODE=ON" )
set( ROCKSDB_CMAKE_ARGS
"-DWITH_GFLAGS=OFF"
"-DROCKSDB_BUILD_SHARED=OFF"
Expand Down
1 change: 1 addition & 0 deletions internal/distributed/proxy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ func (s *Server) startExternalGrpc(grpcPort int, errChan chan error) {
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
ot.UnaryServerInterceptor(opts...),
grpc_auth.UnaryServerInterceptor(proxy.AuthenticationInterceptor),
proxy.UnaryServerHookInterceptor(),
proxy.UnaryServerInterceptor(proxy.PrivilegeInterceptor),
logutil.UnaryTraceLoggerInterceptor,
proxy.RateLimitInterceptor(limiter),
Expand Down
95 changes: 95 additions & 0 deletions internal/proxy/hook_interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package proxy

import (
"context"
"plugin"

"github.com/milvus-io/milvus/api/hook"

"go.uber.org/zap"

"google.golang.org/grpc"
)

type defaultHook struct {
}

func (d defaultHook) Init(params map[string]string) error {
return nil
}

func (d defaultHook) Mock(req interface{}, fullMethod string) (bool, interface{}, error) {
return false, nil, nil
}

func (d defaultHook) Before(req interface{}, fullMethod string) error {
return nil
}

func (d defaultHook) After(result interface{}, err error, fullMethod string) error {
return nil
}

func (d defaultHook) Release() {}

var hoo hook.Hook

func initHook() {
path := Params.ProxyCfg.SoPath
if path == "" {
hoo = defaultHook{}
return
}

logger.Debug("start to load plugin", zap.String("path", path))
p, err := plugin.Open(path)
if err != nil {
exit("fail to open the plugin", err)
}
logger.Debug("plugin open")

h, err := p.Lookup("MilvusHook")
if err != nil {
exit("fail to the 'MilvusHook' object in the plugin", err)
}

var ok bool
hoo, ok = h.(hook.Hook)
if !ok {
exit("fail to convert the `Hook` interface", nil)
}
if err = hoo.Init(Params.HookCfg.SoConfig); err != nil {
exit("fail to init configs for the hoo", err)
}
}

func exit(errMsg string, err error) {
logger.Panic("hoo error", zap.String("path", Params.ProxyCfg.SoPath), zap.String("msg", errMsg), zap.Error(err))
}

func UnaryServerHookInterceptor() grpc.UnaryServerInterceptor {
initHook()
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var (
fullMethod = info.FullMethod
isMock bool
mockResp interface{}
realResp interface{}
realErr error
err error
)

if isMock, mockResp, err = hoo.Mock(req, fullMethod); isMock {
return mockResp, err
}

if err = hoo.Before(req, fullMethod); err != nil {
return nil, err
}
realResp, realErr = handler(ctx, req)
if err = hoo.After(realResp, realErr, fullMethod); err != nil {
return nil, err
}
return realResp, realErr
}
}
127 changes: 127 additions & 0 deletions internal/proxy/hook_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package proxy

import (
"context"
"errors"
"testing"

"google.golang.org/grpc"

"github.com/stretchr/testify/assert"
)

func TestInitHook(t *testing.T) {
Params.ProxyCfg.SoPath = ""
initHook()
assert.IsType(t, defaultHook{}, hoo)

Params.ProxyCfg.SoPath = "/a/b/hook.so"
assert.Panics(t, func() {
initHook()
})
Params.ProxyCfg.SoPath = ""
}

type mockHook struct {
defaultHook
mockRes interface{}
mockErr error
}

func (m mockHook) Mock(req interface{}, fullMethod string) (bool, interface{}, error) {
return true, m.mockRes, m.mockErr
}

type req struct {
method string
}

type beforeMock struct {
defaultHook
method string
err error
}

func (b beforeMock) Before(r interface{}, fullMethod string) error {
re, ok := r.(*req)
if !ok {
return errors.New("r is invalid type")
}
re.method = b.method
return b.err
}

type resp struct {
method string
}

type afterMock struct {
defaultHook
method string
err error
}

func (a afterMock) After(r interface{}, err error, fullMethod string) error {
re, ok := r.(*resp)
if !ok {
return errors.New("r is invalid type")
}
re.method = a.method
return a.err
}

func TestHookInterceptor(t *testing.T) {
var (
ctx = context.Background()
info = &grpc.UnaryServerInfo{
FullMethod: "test",
}
interceptor = UnaryServerHookInterceptor()
mockHoo = mockHook{mockRes: "mock", mockErr: errors.New("mock")}
r = &req{method: "req"}
re = &resp{method: "resp"}
beforeHoo = beforeMock{method: "before", err: errors.New("before")}
afterHoo = afterMock{method: "after", err: errors.New("after")}

res interface{}
err error
)

hoo = mockHoo
res, err = interceptor(ctx, "request", info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, res, mockHoo.mockRes)
assert.Equal(t, err, mockHoo.mockErr)

hoo = beforeHoo
_, err = interceptor(ctx, r, info, func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
})
assert.Equal(t, r.method, beforeHoo.method)
assert.Equal(t, err, beforeHoo.err)

hoo = afterHoo
_, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return re, nil
})
assert.Equal(t, re.method, afterHoo.method)
assert.Equal(t, err, afterHoo.err)

hoo = defaultHook{}
res, err = interceptor(ctx, r, info, func(ctx context.Context, r interface{}) (interface{}, error) {
return &resp{
method: r.(*req).method,
}, nil
})
assert.Equal(t, res.(*resp).method, r.method)
assert.NoError(t, err)
}

func TestDefaultHook(t *testing.T) {
d := defaultHook{}
assert.NoError(t, d.Init(nil))
assert.NotPanics(t, func() {
d.Release()
})
}
41 changes: 24 additions & 17 deletions internal/util/paramtable/base_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ const (
DefaultEtcdEndpoints = "localhost:2379"
DefaultInsertBufferSize = "16777216"
DefaultEnvPrefix = "milvus"

DefaultLogFormat = "text"
DefaultLogLevelForBase = "debug"
DefaultRootPath = ""
DefaultMaxSize = 300
DefaultMaxAge = 10
DefaultMaxBackups = 20
)

var defaultYaml = DefaultMilvusYaml
Expand All @@ -71,6 +78,8 @@ type BaseTable struct {
RoleName string
Log log.Config
LogCfgFunc func(log.Config)

YamlFile string
}

// GlobalInitWithYaml initializes the param table with the given yaml.
Expand All @@ -94,6 +103,9 @@ func (gp *BaseTable) Init() {
ret = strings.ReplaceAll(ret, ".", "")
return ret
}
if gp.YamlFile == "" {
gp.YamlFile = defaultYaml
}
gp.initConfigsFromLocal(formatter)
gp.initConfigsFromRemote(formatter)
gp.InitLogCfg()
Expand All @@ -107,7 +119,7 @@ func (gp *BaseTable) initConfigsFromLocal(formatter func(key string) string) {
}

gp.configDir = gp.initConfPath()
configFilePath := gp.configDir + "/" + defaultYaml
configFilePath := gp.configDir + "/" + gp.YamlFile
gp.mgr, err = config.Init(config.WithEnvSource(formatter), config.WithFilesSource(configFilePath))
if err != nil {
log.Warn("init baseTable with file failed", zap.String("configFile", configFilePath), zap.Error(err))
Expand All @@ -127,7 +139,7 @@ func (gp *BaseTable) initConfigsFromRemote(formatter func(key string) string) {
return
}

configFilePath := gp.configDir + "/" + defaultYaml
configFilePath := gp.configDir + "/" + gp.YamlFile
gp.mgr, err = config.Init(config.WithEnvSource(formatter),
config.WithFilesSource(configFilePath),
config.WithEtcdSource(&config.EtcdInfo{
Expand Down Expand Up @@ -164,6 +176,10 @@ func (gp *BaseTable) initConfPath() string {
return configDir
}

func (gp *BaseTable) Configs() map[string]string {
return gp.mgr.Configs()
}

// Load loads an object with @key.
func (gp *BaseTable) Load(key string) (string, error) {
return gp.mgr.GetConfig(key)
Expand Down Expand Up @@ -366,19 +382,13 @@ func ConvertRangeToIntSlice(rangeStr, sep string) []int {
// InitLogCfg init log of the base table
func (gp *BaseTable) InitLogCfg() {
gp.Log = log.Config{}
format, err := gp.Load("log.format")
if err != nil {
panic(err)
}
format := gp.LoadWithDefault("log.format", DefaultLogFormat)
gp.Log.Format = format
level, err := gp.Load("log.level")
if err != nil {
panic(err)
}
level := gp.LoadWithDefault("log.level", DefaultLogLevelForBase)
gp.Log.Level = level
gp.Log.File.MaxSize = gp.ParseInt("log.file.maxSize")
gp.Log.File.MaxBackups = gp.ParseInt("log.file.maxBackups")
gp.Log.File.MaxDays = gp.ParseInt("log.file.maxAge")
gp.Log.File.MaxSize = gp.ParseIntWithDefault("log.file.maxSize", DefaultMaxSize)
gp.Log.File.MaxBackups = gp.ParseIntWithDefault("log.file.maxBackups", DefaultMaxBackups)
gp.Log.File.MaxDays = gp.ParseIntWithDefault("log.file.maxAge", DefaultMaxAge)
}

// SetLogConfig set log config of the base table
Expand All @@ -398,10 +408,7 @@ func (gp *BaseTable) SetLogConfig() {

// SetLogger sets the logger file by given id
func (gp *BaseTable) SetLogger(id UniqueID) {
rootPath, err := gp.Load("log.file.rootPath")
if err != nil {
panic(err)
}
rootPath := gp.LoadWithDefault("log.file.rootPath", DefaultRootPath)
if rootPath != "" {
if id < 0 {
gp.Log.File.Filename = path.Join(rootPath, gp.RoleName+".log")
Expand Down
11 changes: 10 additions & 1 deletion internal/util/paramtable/component_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type ComponentParam struct {
DataNodeCfg dataNodeConfig
IndexCoordCfg indexCoordConfig
IndexNodeCfg indexNodeConfig
HookCfg HookConfig
}

// InitOnce initialize once
Expand All @@ -76,6 +77,7 @@ func (p *ComponentParam) Init() {
p.DataNodeCfg.init(&p.BaseTable)
p.IndexCoordCfg.init(&p.BaseTable)
p.IndexNodeCfg.init(&p.BaseTable)
p.HookCfg.init()
}

// SetLogConfig set log config with given role
Expand Down Expand Up @@ -431,7 +433,8 @@ type proxyConfig struct {
IP string
NetworkAddress string

Alias string
Alias string
SoPath string

NodeID atomic.Value
TimeTickInterval time.Duration
Expand Down Expand Up @@ -475,13 +478,19 @@ func (p *proxyConfig) init(base *BaseTable) {
p.initGinLogging()
p.initMaxUserNum()
p.initMaxRoleNum()

p.initSoPath()
}

// InitAlias initialize Alias member.
func (p *proxyConfig) InitAlias(alias string) {
p.Alias = alias
}

func (p *proxyConfig) initSoPath() {
p.SoPath = p.Base.LoadWithDefault("proxy.soPath", "")
}

func (p *proxyConfig) initTimeTickInterval() {
interval := p.Base.ParseIntWithDefault("proxy.timeTickInterval", 200)
p.TimeTickInterval = time.Duration(interval) * time.Millisecond
Expand Down
Loading

0 comments on commit 68a2574

Please sign in to comment.