diff --git a/mu/func.go b/mu/func.go index 8e850a2f..d96e02f9 100644 --- a/mu/func.go +++ b/mu/func.go @@ -25,61 +25,65 @@ import ( "time" ) +const ( + idType = 0 // address type index + idIP0 = 1 // ip addres start index + idDmLen = 1 // domain address length index + idDm0 = 2 // domain address start index + + typeIPv4 = 1 // type is ipv4 address + typeDm = 3 // type is domain address + typeIPv6 = 4 // type is ipv6 address + + lenIPv4 = net.IPv4len + 2 // ipv4 + 2port + lenIPv6 = net.IPv6len + 2 // ipv6 + 2port + lenDmBase = 2 // 1addrLen + 2port, plus addrLen + lenHmacSha1 = 10 +) + var ssdebug ss.DebugLog -func getRequest(conn *ss.Conn) (host string, extra []byte, err error) { - const ( - idType = 0 // address type index - idIP0 = 1 // ip addres start index - idDmLen = 1 // domain address length index - idDm0 = 2 // domain address start index - - typeIPv4 = 1 // type is ipv4 address - typeDm = 3 // type is domain address - typeIPv6 = 4 // type is ipv6 address - - lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port - lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port - lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen - ) +func getRequest(conn *ss.Conn, auth bool) (host string, res_size int, ota bool, err error) { + var n int + ss.SetReadTimeout(conn) // buf size should at least have the same size with the largest possible // request size (when addrType is 3, domain name has at most 256 bytes) - // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) - buf := make([]byte, 260) - var n int + // 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1) + buf := make([]byte, 270) // read till we get possible domain length field - ss.SetReadTimeout(conn) - if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil { + if n, err = io.ReadFull(conn, buf[:idType+1]); err != nil { return } + res_size += n - reqLen := -1 - switch buf[idType] { + var reqStart, reqEnd int + addrType := buf[idType] + switch addrType & ss.AddrMask { case typeIPv4: - reqLen = lenIPv4 + reqStart, reqEnd = idIP0, idIP0+lenIPv4 case typeIPv6: - reqLen = lenIPv6 + reqStart, reqEnd = idIP0, idIP0+lenIPv6 case typeDm: - reqLen = int(buf[idDmLen]) + lenDmBase + if n, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil { + return + } + reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase) default: - err = fmt.Errorf("addr type %d not supported", buf[idType]) + err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask) return } + res_size += n - if n < reqLen { // rare case - if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil { - return - } - } else if n > reqLen { - // it's possible to read more than just the request head - extra = buf[reqLen:n] + if n, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil { + return } + res_size += n // Return string for typeIP is not most efficient, but browsers (Chrome, // Safari, Firefox) all seems using typeDm exclusively. So this is not a // big problem. - switch buf[idType] { + switch addrType & ss.AddrMask { case typeIPv4: host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String() case typeIPv6: @@ -88,8 +92,23 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) { host = string(buf[idDm0 : idDm0+buf[idDmLen]]) } // parse port - port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen]) + port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd]) host = net.JoinHostPort(host, strconv.Itoa(int(port))) + // if specified one time auth enabled, we should verify this + if auth || addrType&ss.OneTimeAuthMask > 0 { + ota = true + if n, err = io.ReadFull(conn, buf[reqEnd:reqEnd+lenHmacSha1]); err != nil { + return + } + iv := conn.GetIv() + key := conn.GetKey() + actualHmacSha1Buf := ss.HmacSha1(append(iv, key...), buf[:reqEnd]) + if !bytes.Equal(buf[reqEnd:reqEnd+lenHmacSha1], actualHmacSha1Buf) { + err = fmt.Errorf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, buf[:reqEnd]) + return + } + res_size += n + } return } @@ -98,13 +117,9 @@ const logCntDelta = 100 var connCnt int var nextLogConnCnt int = logCntDelta -func handleConnection(user user.User, conn *ss.Conn) { +func handleConnection(user user.User, conn *ss.Conn, auth bool) { var host string - var size = 0 - var raw_req_header, raw_res_header []byte - var is_http = false - var res_size = 0 - var req_chan = make(chan []byte) + connCnt++ // this maybe not accurate, but should be enough if connCnt-nextLogConnCnt >= 0 { // XXX There's no xadd in the atomic package, so it's difficult to log @@ -128,7 +143,7 @@ func handleConnection(user user.User, conn *ss.Conn) { } }() - host, extra, err := getRequest(conn) + host, res_size, ota, err := getRequest(conn, auth) if err != nil { Log.Error("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err) return @@ -151,65 +166,30 @@ func handleConnection(user user.User, conn *ss.Conn) { } }() - defer func() { - if is_http { - tmp_req_header := <-req_chan - buffer := bytes.NewBuffer(raw_req_header) - buffer.Write(tmp_req_header) - raw_req_header = buffer.Bytes() - } - showConn(raw_req_header, raw_res_header, host, user, size, is_http) - close(req_chan) - if !closed { - remote.Close() - } - }() + // debug conn info + Log.Debug(fmt.Sprintf("%d conn debug: local addr: %s | remote addr: %s network: %s ", user.GetPort(), + conn.LocalAddr().String(), conn.RemoteAddr().String(), conn.RemoteAddr().Network())) + err = storage.IncrSize(user, res_size) + if err != nil { + Log.Error(err) + return + } + err = storage.MarkUserOnline(user) + if err != nil { + Log.Error(err) + return + } + Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), res_size)) - // write extra bytes read from - if extra != nil { - // debug.Println("getRequest read extra data, writing to remote, len", len(extra)) - is_http, extra, _ = checkHttp(extra, conn) - if strings.HasSuffix(host, ":80") { - is_http = true - } - raw_req_header = extra - res_size, err = remote.Write(extra) - // size, err := remote.Write(extra) - if err != nil { - Log.Error("write request extra error:", err) - return - } - // debug conn info - Log.Debug(fmt.Sprintf("%d conn debug: local addr: %s | remote addr: %s network: %s ", user.GetPort(), - conn.LocalAddr().String(), conn.RemoteAddr().String(), conn.RemoteAddr().Network())) - err = storage.IncrSize(user, res_size) - if err != nil { - Log.Error(err) - return - } - err = storage.MarkUserOnline(user) - if err != nil { - Log.Error(err) - return - } - Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), res_size)) + Log.Info(fmt.Sprintf("piping %s<->%s ota=%v connOta=%v", conn.RemoteAddr(), host, ota, conn.IsOta())) + + if ota { + go PipeThenCloseOta(conn, remote, false, host, user) + } else { + go PipeThenClose(conn, remote, false, host, user) } - Log.Debug(fmt.Sprintf("piping %s<->%s", conn.RemoteAddr(), host)) - /** - go ss.PipeThenClose(conn, remote) - ss.PipeThenClose(remote, conn) - closed = true - return - **/ - go func() { - _, raw_header := PipeThenClose(conn, remote, is_http, false, host, user) - if is_http { - req_chan <- raw_header - } - }() - res_size, raw_res_header = PipeThenClose(remote, conn, is_http, true, host, user) - size += res_size + PipeThenClose(remote, conn, true, host, user) closed = true return } @@ -273,7 +253,7 @@ func runWithCustomMethod(user user.User) { os.Exit(1) } passwdManager.add(port, password, ln) - cipher, err := user.GetCipher() + cipher, err, auth := user.GetCipher() if err != nil { return } @@ -288,27 +268,34 @@ func runWithCustomMethod(user user.User) { // Creating cipher upon first connection. if cipher == nil { Log.Debug("creating cipher for port:", port) - cipher, err = ss.NewCipher(user.GetMethod(), password) + method := user.GetMethod() + + if strings.HasSuffix(method, "-auth") { + method = method[:len(method)-5] + auth = true + } else { + auth = false + } + + cipher, err = ss.NewCipher(method, password) if err != nil { Log.Error(fmt.Sprintf("Error generating cipher for port: %s %v\n", port, err)) conn.Close() continue } } - go handleConnection(user, ss.NewConn(conn, cipher.Copy())) + go handleConnection(user, ss.NewConn(conn, cipher.Copy()), auth) } } const bufSize = 4096 const nBuf = 2048 -func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, user user.User) (total int, raw_header []byte) { +func PipeThenClose(src, dst net.Conn, is_res bool, host string, user user.User) { var pipeBuf = leakybuf.NewLeakyBuf(nBuf, bufSize) defer dst.Close() buf := pipeBuf.Get() // defer pipeBuf.Put(buf) - var buffer = bytes.NewBuffer(nil) - var is_end = false var size int for { @@ -317,15 +304,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us // read may return EOF with n > 0 // should always process n > 0 bytes before handling error if n > 0 { - if is_http && !is_end { - buffer.Write(buf) - raw_header = buffer.Bytes() - lines := bytes.SplitN(raw_header, []byte("\r\n\r\n"), 2) - if len(lines) == 2 { - is_end = true - } - } - size, err = dst.Write(buf[0:n]) if is_res { err = storage.IncrSize(user, size) @@ -334,7 +312,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us } Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), size)) } - total += size if err != nil { Log.Debug("write:", err) break @@ -350,6 +327,69 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us return } +func PipeThenCloseOta(src *ss.Conn, dst net.Conn, is_res bool, host string, user user.User) { + const ( + dataLenLen = 2 + hmacSha1Len = 10 + idxData0 = dataLenLen + hmacSha1Len + ) + + defer func() { + dst.Close() + }() + var pipeBuf = leakybuf.NewLeakyBuf(nBuf, bufSize) + buf := pipeBuf.Get() + // sometimes it have to fill large block + for i := 1; ; i += 1 { + SetReadTimeout(src) + n, err := io.ReadFull(src, buf[:dataLenLen+hmacSha1Len]) + if err != nil { + if err == io.EOF { + break + } + Log.Debug(fmt.Sprintf("conn=%p #%v read header error n=%v: %v", src, i, n, err)) + break + } + dataLen := binary.BigEndian.Uint16(buf[:dataLenLen]) + expectedHmacSha1 := buf[dataLenLen:idxData0] + + var dataBuf []byte + if len(buf) < int(idxData0+dataLen) { + dataBuf = make([]byte, dataLen) + } else { + dataBuf = buf[idxData0 : idxData0+dataLen] + } + if n, err := io.ReadFull(src, dataBuf); err != nil { + if err == io.EOF { + break + } + Log.Debug(fmt.Sprintf("conn=%p #%v read data error n=%v: %v", src, i, n, err)) + break + } + chunkIdBytes := make([]byte, 4) + chunkId := src.GetAndIncrChunkId() + binary.BigEndian.PutUint32(chunkIdBytes, chunkId) + actualHmacSha1 := ss.HmacSha1(append(src.GetIv(), chunkIdBytes...), dataBuf) + if !bytes.Equal(expectedHmacSha1, actualHmacSha1) { + Log.Debug(fmt.Sprintf("conn=%p #%v read data hmac-sha1 mismatch, iv=%v chunkId=%v src=%v dst=%v len=%v expeced=%v actual=%v", src, i, src.GetIv(), chunkId, src.RemoteAddr(), dst.RemoteAddr(), dataLen, expectedHmacSha1, actualHmacSha1)) + break + } + + if n, err := dst.Write(dataBuf); err != nil { + Log.Debug(fmt.Sprintf("conn=%p #%v write data error n=%v: %v", dst, i, n, err)) + break + } + if is_res { + err := storage.IncrSize(user, n) + if err != nil { + Log.Error(err) + } + Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), n)) + } + } + return +} + var readTimeout time.Duration func SetReadTimeout(c net.Conn) { diff --git a/mu/mysql/mysql.go b/mu/mysql/mysql.go index fdbd2c2b..6060fcea 100644 --- a/mu/mysql/mysql.go +++ b/mu/mysql/mysql.go @@ -6,6 +6,7 @@ import ( "github.com/orvice/shadowsocks-go/mu/user" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" "time" + "strings" ) var client *Client @@ -65,8 +66,16 @@ func (u *User) IsEnable() bool { return true } -func (u *User) GetCipher() (*ss.Cipher, error) { - return ss.NewCipher(u.method, u.passwd) +func (u *User) GetCipher() (*ss.Cipher, error, bool) { + method := u.method + auth := false + + if strings.HasSuffix(method, "-auth") { + method = method[:len(method)-5] + auth = true + } + s,e := ss.NewCipher(method, u.passwd) + return s, e, auth } func (u *User) UpdateTraffic(storageSize int) error { diff --git a/mu/user/user.go b/mu/user/user.go index 1df28939..bb7edd79 100644 --- a/mu/user/user.go +++ b/mu/user/user.go @@ -32,7 +32,7 @@ type User interface { GetPasswd() string GetMethod() string IsEnable() bool - GetCipher() (*ss.Cipher, error) + GetCipher() (*ss.Cipher, error, bool) UpdateTraffic(storageSize int) error GetUserInfo() UserInfo } diff --git a/mu/webapi/user.go b/mu/webapi/user.go index e750946e..c96bce34 100644 --- a/mu/webapi/user.go +++ b/mu/webapi/user.go @@ -4,6 +4,7 @@ import ( "github.com/orvice/shadowsocks-go/mu/user" ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" "strconv" + "strings" ) type User struct { @@ -39,8 +40,16 @@ func (u User) IsEnable() bool { return true } -func (u User) GetCipher() (*ss.Cipher, error) { - return ss.NewCipher(u.Method, u.Passwd) +func (u User) GetCipher() (*ss.Cipher, error, bool) { + method := u.Method + auth := false + + if strings.HasSuffix(method, "-auth") { + method = method[:len(method)-5] + auth = true + } + s,e := ss.NewCipher(method, u.Passwd) + return s, e, auth } func (u User) UpdateTraffic(storageSize int) error {