diff --git a/server/handler/link_auth.go b/server/handler/link_auth.go index cabe728a..fa42c96f 100644 --- a/server/handler/link_auth.go +++ b/server/handler/link_auth.go @@ -159,7 +159,6 @@ func tplRequest(typ int, w io.Writer, data RequestData) { _ = xml.EscapeText(buf, []byte(data.Banner)) data.Banner = buf.String() } - t, _ := template.New("auth_complete").Parse(auth_complete) _ = t.Execute(w, data) case tpl_otp: diff --git a/server/handler/link_auth_otp.go b/server/handler/link_auth_otp.go index 313aa63f..683c577d 100644 --- a/server/handler/link_auth_otp.go +++ b/server/handler/link_auth_otp.go @@ -2,26 +2,28 @@ package handler import ( "crypto/md5" - "crypto/rand" - "crypto/sha256" - "encoding/base64" "encoding/xml" "fmt" "io" "net" "net/http" "sync" + "sync/atomic" "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/pkg/utils" "github.com/bjdgyc/anylink/sessdata" ) var SessStore = NewSessionStore() +const maxOtpErrCount = 3 + type AuthSession struct { ClientRequest *ClientRequest UserActLog *dbdata.UserActLog + OtpErrCount atomic.Uint32 // otp错误次数 } // 存储临时会话信息 @@ -60,15 +62,17 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) { delete(s.session, sessionID) } +func (a *AuthSession) AddOtpErrCount(i int) int { + newI := a.OtpErrCount.Add(uint32(i)) + return int(newI) +} + func GenerateSessionID() (string, error) { - b := make([]byte, 32) - _, err := rand.Read(b) - if err != nil { - return "", fmt.Errorf("failed to generate session ID: %w", err) + sessionID := utils.RandomRunes(32) + if sessionID == "" { + return "", fmt.Errorf("failed to generate session ID") } - hash := sha256.Sum256(b) - sessionID := base64.URLEncoding.EncodeToString(hash[:]) return sessionID, nil } @@ -186,14 +190,20 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) { otpSecret := sessionData.ClientRequest.Auth.OtpSecret otp := cr.Auth.SecondaryPassword + // 动态码错误 if !dbdata.CheckOtp(username, otp, otpSecret) { - base.Warn("OTP 动态码错误", r.RemoteAddr) + if sessionData.AddOtpErrCount(1) > maxOtpErrCount { + http.Error(w, "TooManyError, please login again", http.StatusBadRequest) + return + } + + base.Warn("OTP 动态码错误", username, r.RemoteAddr) ua.Info = "OTP 动态码错误" ua.Status = dbdata.UserAuthFail dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent) w.WriteHeader(http.StatusOK) - data := RequestData{Error: "OTP 动态码错误"} + data := RequestData{Error: "请求错误"} if base.Cfg.DisplayError { data.Error = "OTP 动态码错误" } @@ -216,7 +226,7 @@ var auth_otp = ` 验证失败: %s {{end}}
- +
` diff --git a/server/pkg/utils/util.go b/server/pkg/utils/util.go index 2b17f262..54f7df1b 100644 --- a/server/pkg/utils/util.go +++ b/server/pkg/utils/util.go @@ -1,7 +1,10 @@ package utils import ( + crand "crypto/rand" + "encoding/hex" "fmt" + "log" "math/rand" "strings" "sync/atomic" @@ -83,9 +86,7 @@ func HumanByte(bf interface{}) string { func RandomRunes(length int) string { letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890") - bytes := make([]rune, length) - for i := range bytes { bytes[i] = letterRunes[rand.Intn(len(letterRunes))] } @@ -93,6 +94,17 @@ func RandomRunes(length int) string { return string(bytes) } +func RandomHex(length int) string { + b := make([]byte, length) + _, err := crand.Read(b) + if err != nil { + log.Println(err) + return "" + } + + return hex.EncodeToString(b) +} + func ParseName(name string) string { name = strings.ReplaceAll(name, " ", "-") name = strings.ReplaceAll(name, "'", "-") diff --git a/server/sessdata/session.go b/server/sessdata/session.go index 60fae1d8..98854192 100644 --- a/server/sessdata/session.go +++ b/server/sessdata/session.go @@ -2,7 +2,6 @@ package sessdata import ( "fmt" - "math/rand" "net" "strconv" "strings" @@ -12,6 +11,7 @@ import ( "github.com/bjdgyc/anylink/base" "github.com/bjdgyc/anylink/dbdata" + "github.com/bjdgyc/anylink/pkg/utils" mapset "github.com/deckarep/golang-set" ) @@ -91,10 +91,6 @@ type Session struct { CSess *ConnSession } -func init() { - rand.Seed(time.Now().UnixNano()) -} - func checkSession() { // 检测过期的session go func() { @@ -144,28 +140,16 @@ func CloseUserLimittimeSession() { } } -func GenToken() string { - // 生成32位的 token - bToken := make([]byte, 32) - rand.Read(bToken) - return fmt.Sprintf("%x", bToken) -} - func NewSession(token string) *Session { if token == "" { - btoken := make([]byte, 32) - rand.Read(btoken) - token = fmt.Sprintf("%x", btoken) + token = utils.RandomHex(32) } // 生成 dtlsn session_id - dtlsid := make([]byte, 32) - rand.Read(dtlsid) - sess := &Session{ Sid: fmt.Sprintf("%d", time.Now().Unix()), Token: token, - DtlsSid: fmt.Sprintf("%x", dtlsid), + DtlsSid: utils.RandomHex(32), LastLogin: time.Now(), }