Skip to content

Commit

Permalink
Deliver response code with context (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
crazytaxii authored Nov 25, 2024
1 parent 5ef533c commit 7fc7d47
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 47 deletions.
68 changes: 57 additions & 11 deletions api/server/httputils/httputils.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package httputils

import (
"context"
goerrors "errors"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -49,12 +50,6 @@ func (r *Response) SetMessage(m interface{}) {
}
}

func (r *Response) IsSuccessful() bool {
return r.Code == http.StatusOK ||
r.Code == http.StatusCreated ||
r.Code == http.StatusAccepted
}

func (r *Response) SetMessageWithCode(m interface{}, c int) {
r.SetCode(c)
r.SetMessage(m)
Expand All @@ -79,6 +74,7 @@ func NewResponse() *Response {

// SetSuccess 设置成功返回值
func SetSuccess(c *gin.Context, r *Response) {
_ = contextBind(c).withResponseCode(http.StatusOK)
r.SetMessageWithCode("success", http.StatusOK)
c.JSON(http.StatusOK, r)
}
Expand All @@ -87,28 +83,31 @@ func SetSuccess(c *gin.Context, r *Response) {
func SetFailed(c *gin.Context, r *Response, err error) {
switch e := err.(type) {
case errors.Error:
SetFailedWithCode(c, r, e.Code, e)
setFailedWithCode(c, r, e.Code, e)
case validator.ValidationErrors:
SetFailedWithValidationError(c, r, validatorutil.TranslateError(e))
setFailedWithValidationError(c, r, validatorutil.TranslateError(e))
default:
SetFailedWithCode(c, r, http.StatusBadRequest, err)
setFailedWithCode(c, r, http.StatusBadRequest, err)
}
}

// SetFailedWithCode 设置错误返回值
func SetFailedWithCode(c *gin.Context, r *Response, code int, err error) {
func setFailedWithCode(c *gin.Context, r *Response, code int, err error) {
_ = contextBind(c).withResponseCode(code).withRawError(err)
r.SetMessageWithCode(err, code)
c.JSON(http.StatusOK, r)
}

func SetFailedWithValidationError(c *gin.Context, r *Response, e string) {
func setFailedWithValidationError(c *gin.Context, r *Response, e string) {
_ = contextBind(c).withResponseCode(http.StatusBadRequest).withRawError(goerrors.New(e))
r.SetMessageWithCode(e, http.StatusBadRequest)
c.JSON(http.StatusOK, r)
}

// AbortFailedWithCode 设置错误,code 返回值并终止请求
func AbortFailedWithCode(c *gin.Context, code int, err error) {
r := NewResponse()
_ = contextBind(c).withResponseCode(code).withRawError(err)
r.SetMessageWithCode(err, code)
c.JSON(http.StatusOK, r)
c.Abort()
Expand Down Expand Up @@ -202,3 +201,50 @@ func GetIdRangeFromListReq(ctx context.Context) (exists bool, ids []int64) {
ids, exists = val.([]int64)
return
}

const (
ResponseCodeKey = "response_code"
RawErrorKey = "raw_error"
)

type ctxBind struct {
*gin.Context
}

func contextBind(c *gin.Context) *ctxBind {
return &ctxBind{c}
}

// withResponseCode puts the response code into the HTTP context.
func (cb *ctxBind) withResponseCode(code int) *ctxBind {
cb.Set(ResponseCodeKey, code)
return cb
}

// withRawError puts the raw error into the HTTP context.
func (cb *ctxBind) withRawError(err error) *ctxBind {
cb.Set(RawErrorKey, err)
return cb
}

// GetResponseCode gets the response code from the HTTP context.
func GetResponseCode(ctx context.Context) (code int) {
val := ctx.Value(ResponseCodeKey)
if val == nil {
return
}

code = val.(int)
return
}

// GetRawError gets the raw error from the HTTP context.
func GetRawError(ctx context.Context) (err error) {
val := ctx.Value(RawErrorKey)
if val == nil {
return
}

err = val.(error)
return
}
48 changes: 24 additions & 24 deletions api/server/middleware/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ limitations under the License.
package middleware

import (
"bytes"
"context"
"encoding/json"
"net/http"

"github.com/gin-contrib/requestid"
Expand All @@ -33,28 +31,18 @@ import (

// 自定义 ResponseWriter 用于捕获写入的数据
type auditWriter struct {
gin.ResponseWriter
resp *httputils.Response
opts *options.Options
}

func newResponseWriter(w gin.ResponseWriter, o *options.Options) *auditWriter {
func newResponseWriter(o *options.Options) *auditWriter {
return &auditWriter{
ResponseWriter: w,
resp: httputils.NewResponse(),
opts: o,
opts: o,
}
}

func (w *auditWriter) Write(b []byte) (int, error) {
_ = json.NewDecoder(bytes.NewReader(b)).Decode(w.resp)
return w.ResponseWriter.Write(b)
}

func Audit(o *options.Options) gin.HandlerFunc {
return func(c *gin.Context) {
auditor := newResponseWriter(c.Writer, o)
c.Writer = auditor
auditor := newResponseWriter(o)
c.Next()

// do audit asynchronously
Expand All @@ -80,24 +68,36 @@ func (w *auditWriter) asyncAudit(c *gin.Context) {
return
}

status := model.AuditOpUnknown
if w.resp != nil {
status = model.AuditOpFail
if w.resp.IsSuccessful() {
status = model.AuditOpSuccess
}
}

audit := &model.Audit{
RequestId: requestid.Get(c),
Action: c.Request.Method,
Ip: c.ClientIP(),
Operator: userName,
Path: c.Request.RequestURI,
ObjectType: model.ObjectType(obj),
Status: status,
Status: getAuditStatus(c),
}
if _, err := w.opts.Factory.Audit().Create(context.TODO(), audit); err != nil {
klog.Errorf("failed to create audit record [%s]: %v", audit.String(), err)
}
}

// getAuditStatus returns the status of operation.
func getAuditStatus(c *gin.Context) model.AuditOperationStatus {
respCode := httputils.GetResponseCode(c)
if respCode == 0 {
return model.AuditOpUnknown
}

if responseOK(respCode) {
return model.AuditOpSuccess
}

return model.AuditOpFail
}

func responseOK(code int) bool {
return code == http.StatusOK ||
code == http.StatusCreated ||
code == http.StatusAccepted
}
19 changes: 7 additions & 12 deletions api/server/middleware/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ limitations under the License.
package middleware

import (
"fmt"

"github.com/gin-contrib/requestid"
"github.com/gin-gonic/gin"

"github.com/caoyingjunz/pixiu/api/server/httputils"
"github.com/caoyingjunz/pixiu/pkg/db"
logutil "github.com/caoyingjunz/pixiu/pkg/util/log"
)
Expand All @@ -34,17 +33,13 @@ func Logger(cfg *logutil.LogOptions) gin.HandlerFunc {
// 处理请求操作
c.Next()

var err error
if errs := c.Errors; len(errs) > 0 {
err = fmt.Errorf("%v", errs.Errors())
}
l.WithLogFields(map[string]interface{}{
"request_id": requestid.Get(c),
"method": c.Request.Method,
"uri": c.Request.RequestURI,
"status_code": c.Writer.Status(),
"client_ip": c.ClientIP(),
"request_id": requestid.Get(c),
"method": c.Request.Method,
"uri": c.Request.RequestURI,
httputils.ResponseCodeKey: httputils.GetResponseCode(c),
"client_ip": c.ClientIP(),
})
l.Log(c, logutil.InfoLevel, err)
l.Log(c, logutil.InfoLevel, httputils.GetRawError(c))
}
}

0 comments on commit 7fc7d47

Please sign in to comment.