Skip to content

Commit

Permalink
优化代码
Browse files Browse the repository at this point in the history
  • Loading branch information
bjdgyc committed Oct 24, 2024
1 parent 772b111 commit bd6ee0b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 34 deletions.
1 change: 0 additions & 1 deletion server/handler/link_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ func tplRequest(typ int, w io.Writer, data RequestData) {
_ = xml.EscapeText(buf, []byte(data.Banner))
data.Banner = buf.String()
}

t, _ := template.New("auth_complete").Parse(auth_complete)
_ = t.Execute(w, data)
case tpl_otp:
Expand Down
34 changes: 22 additions & 12 deletions server/handler/link_auth_otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@ package handler

import (
"crypto/md5"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"net"
"net/http"
"sync"
"sync/atomic"

"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/pkg/utils"
"github.com/bjdgyc/anylink/sessdata"
)

var SessStore = NewSessionStore()

const maxOtpErrCount = 3

type AuthSession struct {
ClientRequest *ClientRequest
UserActLog *dbdata.UserActLog
OtpErrCount atomic.Uint32 // otp错误次数
}

// 存储临时会话信息
Expand Down Expand Up @@ -60,15 +62,17 @@ func (s *SessionStore) DeleteAuthSession(sessionID string) {
delete(s.session, sessionID)
}

func (a *AuthSession) AddOtpErrCount(i int) int {
newI := a.OtpErrCount.Add(uint32(i))
return int(newI)
}

func GenerateSessionID() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
sessionID := utils.RandomRunes(32)
if sessionID == "" {
return "", fmt.Errorf("failed to generate session ID")
}

hash := sha256.Sum256(b)
sessionID := base64.URLEncoding.EncodeToString(hash[:])
return sessionID, nil
}

Expand Down Expand Up @@ -186,14 +190,20 @@ func LinkAuth_otp(w http.ResponseWriter, r *http.Request) {
otpSecret := sessionData.ClientRequest.Auth.OtpSecret
otp := cr.Auth.SecondaryPassword

// 动态码错误
if !dbdata.CheckOtp(username, otp, otpSecret) {
base.Warn("OTP 动态码错误", r.RemoteAddr)
if sessionData.AddOtpErrCount(1) > maxOtpErrCount {
http.Error(w, "TooManyError, please login again", http.StatusBadRequest)
return
}

base.Warn("OTP 动态码错误", username, r.RemoteAddr)
ua.Info = "OTP 动态码错误"
ua.Status = dbdata.UserAuthFail
dbdata.UserActLogIns.Add(*ua, sessionData.ClientRequest.UserAgent)

w.WriteHeader(http.StatusOK)
data := RequestData{Error: "OTP 动态码错误"}
data := RequestData{Error: "请求错误"}
if base.Cfg.DisplayError {
data.Error = "OTP 动态码错误"
}
Expand All @@ -216,7 +226,7 @@ var auth_otp = `<?xml version="1.0" encoding="UTF-8"?>
<error id="otp-verification" param1="{{.Error}}" param2="">验证失败: %s</error>
{{end}}
<form method="post" action="/otp-verification">
<input type="password" name="secondary_password" label="OTP"/>
<input type="password" name="secondary_password" label="OTPCode:"/>
</form>
</auth>
</config-auth>`
16 changes: 14 additions & 2 deletions server/pkg/utils/util.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package utils

import (
crand "crypto/rand"
"encoding/hex"
"fmt"
"log"
"math/rand"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -83,16 +86,25 @@ func HumanByte(bf interface{}) string {

func RandomRunes(length int) string {
letterRunes := []rune("abcdefghijklmnpqrstuvwxy1234567890")

bytes := make([]rune, length)

for i := range bytes {
bytes[i] = letterRunes[rand.Intn(len(letterRunes))]
}

return string(bytes)
}

func RandomHex(length int) string {
b := make([]byte, length)
_, err := crand.Read(b)
if err != nil {
log.Println(err)
return ""
}

return hex.EncodeToString(b)
}

func ParseName(name string) string {
name = strings.ReplaceAll(name, " ", "-")
name = strings.ReplaceAll(name, "'", "-")
Expand Down
22 changes: 3 additions & 19 deletions server/sessdata/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package sessdata

import (
"fmt"
"math/rand"
"net"
"strconv"
"strings"
Expand All @@ -12,6 +11,7 @@ import (

"github.com/bjdgyc/anylink/base"
"github.com/bjdgyc/anylink/dbdata"
"github.com/bjdgyc/anylink/pkg/utils"
mapset "github.com/deckarep/golang-set"
)

Expand Down Expand Up @@ -91,10 +91,6 @@ type Session struct {
CSess *ConnSession
}

func init() {
rand.Seed(time.Now().UnixNano())
}

func checkSession() {
// 检测过期的session
go func() {
Expand Down Expand Up @@ -144,28 +140,16 @@ func CloseUserLimittimeSession() {
}
}

func GenToken() string {
// 生成32位的 token
bToken := make([]byte, 32)
rand.Read(bToken)
return fmt.Sprintf("%x", bToken)
}

func NewSession(token string) *Session {
if token == "" {
btoken := make([]byte, 32)
rand.Read(btoken)
token = fmt.Sprintf("%x", btoken)
token = utils.RandomHex(32)
}

// 生成 dtlsn session_id
dtlsid := make([]byte, 32)
rand.Read(dtlsid)

sess := &Session{
Sid: fmt.Sprintf("%d", time.Now().Unix()),
Token: token,
DtlsSid: fmt.Sprintf("%x", dtlsid),
DtlsSid: utils.RandomHex(32),
LastLogin: time.Now(),
}

Expand Down

0 comments on commit bd6ee0b

Please sign in to comment.