Skip to content

Commit

Permalink
Merge pull request #62 from sev-2/feature/support-rpc-with-trigger
Browse files Browse the repository at this point in the history
rpc : support return trigger and uuid params
  • Loading branch information
toopay authored Sep 5, 2024
2 parents 056aa2c + cff0dd8 commit 9066415
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 1 deletion.
110 changes: 109 additions & 1 deletion pkg/generator/rpc_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package generator

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"io/fs"
"path/filepath"
"strings"
Expand Down Expand Up @@ -114,7 +117,7 @@ func WalkScanRpc(rpcDir string) ([]string, error) {
err := filepath.Walk(rpcDir, func(path string, info fs.FileInfo, err error) error {
if strings.HasSuffix(path, ".go") {
RpcRegisterLogger.Trace("collect rpc", "path", path)
rs, e := getStructByBaseName(path, "RpcBase")
rs, e := getStructByBaseAndExcludeReturnType(path, "RpcBase", map[string]bool{"RpcReturnDataTypeTrigger": true})
if e != nil {
return e
}
Expand All @@ -130,3 +133,108 @@ func WalkScanRpc(rpcDir string) ([]string, error) {

return rpc, nil
}

func getStructByBaseAndExcludeReturnType(filePath string, baseStructName string, returnTypes map[string]bool) (r []string, err error) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filePath, nil, parser.ParseComments)
if err != nil {
return r, err
}

mapRpc := map[string]struct {
IsFoundBaseStruct bool
IsExclude bool
}{}

// Traverse the AST to find the struct with the Http attribute
for _, decl := range file.Decls {

ft, fok := decl.(*ast.FuncDecl)
if fok {
if ft.Name == nil || (ft.Name != nil && ft.Name.Name != "GetReturnType") {
continue
}

startExp, isStartExp := ft.Recv.List[0].Type.(*ast.StarExpr)
if !isStartExp {
continue
}
stName := fmt.Sprintf("%s", startExp.X)

// Iterate over the statements in the function body
for _, stmt := range ft.Body.List {
// Check if the statement is a return statement
retStmt, isReturn := stmt.(*ast.ReturnStmt)
if isReturn {
// Iterate over the results in the return statement
for _, result := range retStmt.Results {
switch expr := result.(type) {
case *ast.SelectorExpr:
if returnTypes[expr.Sel.Name] {
// isExclude = true
if v, exist := mapRpc[stName]; exist {
mapRpc[stName] = struct {
IsFoundBaseStruct bool
IsExclude bool
}{
IsFoundBaseStruct: v.IsFoundBaseStruct,
IsExclude: true,
}
}
}
default:
continue
}
}
}
}
continue
}

genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}

for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}

st, ok := typeSpec.Type.(*ast.StructType)
if !ok {
continue
}

if len(st.Fields.List) == 0 {
continue
}

for _, f := range st.Fields.List {
if se, isSe := f.Type.(*ast.SelectorExpr); isSe && se.Sel.Name == baseStructName {
if _, exist := mapRpc[typeSpec.Name.Name]; !exist {
mapRpc[typeSpec.Name.Name] = struct {
IsFoundBaseStruct bool
IsExclude bool
}{
IsFoundBaseStruct: true,
IsExclude: false,
}
}
continue
}
}

}
}

for structName, checkValue := range mapRpc {
if checkValue.IsFoundBaseStruct && checkValue.IsExclude {
continue
}
r = append(r, structName)
}

return
}
9 changes: 9 additions & 0 deletions pkg/generator/rpc_register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ func TestWalkRpcDir(t *testing.T) {
assert.Equal(t, "GetVoteBy", rs[0])
}

func TestScanRpcAndExclude(t *testing.T) {
testPath, err := utils.GetAbsolutePath("/testdata")
assert.NoError(t, err)

rs, err := generator.WalkScanRpc(testPath)
assert.NoError(t, err)
assert.Len(t, rs, 1)
}

func TestGenerateRpcRegister(t *testing.T) {
dir, err := os.MkdirTemp("", "rpc_register")
assert.NoError(t, err)
Expand Down
36 changes: 36 additions & 0 deletions pkg/generator/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,39 @@ func TestGenerateRpc(t *testing.T) {
assert.NoError(t, err2)
assert.FileExists(t, dir+"/internal/rpc/get_submissions.go")
}

func TestRpcWithTrigger(t *testing.T) {
fn := objects.Function{
Schema: "public",
Name: "create_profile",
Language: "plpgsql",
Definition: `BEGIN INSERT INTO public.users (firstname,lastname, email) \nVALUES \n (\n NEW.raw_user_meta_data ->> 'name', \n NEW.raw_user_meta_data ->> 'name', \n NEW.raw_user_meta_data ->> 'email'\n );\nRETURN NEW;\nEND;`,
CompleteStatement: `CREATE OR REPLACE FUNCTION public.create_profile()\n RETURNS trigger\n LANGUAGE plpgsql\n SECURITY DEFINER\nAS $function$BEGIN INSERT INTO public.users (firstname,lastname, email) \nVALUES \n (\n NEW.raw_user_meta_data ->> 'name', \n NEW.raw_user_meta_data ->> 'name', \n NEW.raw_user_meta_data ->> 'email'\n );\nRETURN NEW;\nEND;$function$\n`,
Args: []objects.FunctionArg{},
ReturnTypeID: 2279,
ReturnType: "trigger",
IsSetReturningFunction: false,
Behavior: string(raiden.RpcBehaviorVolatile),
SecurityDefiner: true,
ConfigParams: nil,
}

result, err := generator.ExtractRpcFunction(&fn, []objects.Table{
{Name: "users"},
})
assert.NoError(t, err)

raidenPath := fmt.Sprintf("%q", "github.com/sev-2/raiden")
importsMap := map[string]bool{
raidenPath: true,
}
returnDecl, returnColumns, IsReturnArr, err := result.GetReturn(importsMap)
assert.NoError(t, err)

assert.Equal(t, "interface{}", returnDecl)
assert.Equal(t, 0, len(returnColumns))
assert.False(t, IsReturnArr)

// assert security type
assert.Equal(t, "RpcSecurityTypeDefiner", result.GetSecurity())
}
26 changes: 26 additions & 0 deletions pkg/generator/testdata/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,29 @@ type GetVoteBy struct {
Params *GetVoteByParams `json:"-"`
Return GetVoteByResult `json:"-"`
}

type CreateProfileParams struct {
}
type CreateProfileResult interface{}

type CreateProfile struct {
raiden.RpcBase
Params *CreateProfileParams `json:"-"`
Return CreateProfileResult `json:"-"`
}

func (r *CreateProfile) GetName() string {
return "create_profile"
}

func (r *CreateProfile) GetSecurity() raiden.RpcSecurityType {
return raiden.RpcSecurityTypeDefiner
}

func (r *CreateProfile) GetReturnType() raiden.RpcReturnDataType {
return raiden.RpcReturnDataTypeTrigger
}

func (r *CreateProfile) GetRawDefinition() string {
return `BEGIN INSERT INTO public.users (firstname,lastname, email) VALUES ( NEW.raw_user_meta_data ->> 'name', NEW.raw_user_meta_data ->> 'name', NEW.raw_user_meta_data ->> 'email' ); RETURN NEW; END;`
}
40 changes: 40 additions & 0 deletions pkg/state/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,43 @@ func TestExtractRpcResult_ToDeleteFlatMap(t *testing.T) {
assert.Equal(t, "rpc1", mapData["rpc1"].Name)
assert.Equal(t, "rpc2", mapData["rpc2"].Name)
}

// Test declaration query with return trigger

type CreateProfileParams struct{}
type CreateProfileResult interface{}

type CreateProfile struct {
raiden.RpcBase
Params *CreateProfileParams `json:"-"`
Return CreateProfileResult `json:"-"`
}

func (r *CreateProfile) GetName() string {
return "create_profile"
}

func (r *CreateProfile) GetSecurity() raiden.RpcSecurityType {
return raiden.RpcSecurityTypeDefiner
}

func (r *CreateProfile) GetReturnType() raiden.RpcReturnDataType {
return raiden.RpcReturnDataTypeTrigger
}

func (r *CreateProfile) GetRawDefinition() string {
return `BEGIN INSERT INTO public.users (firstname,lastname, email) VALUES ( NEW.raw_user_meta_data ->> 'name', NEW.raw_user_meta_data ->> 'name', NEW.raw_user_meta_data ->> 'email' ); RETURN NEW; END;`
}

func TestRpcFunction_ReturnTrigger(t *testing.T) {
rpc := &CreateProfile{}
e := raiden.BuildRpc(rpc)
assert.NoError(t, e)
fn := objects.Function{}

err := state.BindRpcFunction(rpc, &fn)
assert.NoError(t, err)
assert.Equal(t, "create_profile", fn.Name)
assert.Equal(t, "public", fn.Schema)
assert.Equal(t, "create or replace function public.create_profile() returns trigger language plpgsql security definer as $function$ begin insert into public.users (firstname,lastname, email) values ( new.raw_user_meta_data ->> 'name', new.raw_user_meta_data ->> 'name', new.raw_user_meta_data ->> 'email' ) ; return new ; end; $function$", fn.CompleteStatement)
}
10 changes: 10 additions & 0 deletions rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const (
RpcParamDataTypeTimestampTZAlias RpcParamDataType = "TIMESTAMPZ"
RpcParamDataTypeJSON RpcParamDataType = "JSON"
RpcParamDataTypeJSONB RpcParamDataType = "JSONB"
RpcParamDataTypeUuid RpcParamDataType = "UUID"
)

// Define constants for rpc return data type
Expand All @@ -64,6 +65,7 @@ const (
RpcReturnDataTypeTable RpcReturnDataType = "TABLE"
RpcReturnDataTypeSetOf RpcReturnDataType = "SETOF"
RpcReturnDataTypeVoid RpcReturnDataType = "VOID"
RpcReturnDataTypeTrigger RpcReturnDataType = "TRIGGER"
)

func RpcParamToGoType(dataType RpcParamDataType) string {
Expand All @@ -84,6 +86,8 @@ func RpcParamToGoType(dataType RpcParamDataType) string {
return "time.Time"
case RpcParamDataTypeJSON, RpcParamDataTypeJSONB:
return "map[string]interface{}"
case RpcParamDataTypeUuid:
return "uuid.UUID"
default:
return "interface{}" // Return interface{} for unknown types
}
Expand Down Expand Up @@ -127,6 +131,8 @@ func GetValidRpcParamType(pType string, returnAlias bool) (RpcParamDataType, err
return RpcParamDataTypeJSON, nil
case RpcParamDataTypeJSONB:
return RpcParamDataTypeJSONB, nil
case RpcParamDataTypeUuid:
return RpcParamDataTypeUuid, nil
default:
return "", fmt.Errorf("unsupported rpc param type : %s", pCheckType)
}
Expand Down Expand Up @@ -197,6 +203,8 @@ func GetValidRpcReturnType(pType string, returnAlias bool) (RpcReturnDataType, e
return RpcReturnDataTypeTable, nil
case RpcReturnDataTypeVoid:
return RpcReturnDataTypeVoid, nil
case RpcReturnDataTypeTrigger:
return RpcReturnDataTypeTrigger, nil
default:
return "", fmt.Errorf("unsupported rpc return type : %s", pCheckType)
}
Expand Down Expand Up @@ -243,6 +251,8 @@ func GetValidRpcReturnNameDecl(pType RpcReturnDataType, returnAlias bool) (strin
return "RpcReturnDataTypeTable", nil
case RpcReturnDataTypeVoid:
return "RpcReturnDataTypeVoid", nil
case RpcReturnDataTypeTrigger:
return "RpcReturnDataTypeTrigger", nil
default:
return "", fmt.Errorf("unsupported rpc return name declaration : %s", pType)
}
Expand Down

0 comments on commit 9066415

Please sign in to comment.