Skip to content

Commit

Permalink
add StmtIndexConvert
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuxiujia committed Mar 15, 2020
1 parent 6c14103 commit e1ecbc4
Show file tree
Hide file tree
Showing 33 changed files with 212 additions and 103 deletions.
2 changes: 1 addition & 1 deletion DataSourceRouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type DataSourceRouter interface {
//返回(session,error)路由选择后的session,error异常
Router(mapperName string, engine SessionEngine) (Session, error)
//设置sql.DB,该方法会被GoMybatis框架内调用
SetDB(driver string, url string, db *sql.DB)
SetDB(driverName string, url string, db *sql.DB)

Name() string
}
43 changes: 30 additions & 13 deletions GoMybatis.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"github.com/zhuxiujia/GoMybatis/ast"
"github.com/zhuxiujia/GoMybatis/lib/github.com/beevik/etree"
"github.com/zhuxiujia/GoMybatis/stmt"
"github.com/zhuxiujia/GoMybatis/utils"
"log"
"reflect"
Expand Down Expand Up @@ -328,15 +329,11 @@ func exeMethodByXml(elementType ElementType, beanName string, sessionEngine Sess
var session Session
var sql string
var err error
var array_arg = []interface{}{}
session, sql, err = buildSql(proxyArg, nodes, sessionEngine.SqlBuilder(), &array_arg)
//session
session, err = findArgSession(proxyArg)
if err != nil {
return err
}
if sessionEngine.SessionFactory() == nil && session == nil {
panic("[GoMybatis] exe sql need a SessionFactory or Session!")
}
//session
if session == nil {
var goroutineID int64 //协程id
if sessionEngine.GoroutineIDEnable() {
Expand All @@ -354,8 +351,17 @@ func exeMethodByXml(elementType ElementType, beanName string, sessionEngine Sess
session = s
defer session.Close()
}
var haveLastReturnValue = returnValue != nil && (*returnValue).IsNil() == false
convert, err := session.StmtConvert()
if err != nil {
return err
}
var array_arg = []interface{}{}
sql, err = buildSql(proxyArg, nodes, sessionEngine.SqlBuilder(), &array_arg, convert)
if err != nil {
return err
}
//do CRUD
var haveLastReturnValue = returnValue != nil && (*returnValue).IsNil() == false
if elementType == Element_Select && haveLastReturnValue {
//is select and have return value
if sessionEngine.LogEnable() {
Expand Down Expand Up @@ -419,8 +425,22 @@ func closeSession(factory *SessionFactory, session Session) {
session.Close()
}

func buildSql(proxyArg ProxyArg, nodes []ast.Node, sqlBuilder SqlBuilder, array_arg *[]interface{}) (Session, string, error) {
func findArgSession(proxyArg ProxyArg) (Session, error) {
var session Session
for _, arg := range proxyArg.Args {
var argInterface = arg.Interface()
if arg.Kind() == reflect.Ptr && arg.IsNil() == false && argInterface != nil && arg.Type().String() == GoMybatis_Session_Ptr {
session = *(argInterface.(*Session))
continue
} else if argInterface != nil && arg.Kind() == reflect.Interface && arg.Type().String() == GoMybatis_Session {
session = argInterface.(Session)
continue
}
}
return session, nil
}

func buildSql(proxyArg ProxyArg, nodes []ast.Node, sqlBuilder SqlBuilder, array_arg *[]interface{}, stmtConvert stmt.StmtIndexConvert) (string, error) {
var paramMap = make(map[string]interface{})
var tagArgsLen = proxyArg.TagArgsLen
var argsLen = proxyArg.ArgsLen //参数长度,除session参数外。
Expand All @@ -429,10 +449,8 @@ func buildSql(proxyArg ProxyArg, nodes []ast.Node, sqlBuilder SqlBuilder, array_
for argIndex, arg := range proxyArg.Args {
var argInterface = arg.Interface()
if arg.Kind() == reflect.Ptr && arg.IsNil() == false && argInterface != nil && arg.Type().String() == GoMybatis_Session_Ptr {
session = *(argInterface.(*Session))
continue
} else if argInterface != nil && arg.Kind() == reflect.Interface && arg.Type().String() == GoMybatis_Session {
session = argInterface.(Session)
continue
}
if isCustomStruct(arg.Type()) {
Expand Down Expand Up @@ -466,9 +484,8 @@ func buildSql(proxyArg ProxyArg, nodes []ast.Node, sqlBuilder SqlBuilder, array_
}
paramMap = scanStructArgFields(proxyArg.Args[customIndex], tag)
}

result, err := sqlBuilder.BuildSql(paramMap, nodes, array_arg)
return session, result, err
result, err := sqlBuilder.BuildSql(paramMap, nodes, array_arg, stmtConvert)
return result, err
}

//scan params
Expand Down
22 changes: 11 additions & 11 deletions GoMybatisDataSourceRouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import (

//动态数据源路由
type GoMybatisDataSourceRouter struct {
dbMap map[string]*sql.DB
driverMap map[string]string
routerFunc func(mapperName string) *string
driverLinkDBMap map[string]*sql.DB // map[driverLink]*DB
driverTypeUrlMap map[string]string // map[driverType]Url
routerFunc func(mapperName string) *string
}

//初始化路由,routerFunc为nil或者routerFunc返回nil,则框架自行选择第一个数据库作为数据源
Expand All @@ -19,15 +19,15 @@ func (it GoMybatisDataSourceRouter) New(routerFunc func(mapperName string) *stri
return nil
}
}
it.dbMap = make(map[string]*sql.DB)
it.driverMap = make(map[string]string)
it.driverLinkDBMap = make(map[string]*sql.DB)
it.driverTypeUrlMap = make(map[string]string)
it.routerFunc = routerFunc
return it
}

func (it *GoMybatisDataSourceRouter) SetDB(driver string, url string, db *sql.DB) {
it.dbMap[url] = db
it.driverMap[url] = driver
func (it *GoMybatisDataSourceRouter) SetDB(driverType string, driverLink string, db *sql.DB) {
it.driverLinkDBMap[driverLink] = db
it.driverTypeUrlMap[driverLink] = driverType
}

func (it *GoMybatisDataSourceRouter) Router(mapperName string, engine SessionEngine) (Session, error) {
Expand All @@ -39,9 +39,9 @@ func (it *GoMybatisDataSourceRouter) Router(mapperName string, engine SessionEng
}

if key != nil && *key != "" {
db = it.dbMap[*key]
db = it.driverLinkDBMap[*key]
} else {
for k, v := range it.dbMap {
for k, v := range it.driverLinkDBMap {
if v != nil {
db = v
key = &k
Expand All @@ -56,7 +56,7 @@ func (it *GoMybatisDataSourceRouter) Router(mapperName string, engine SessionEng
if key != nil {
url = *key
}
var local = LocalSession{}.New(it.driverMap[url], url, db, engine.Log())
var local = LocalSession{}.New(it.driverTypeUrlMap[url], url, db, engine.Log())
var session = Session(&local)
return session, nil
}
Expand Down
6 changes: 3 additions & 3 deletions GoMybatisEngine.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ func (it *GoMybatisEngine) SetSqlResultDecoder(decoder SqlResultDecoder) {

//打开数据库
//driverName: 驱动名称例如"mysql", dataSourceName: string 数据库url
func (it *GoMybatisEngine) Open(driverName, dataSourceName string) (*sql.DB, error) {
func (it *GoMybatisEngine) Open(driverName, dataSourceLink string) (*sql.DB, error) {
it.initCheck()
db, err := sql.Open(driverName, dataSourceName)
db, err := sql.Open(driverName, dataSourceLink)
if err != nil {
return nil, err
}
it.dataSourceRouter.SetDB(driverName, dataSourceName, db)
it.dataSourceRouter.SetDB(driverName, dataSourceLink, db)
return db, nil
}

Expand Down
5 changes: 3 additions & 2 deletions GoMybatisSqlBuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package GoMybatis

import (
"github.com/zhuxiujia/GoMybatis/ast"
"github.com/zhuxiujia/GoMybatis/stmt"
)

type GoMybatisSqlBuilder struct {
Expand Down Expand Up @@ -31,9 +32,9 @@ func (it GoMybatisSqlBuilder) New(SqlArgTypeConvert ast.SqlArgTypeConvert, expre
return it
}

func (it *GoMybatisSqlBuilder) BuildSql(paramMap map[string]interface{}, nodes []ast.Node, arg_array *[]interface{}) (string, error) {
func (it *GoMybatisSqlBuilder) BuildSql(paramMap map[string]interface{}, nodes []ast.Node, arg_array *[]interface{}, stmtConvert stmt.StmtIndexConvert) (string, error) {
//抽象语法树节点构建
var sql, err = ast.DoChildNodes(nodes, paramMap, arg_array)
var sql, err = ast.DoChildNodes(nodes, paramMap, arg_array, stmtConvert)
if err != nil {
return "", err
}
Expand Down
11 changes: 6 additions & 5 deletions GoMybatisSqlBuilder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/zhuxiujia/GoMybatis/example"
"github.com/zhuxiujia/GoMybatis/lib/github.com/Knetic/govaluate"
"github.com/zhuxiujia/GoMybatis/lib/github.com/beevik/etree"
"github.com/zhuxiujia/GoMybatis/stmt"
"github.com/zhuxiujia/GoMybatis/utils"
"testing"
"time"
Expand Down Expand Up @@ -63,7 +64,7 @@ func Benchmark_SqlBuilder(b *testing.B) {
b.StartTimer()
for i := 0; i < b.N; i++ {
var array = []interface{}{}
_, e := builder.BuildSql(paramMap, nodes, &array)
_, e := builder.BuildSql(paramMap, nodes, &array, &stmt.MysqlStmtIndexConvertImpl{})
if e != nil {
b.Fatal(e)
}
Expand Down Expand Up @@ -108,7 +109,7 @@ func Test_SqlBuilder_Tps(t *testing.T) {
for i := 0; i < 100000; i++ {
//var sql, e =
var array = []interface{}{}
_, e := builder.BuildSql(paramMap, nodes, &array)
_, e := builder.BuildSql(paramMap, nodes, &array, &stmt.MysqlStmtIndexConvertImpl{})
if e != nil {
t.Fatal(e)
}
Expand Down Expand Up @@ -206,7 +207,7 @@ func TestGoMybatisSqlBuilder_BuildSql(t *testing.T) {

var array = []interface{}{}

var sql, err = builder.BuildSql(paramMap, nodes, &array)
var sql, err = builder.BuildSql(paramMap, nodes, &array, &stmt.MysqlStmtIndexConvertImpl{})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -262,7 +263,7 @@ func Benchmark_SqlBuilder_If_Element(b *testing.B) {
b.StartTimer()
for i := 0; i < b.N; i++ {
var array = []interface{}{}
builder.BuildSql(paramMap, nodes, &array)
builder.BuildSql(paramMap, nodes, &array, &stmt.MysqlStmtIndexConvertImpl{})
}
}

Expand Down Expand Up @@ -316,7 +317,7 @@ func Benchmark_SqlBuilder_Nested(b *testing.B) {
b.StartTimer()
for i := 0; i < b.N; i++ {
var array = []interface{}{}
_, e := builder.BuildSql(paramMap, nodes, &array)
_, e := builder.BuildSql(paramMap, nodes, &array, &stmt.MysqlStmtIndexConvertImpl{})
if e != nil {
b.Fatal(e)
}
Expand Down
5 changes: 5 additions & 0 deletions LocalSession.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package GoMybatis
import (
"database/sql"
"errors"
"github.com/zhuxiujia/GoMybatis/stmt"
"github.com/zhuxiujia/GoMybatis/tx"
"github.com/zhuxiujia/GoMybatis/utils"
"strconv"
Expand Down Expand Up @@ -431,6 +432,10 @@ func (it *LocalSession) ExecPrepare(sqlPrepare string, args ...interface{}) (*Re
}
}

func (it *LocalSession) StmtConvert() (stmt.StmtIndexConvert, error) {
return stmt.BuildStmtConvert(it.driver)
}

func (it *LocalSession) dbErrorPack(e error) error {
if e != nil {
var sqlError = errors.New("[GoMybatis][LocalSession]" + e.Error())
Expand Down
2 changes: 1 addition & 1 deletion README-ch.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
MsSql: github.com/denisenkom/go-mssqldb
Oracle: github.com/mattn/go-oci8
//分布式NewSql数据库
Tidb: github.com/pingcap/tidb
Tidb: github.com/go-sql-driver/mysql
CockroachDB: github.com/lib/pq
```

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
MsSql: github.com/denisenkom/go-mssqldb
Oracle: github.com/mattn/go-oci8
//Distributed NewSql database
Tidb: github.com/pingcap/tidb
Tidb: github.com/go-sql-driver/mysql
CockroachDB: github.com/lib/pq
```

Expand Down
5 changes: 5 additions & 0 deletions SessionFactorySession.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package GoMybatis

import (
"github.com/zhuxiujia/GoMybatis/stmt"
"github.com/zhuxiujia/GoMybatis/tx"
"github.com/zhuxiujia/GoMybatis/utils"
)
Expand Down Expand Up @@ -74,3 +75,7 @@ func (it *SessionFactorySession) Close() {
func (it *SessionFactorySession) LastPROPAGATION() *tx.Propagation {
return it.Session.LastPROPAGATION()
}

func (it *SessionFactorySession) StmtConvert() (stmt.StmtIndexConvert, error) {
return it.Session.StmtConvert()
}
7 changes: 5 additions & 2 deletions SqlBuilder.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package GoMybatis

import "github.com/zhuxiujia/GoMybatis/ast"
import (
"github.com/zhuxiujia/GoMybatis/ast"
"github.com/zhuxiujia/GoMybatis/stmt"
)

//sql文本构建
type SqlBuilder interface {
BuildSql(paramMap map[string]interface{}, nodes []ast.Node, arg_array *[]interface{}) (string, error)
BuildSql(paramMap map[string]interface{}, nodes []ast.Node, arg_array *[]interface{}, stmtConvert stmt.StmtIndexConvert) (string, error)
ExpressionEngineProxy() *ExpressionEngineProxy
SqlArgTypeConvert() ast.SqlArgTypeConvert
SetEnableLog(enable bool)
Expand Down
4 changes: 3 additions & 1 deletion SqlEngine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package GoMybatis
import (
"database/sql"
"github.com/zhuxiujia/GoMybatis/ast"
"github.com/zhuxiujia/GoMybatis/stmt"
"github.com/zhuxiujia/GoMybatis/tx"
)

Expand All @@ -24,12 +25,13 @@ type Session interface {
Begin(p *tx.Propagation) error
Close()
LastPROPAGATION() *tx.Propagation
StmtConvert() (stmt.StmtIndexConvert, error)
}

//产生session的引擎
type SessionEngine interface {
//打开数据库
Open(driverName, dataSourceName string) (*sql.DB, error)
Open(driverName, dataSourceLink string) (*sql.DB, error)
//写方法到mapper
WriteMapperPtr(ptr interface{}, xml []byte)
//引擎名称
Expand Down
7 changes: 4 additions & 3 deletions ast/Node.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@ package ast

import (
"bytes"
"github.com/zhuxiujia/GoMybatis/stmt"
)

//sql构建抽象语法树节点
type Node interface {
Type() NodeType
Eval(env map[string]interface{}, arg_array *[]interface{}) ([]byte, error)
Eval(env map[string]interface{}, arg_array *[]interface{}, stmtConvert stmt.StmtIndexConvert) ([]byte, error)
}

//执行子所有节点
func DoChildNodes(childNodes []Node, env map[string]interface{}, arg_array *[]interface{}) ([]byte, error) {
func DoChildNodes(childNodes []Node, env map[string]interface{}, arg_array *[]interface{}, stmtConvert stmt.StmtIndexConvert) ([]byte, error) {
if childNodes == nil {
return nil, nil
}
var sql bytes.Buffer
for _, v := range childNodes {
var r, e = v.Eval(env, arg_array)
var r, e = v.Eval(env, arg_array, stmtConvert)
if e != nil {
return nil, e
}
Expand Down
4 changes: 3 additions & 1 deletion ast/NodeBind.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ast

import "github.com/zhuxiujia/GoMybatis/stmt"

type NodeBind struct {
t NodeType

Expand All @@ -13,7 +15,7 @@ func (it *NodeBind) Type() NodeType {
return NBind
}

func (it *NodeBind) Eval(env map[string]interface{}, arg_array *[]interface{}) ([]byte, error) {
func (it *NodeBind) Eval(env map[string]interface{}, arg_array *[]interface{}, stmtConvert stmt.StmtIndexConvert) ([]byte, error) {
if it.name == "" {
panic(`[GoMybatis] element <bind name = ""> name can not be nil!`)
}
Expand Down
Loading

0 comments on commit e1ecbc4

Please sign in to comment.