Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuxiujia committed Dec 10, 2018
1 parent c08c3b7 commit fdbd738
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 65 deletions.
77 changes: 66 additions & 11 deletions GoMybatis.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func WriteMapperPtrByEngine(ptr interface{}, xml []byte, sessionEngine *SessionE
//func的结构体参数无需指定mapperParams的tag,框架会自动扫描它的属性,封装为map处理掉
//使用WriteMapper函数设置代理后即可正常使用。
func WriteMapper(bean reflect.Value, xml []byte, sessionFactory *SessionFactory, decoder SqlResultDecoder, sqlBuilder SqlBuilder, enableLog bool) {
beanCheck(bean)
var mapperTree = LoadMapperXml(xml)
//make a map[method]xml
var methodXmlMap = makeMethodXmlMap(bean, mapperTree)
Expand Down Expand Up @@ -82,6 +83,29 @@ func WriteMapper(bean reflect.Value, xml []byte, sessionFactory *SessionFactory,
UseMapperValue(bean, proxyFunc)
}

//check beans
func beanCheck(value reflect.Value) {
var t = value.Type()
if value.Kind() == reflect.Ptr {
value = value.Elem()
t = value.Type()
}
for i := 0; i < t.NumField(); i++ {
var fieldItem = t.Field(i)
var argsLen = fieldItem.Type.NumIn() //参数长度,除session参数外。
var customLen = 0
for argIndex := 0; argIndex < fieldItem.Type.NumIn(); argIndex++ {
var inType = fieldItem.Type.In(argIndex)
if isCustomStruct(inType) {
customLen++
}
}
if argsLen > 1 && customLen > 1 {
panic(`[GoMybats] ` + fieldItem.Name + ` must add tag "mapperParams:"*,*..."`)
}
}
}

func buildReturnValues(returnType *ReturnType, returnValue *reflect.Value, e error) []reflect.Value {
var returnValues = make([]reflect.Value, returnType.NumOut)
for index, _ := range returnValues {
Expand Down Expand Up @@ -263,6 +287,9 @@ func buildSql(tagArgs []TagArg, args []reflect.Value, mapperXml *MapperXml, sqlB
var session Session
var paramMap = make(map[string]SqlArg)
var tagArgsLen = len(tagArgs)
var argsLen = len(args) //参数长度,除session参数外。
var customLen = 0
var customIndex = -1
for argIndex, arg := range args {
var argInterface = arg.Interface()
if arg.Kind() == reflect.Ptr && arg.IsNil() == false && argInterface != nil && arg.Type().String() == GoMybatis_Session_Ptr {
Expand All @@ -272,31 +299,51 @@ func buildSql(tagArgs []TagArg, args []reflect.Value, mapperXml *MapperXml, sqlB
session = argInterface.(Session)
continue
}
if arg.Kind() == reflect.Struct && arg.Type().String() != GoMybatis_Time {
paramMap = scanStructArgFields(argInterface, nil)
} else if tagArgsLen > 0 && argIndex < tagArgsLen && tagArgs[argIndex].Name != "" && argInterface != nil {
if isCustomStruct(arg.Type()) {
customLen++
customIndex = argIndex
}
if arg.Type().String() == GoMybatis_Session_Ptr || arg.Type().String() == GoMybatis_Session {
if argsLen > 0 {
argsLen--
}
if tagArgsLen > 0 {
tagArgsLen --
}
}
if tagArgsLen > 0 && argIndex < tagArgsLen && tagArgs[argIndex].Name != "" {
paramMap[tagArgs[argIndex].Name] = SqlArg{
Value: argInterface,
Type: arg.Type(),
}
} else {
if arg.Kind() != reflect.Ptr {
paramMap[DefaultOneArg] = SqlArg{
Value: argInterface,
Type: arg.Type(),
}
paramMap[DefaultOneArg] = SqlArg{
Value: argInterface,
Type: arg.Type(),
}
}
}
if customLen == 1 && customIndex != -1 {
//只有一个结构体参数,需要展开它的成员变量 加入到map
paramMap = scanStructArgFields(args[customIndex], nil)
}

result, err := sqlBuilder.BuildSql(paramMap, mapperXml, enableLog)
return session, result, err
}

//scan params
func scanStructArgFields(arg interface{}, typeConvert func(arg interface{}) interface{}) map[string]SqlArg {
func scanStructArgFields(v reflect.Value, typeConvert func(arg interface{}) interface{}) map[string]SqlArg {
var t = v.Type()
parameters := make(map[string]SqlArg)
v := reflect.ValueOf(arg)
t := reflect.TypeOf(arg)
if v.Kind() == reflect.Ptr {
if v.IsNil() == true {
return parameters
}
//为指针,解引用
v = v.Elem()
t = t.Elem()
}
if t.Kind() != reflect.Struct {
panic(`[GoMybatis] the scanParamterBean() arg is not a struct type!,type =` + t.String())
}
Expand All @@ -321,3 +368,11 @@ func scanStructArgFields(arg interface{}, typeConvert func(arg interface{}) inte
}
return parameters
}

func isCustomStruct(value reflect.Type) bool {
if value.Kind() == reflect.Struct && value.String() != GoMybatis_Time && value.String() != GoMybatis_Time_Ptr {
return true
} else {
return false
}
}
1 change: 1 addition & 0 deletions GoMybatisEnableType.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ package GoMybatis
const GoMybatis_Session_Ptr = `*GoMybatis.Session`
const GoMybatis_Session = `GoMybatis.Session`
const GoMybatis_Time = `time.Time`
const GoMybatis_Time_Ptr = `*time.Time`
19 changes: 12 additions & 7 deletions GoMybatisProxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,29 @@ func buildMapper(v reflect.Value, proxyFunc func(method string, args []reflect.V
}

func buildRemoteMethod(f reflect.Value, ft reflect.Type, sf reflect.StructField, proxyFunc func(method string, args []reflect.Value, tagArgs []TagArg) []reflect.Value) {
var params []string
var tagParams []string
var mapperParams = sf.Tag.Get(`mapperParams`)
if mapperParams != `` {
params = strings.Split(mapperParams, `,`)
tagParams = strings.Split(mapperParams, `,`)
}
if len(params) > ft.NumIn() {
panic(`[GoMybatisProxy] method fail! the tag "mapperParams" length can not > arg length! filed=` + ft.String())
var tagParamsLen = len(tagParams)
if tagParamsLen > ft.NumIn() {
panic(`[GoMybatisProxy] method fail! the tag "mapperParams" length can not > arg length ! filed=` + sf.Name)
}
var tagArgs = make([]TagArg, 0)
if len(params) != 0 {
for index, v := range params {
if tagParamsLen != 0 {
for index, v := range tagParams {
var tagArg = TagArg{
Index: index,
Name: v,
}
tagArgs = append(tagArgs, tagArg)
}
}
var tagArgsLen = len(tagArgs)
if tagArgsLen > 0 && ft.NumIn() != tagArgsLen {
panic(`[GoMybatisProxy] method fail! the tag "mapperParams" length != args length ! filed = ` + sf.Name)
}
var fn = func(args []reflect.Value) (results []reflect.Value) {
proxyResults := proxyFunc(sf.Name, args, tagArgs)
for _, returnV := range proxyResults {
Expand All @@ -91,5 +96,5 @@ func buildRemoteMethod(f reflect.Value, ft reflect.Type, sf reflect.StructField,
} else {
f.Set(reflect.MakeFunc(ft, fn))
}
params = nil
tagParams = nil
}
11 changes: 4 additions & 7 deletions TransactionFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@ func (this TransactionFactory) New(SessionFactory *SessionFactory) TransactionFa
return this
}

func (this *TransactionFactory) GetTransactionStatus(transactionId string) *TransactionStatus {
func (this *TransactionFactory) GetTransactionStatus(transactionId string) (*TransactionStatus, error) {
var Session Session
if transactionId == "" {
Session = this.SessionFactory.NewSession(SessionType_Default, nil)
transactionId = Session.Id()
}
var result = this.TransactionStatuss[transactionId]
if result == nil {
Session = this.SessionFactory.NewSession(SessionType_Default, nil)
var transaction = Transaction{
Id: transactionId,
Session: Session,
Expand All @@ -30,7 +27,7 @@ func (this *TransactionFactory) GetTransactionStatus(transactionId string) *Tran
result = &transactionStatus
this.TransactionStatuss[transactionId] = result
}
return result
return result, nil
}

func (this *TransactionFactory) SetTransactionStatus(transactionId string, transaction *TransactionStatus) {
Expand All @@ -44,7 +41,7 @@ func (this *TransactionFactory) Append(transactionId string, transaction Transac
if transactionId == "" {
return
}
var old = this.GetTransactionStatus(transactionId)
var old,_ = this.GetTransactionStatus(transactionId)
if old != nil {
this.SetTransactionStatus(transactionId, old)
}
Expand Down
49 changes: 38 additions & 11 deletions TransactionManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,24 @@ func (this DefaultTransationManager) New(SessionFactory *SessionFactory, Transac
}

func (this DefaultTransationManager) GetTransaction(def *TransactionDefinition, transactionId string, OwnerId string) (*TransactionStatus, error) {
//if transactionId == "" {
// return nil, errors.New("[TransactionManager] transactionId =" + transactionId + " transations is nil!")
//}
if transactionId == "" {
return nil, errors.New("[TransactionManager] transactionId =" + transactionId + " transations is nil!")
}
if def == nil {
var d = TransactionDefinition{}.Default()
def = &d
}
var transationStatus = this.TransactionFactory.GetTransactionStatus(transactionId)
var transationStatus, err = this.TransactionFactory.GetTransactionStatus(transactionId)
if err != nil {
return nil, err
}
if def.PropagationBehavior == PROPAGATION_REQUIRED {
//todo doBegin
if transationStatus.IsNewTransaction {
//新事务,则调用begin
transationStatus.OwnerId = OwnerId
var err = transationStatus.Begin()
if err != nil {
if err == nil {
if def.Timeout != 0 {
//transation out of time,default not set out of time
//事务超时,时间大于0则启动超时机制
Expand All @@ -84,12 +87,20 @@ func (this DefaultTransationManager) GetTransaction(def *TransactionDefinition,
}

func (this DefaultTransationManager) Commit(transactionId string) error {
var transactions = this.TransactionFactory.GetTransactionStatus(transactionId)
var transactions, err = this.TransactionFactory.GetTransactionStatus(transactionId)
if err != nil {
log.Println(err)
return err
}
return transactions.Commit()
}

func (this DefaultTransationManager) Rollback(transactionId string) error {
var transactions = this.TransactionFactory.GetTransactionStatus(transactionId)
var transactions, err = this.TransactionFactory.GetTransactionStatus(transactionId)
if err != nil {
log.Println(err)
return err
}
return transactions.Rollback()
}

Expand All @@ -99,14 +110,25 @@ func (this DefaultTransationManager) DoTransaction(dto TransactionReqDTO) Transa
var err error

transcationStatus, err = this.GetTransaction(nil, dto.TransactionId, dto.OwnerId)
dto.TransactionId = transcationStatus.Transaction.Session.Id()
if transcationStatus == nil || transcationStatus.Transaction == nil || transcationStatus.Transaction.Session == nil {
return TransactionRspDTO{
TransactionId: dto.TransactionId,
Error: "Transaction does not exist,id=" + dto.TransactionId,
}
}
if err != nil {
return TransactionRspDTO{
TransactionId: dto.TransactionId,
Error: err.Error(),
}
}
if err != nil {
return TransactionRspDTO{
TransactionId: dto.TransactionId,
Error: err.Error(),
}
}
log.Println("[TransactionManager] transactionId=", dto.TransactionId)
log.Println("[TransactionManager] do transactionId=", dto.TransactionId,",sessionId=",transcationStatus.Transaction.Session.Id())

if dto.Status == Transaction_Status_NO {
defer transcationStatus.Flush() //关闭
Expand All @@ -123,7 +145,12 @@ func (this DefaultTransationManager) DoTransaction(dto TransactionReqDTO) Transa
Error: err.Error(),
}
}
this.TransactionFactory.GetTransactionStatus(dto.TransactionId).Flush()
var transaction, err = this.TransactionFactory.GetTransactionStatus(dto.TransactionId)
if err != nil {
log.Println(err)
} else {
transaction.Flush()
}
}
} else if dto.Status == Transaction_Status_Rollback {
defer transcationStatus.Flush() //关闭,//PROPAGATION_REQUIRED 情况下 子事务 可关闭
Expand Down Expand Up @@ -163,7 +190,7 @@ func (this DefaultTransationManager) DoAction(dto TransactionReqDTO, transcation
return TransactionRspDTO
}
if dto.ActionType == ActionType_Exec {
log.Println("[TransactionManager] Exec ", dto.Sql)
log.Println("[TransactionManager] TransactionId:",dto.TransactionId,",Exec:", dto.Sql)
var res, e = transcationStatus.Transaction.Session.Exec(dto.Sql)
var err string
if e != nil {
Expand Down
Loading

0 comments on commit fdbd738

Please sign in to comment.