Skip to content
This repository has been archived by the owner on Jun 16, 2019. It is now read-only.

Commit

Permalink
Merge pull request #48 from jemyzhang/master
Browse files Browse the repository at this point in the history
try to enable ota
  • Loading branch information
orvice committed May 17, 2016
2 parents 6be75ab + ca3214f commit 64ced42
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 119 deletions.
268 changes: 154 additions & 114 deletions mu/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,61 +25,65 @@ import (
"time"
)

const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index

typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address

lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
lenDmBase = 2 // 1addrLen + 2port, plus addrLen
lenHmacSha1 = 10
)

var ssdebug ss.DebugLog

func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index

typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address

lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port
lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port
lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen
)
func getRequest(conn *ss.Conn, auth bool) (host string, res_size int, ota bool, err error) {
var n int
ss.SetReadTimeout(conn)

// buf size should at least have the same size with the largest possible
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 260)
var n int
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1)
buf := make([]byte, 270)
// read till we get possible domain length field
ss.SetReadTimeout(conn)
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
if n, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
return
}
res_size += n

reqLen := -1
switch buf[idType] {
var reqStart, reqEnd int
addrType := buf[idType]
switch addrType & ss.AddrMask {
case typeIPv4:
reqLen = lenIPv4
reqStart, reqEnd = idIP0, idIP0+lenIPv4
case typeIPv6:
reqLen = lenIPv6
reqStart, reqEnd = idIP0, idIP0+lenIPv6
case typeDm:
reqLen = int(buf[idDmLen]) + lenDmBase
if n, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
return
}
reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase)
default:
err = fmt.Errorf("addr type %d not supported", buf[idType])
err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask)
return
}
res_size += n

if n < reqLen { // rare case
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
} else if n > reqLen {
// it's possible to read more than just the request head
extra = buf[reqLen:n]
if n, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
return
}
res_size += n

// Return string for typeIP is not most efficient, but browsers (Chrome,
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
// big problem.
switch buf[idType] {
switch addrType & ss.AddrMask {
case typeIPv4:
host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
case typeIPv6:
Expand All @@ -88,8 +92,23 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
// parse port
port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen])
port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
host = net.JoinHostPort(host, strconv.Itoa(int(port)))
// if specified one time auth enabled, we should verify this
if auth || addrType&ss.OneTimeAuthMask > 0 {
ota = true
if n, err = io.ReadFull(conn, buf[reqEnd:reqEnd+lenHmacSha1]); err != nil {
return
}
iv := conn.GetIv()
key := conn.GetKey()
actualHmacSha1Buf := ss.HmacSha1(append(iv, key...), buf[:reqEnd])
if !bytes.Equal(buf[reqEnd:reqEnd+lenHmacSha1], actualHmacSha1Buf) {
err = fmt.Errorf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, buf[:reqEnd])
return
}
res_size += n
}
return
}

Expand All @@ -98,13 +117,9 @@ const logCntDelta = 100
var connCnt int
var nextLogConnCnt int = logCntDelta

func handleConnection(user user.User, conn *ss.Conn) {
func handleConnection(user user.User, conn *ss.Conn, auth bool) {
var host string
var size = 0
var raw_req_header, raw_res_header []byte
var is_http = false
var res_size = 0
var req_chan = make(chan []byte)

connCnt++ // this maybe not accurate, but should be enough
if connCnt-nextLogConnCnt >= 0 {
// XXX There's no xadd in the atomic package, so it's difficult to log
Expand All @@ -128,7 +143,7 @@ func handleConnection(user user.User, conn *ss.Conn) {
}
}()

host, extra, err := getRequest(conn)
host, res_size, ota, err := getRequest(conn, auth)
if err != nil {
Log.Error("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err)
return
Expand All @@ -151,65 +166,30 @@ func handleConnection(user user.User, conn *ss.Conn) {
}
}()

defer func() {
if is_http {
tmp_req_header := <-req_chan
buffer := bytes.NewBuffer(raw_req_header)
buffer.Write(tmp_req_header)
raw_req_header = buffer.Bytes()
}
showConn(raw_req_header, raw_res_header, host, user, size, is_http)
close(req_chan)
if !closed {
remote.Close()
}
}()
// debug conn info
Log.Debug(fmt.Sprintf("%d conn debug: local addr: %s | remote addr: %s network: %s ", user.GetPort(),
conn.LocalAddr().String(), conn.RemoteAddr().String(), conn.RemoteAddr().Network()))
err = storage.IncrSize(user, res_size)
if err != nil {
Log.Error(err)
return
}
err = storage.MarkUserOnline(user)
if err != nil {
Log.Error(err)
return
}
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), res_size))

// write extra bytes read from
if extra != nil {
// debug.Println("getRequest read extra data, writing to remote, len", len(extra))
is_http, extra, _ = checkHttp(extra, conn)
if strings.HasSuffix(host, ":80") {
is_http = true
}
raw_req_header = extra
res_size, err = remote.Write(extra)
// size, err := remote.Write(extra)
if err != nil {
Log.Error("write request extra error:", err)
return
}
// debug conn info
Log.Debug(fmt.Sprintf("%d conn debug: local addr: %s | remote addr: %s network: %s ", user.GetPort(),
conn.LocalAddr().String(), conn.RemoteAddr().String(), conn.RemoteAddr().Network()))
err = storage.IncrSize(user, res_size)
if err != nil {
Log.Error(err)
return
}
err = storage.MarkUserOnline(user)
if err != nil {
Log.Error(err)
return
}
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), res_size))
Log.Info(fmt.Sprintf("piping %s<->%s ota=%v connOta=%v", conn.RemoteAddr(), host, ota, conn.IsOta()))

if ota {
go PipeThenCloseOta(conn, remote, false, host, user)
} else {
go PipeThenClose(conn, remote, false, host, user)
}
Log.Debug(fmt.Sprintf("piping %s<->%s", conn.RemoteAddr(), host))
/**
go ss.PipeThenClose(conn, remote)
ss.PipeThenClose(remote, conn)
closed = true
return
**/
go func() {
_, raw_header := PipeThenClose(conn, remote, is_http, false, host, user)
if is_http {
req_chan <- raw_header
}
}()

res_size, raw_res_header = PipeThenClose(remote, conn, is_http, true, host, user)
size += res_size
PipeThenClose(remote, conn, true, host, user)
closed = true
return
}
Expand Down Expand Up @@ -273,7 +253,7 @@ func runWithCustomMethod(user user.User) {
os.Exit(1)
}
passwdManager.add(port, password, ln)
cipher, err := user.GetCipher()
cipher, err, auth := user.GetCipher()
if err != nil {
return
}
Expand All @@ -288,27 +268,34 @@ func runWithCustomMethod(user user.User) {
// Creating cipher upon first connection.
if cipher == nil {
Log.Debug("creating cipher for port:", port)
cipher, err = ss.NewCipher(user.GetMethod(), password)
method := user.GetMethod()

if strings.HasSuffix(method, "-auth") {
method = method[:len(method)-5]
auth = true
} else {
auth = false
}

cipher, err = ss.NewCipher(method, password)
if err != nil {
Log.Error(fmt.Sprintf("Error generating cipher for port: %s %v\n", port, err))
conn.Close()
continue
}
}
go handleConnection(user, ss.NewConn(conn, cipher.Copy()))
go handleConnection(user, ss.NewConn(conn, cipher.Copy()), auth)
}
}

const bufSize = 4096
const nBuf = 2048

func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, user user.User) (total int, raw_header []byte) {
func PipeThenClose(src, dst net.Conn, is_res bool, host string, user user.User) {
var pipeBuf = leakybuf.NewLeakyBuf(nBuf, bufSize)
defer dst.Close()
buf := pipeBuf.Get()
// defer pipeBuf.Put(buf)
var buffer = bytes.NewBuffer(nil)
var is_end = false
var size int

for {
Expand All @@ -317,15 +304,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
// read may return EOF with n > 0
// should always process n > 0 bytes before handling error
if n > 0 {
if is_http && !is_end {
buffer.Write(buf)
raw_header = buffer.Bytes()
lines := bytes.SplitN(raw_header, []byte("\r\n\r\n"), 2)
if len(lines) == 2 {
is_end = true
}
}

size, err = dst.Write(buf[0:n])
if is_res {
err = storage.IncrSize(user, size)
Expand All @@ -334,7 +312,6 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
}
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), size))
}
total += size
if err != nil {
Log.Debug("write:", err)
break
Expand All @@ -350,6 +327,69 @@ func PipeThenClose(src, dst net.Conn, is_http bool, is_res bool, host string, us
return
}

func PipeThenCloseOta(src *ss.Conn, dst net.Conn, is_res bool, host string, user user.User) {
const (
dataLenLen = 2
hmacSha1Len = 10
idxData0 = dataLenLen + hmacSha1Len
)

defer func() {
dst.Close()
}()
var pipeBuf = leakybuf.NewLeakyBuf(nBuf, bufSize)
buf := pipeBuf.Get()
// sometimes it have to fill large block
for i := 1; ; i += 1 {
SetReadTimeout(src)
n, err := io.ReadFull(src, buf[:dataLenLen+hmacSha1Len])
if err != nil {
if err == io.EOF {
break
}
Log.Debug(fmt.Sprintf("conn=%p #%v read header error n=%v: %v", src, i, n, err))
break
}
dataLen := binary.BigEndian.Uint16(buf[:dataLenLen])
expectedHmacSha1 := buf[dataLenLen:idxData0]

var dataBuf []byte
if len(buf) < int(idxData0+dataLen) {
dataBuf = make([]byte, dataLen)
} else {
dataBuf = buf[idxData0 : idxData0+dataLen]
}
if n, err := io.ReadFull(src, dataBuf); err != nil {
if err == io.EOF {
break
}
Log.Debug(fmt.Sprintf("conn=%p #%v read data error n=%v: %v", src, i, n, err))
break
}
chunkIdBytes := make([]byte, 4)
chunkId := src.GetAndIncrChunkId()
binary.BigEndian.PutUint32(chunkIdBytes, chunkId)
actualHmacSha1 := ss.HmacSha1(append(src.GetIv(), chunkIdBytes...), dataBuf)
if !bytes.Equal(expectedHmacSha1, actualHmacSha1) {
Log.Debug(fmt.Sprintf("conn=%p #%v read data hmac-sha1 mismatch, iv=%v chunkId=%v src=%v dst=%v len=%v expeced=%v actual=%v", src, i, src.GetIv(), chunkId, src.RemoteAddr(), dst.RemoteAddr(), dataLen, expectedHmacSha1, actualHmacSha1))
break
}

if n, err := dst.Write(dataBuf); err != nil {
Log.Debug(fmt.Sprintf("conn=%p #%v write data error n=%v: %v", dst, i, n, err))
break
}
if is_res {
err := storage.IncrSize(user, n)
if err != nil {
Log.Error(err)
}
Log.Debug(fmt.Sprintf("[port-%d] store size: %d", user.GetPort(), n))
}
}
return
}

var readTimeout time.Duration

func SetReadTimeout(c net.Conn) {
Expand Down
Loading

0 comments on commit 64ced42

Please sign in to comment.