Skip to content

Commit

Permalink
use early conn to support real ws 0-rtt
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Feb 24, 2023
1 parent 878c726 commit d247b3d
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 30 deletions.
27 changes: 17 additions & 10 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"time"

"github.com/Dreamacro/clash/common/callback"
"github.com/Dreamacro/clash/common/queue"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
Expand Down Expand Up @@ -40,9 +41,17 @@ func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {

// DialContext implements C.ProxyAdapter
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
return aliveContext(p, ctx, func(ctx context.Context) (C.Conn, error) {
return p.ProxyAdapter.DialContext(ctx, metadata, opts...)
})
beginTime := time.Now()
c, err := p.ProxyAdapter.DialContext(ctx, metadata, opts...)
aliveCallback(beginTime, err, p, ctx)

c = &callback.FirstWriteCallBackConn{
Conn: c,
Callback: func(err error) {
aliveCallback(beginTime, err, p, ctx)
},
}
return c, err
}

// DialUDP implements C.ProxyAdapter
Expand All @@ -54,21 +63,19 @@ func (p *Proxy) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {

// ListenPacketContext implements C.ProxyAdapter
func (p *Proxy) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
return aliveContext(p, ctx, func(ctx context.Context) (C.PacketConn, error) {
return p.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...)
})
beginTime := time.Now()
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...)
aliveCallback(beginTime, err, p, ctx)
return pc, err
}

func aliveContext[T any](p *Proxy, ctx context.Context, f func(context.Context) (T, error)) (T, error) {
beginTime := time.Now()
t, err := f(ctx)
func aliveCallback(beginTime time.Time, err error, p *Proxy, ctx context.Context) {
timeUsed := time.Now().Sub(beginTime)
if err != nil {
if ctx.Err() == nil || timeUsed > 1*time.Second { // context not cancelled or timeUsed>1s
p.alive.Store(false)
}
}
return t, err
}

// DelayHistory implements C.Proxy
Expand Down
3 changes: 3 additions & 0 deletions adapter/outbound/reject.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ func (rw *nopConn) Read(b []byte) (int, error) {
}

func (rw *nopConn) Write(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
return 0, io.EOF
}

Expand Down
4 changes: 2 additions & 2 deletions adapter/outbound/shadowsocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ func (ss *ShadowSocks) streamConn(c net.Conn, metadata *C.Metadata) (net.Conn, e
//_, err := c.Write(serializesSocksAddr(metadata))
//return c, err
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
return ss.method.DialConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443"))
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(uot.UOTMagicAddress+":443")), nil
}
return ss.method.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return ss.method.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
}

// DialContext implements C.ProxyAdapter
Expand Down
10 changes: 5 additions & 5 deletions adapter/outbound/vmess.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,12 @@ func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
//return v.client.StreamConn(c, parseVmessAddr(metadata))
if metadata.NetWork == C.UDP {
if v.option.XUDP {
return v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
} else {
return v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
}
} else {
return v.client.DialConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
return v.client.DialEarlyConn(c, M.ParseSocksaddr(metadata.RemoteAddress())), nil
}
}

Expand Down Expand Up @@ -286,9 +286,9 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, o

//c, err = v.client.StreamConn(c, parseVmessAddr(metadata))
if v.option.XUDP {
c, err = v.client.DialXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
c = v.client.DialEarlyXUDPPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
} else {
c, err = v.client.DialPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
c = v.client.DialEarlyPacketConn(c, M.ParseSocksaddr(metadata.RemoteAddress()))
}

if err != nil {
Expand Down
11 changes: 11 additions & 0 deletions adapter/outboundgroup/fallback.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"

"github.com/Dreamacro/clash/adapter/outbound"
"github.com/Dreamacro/clash/common/callback"
"github.com/Dreamacro/clash/common/singledo"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
Expand Down Expand Up @@ -33,6 +34,16 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts .
} else {
doHealthCheck(f.providers, proxy)
}

c = &callback.FirstWriteCallBackConn{
Conn: c,
Callback: func(err error) {
if err == nil {
} else {
doHealthCheck(f.providers, proxy)
}
},
}
return c, err
}

Expand Down
11 changes: 11 additions & 0 deletions adapter/outboundgroup/urltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/Dreamacro/clash/adapter/outbound"
"github.com/Dreamacro/clash/common/callback"
"github.com/Dreamacro/clash/common/singledo"
"github.com/Dreamacro/clash/component/dialer"
C "github.com/Dreamacro/clash/constant"
Expand Down Expand Up @@ -44,6 +45,16 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ..
} else {
doHealthCheck(u.providers, proxy)
}

c = &callback.FirstWriteCallBackConn{
Conn: c,
Callback: func(err error) {
if err == nil {
} else {
doHealthCheck(u.providers, proxy)
}
},
}
return c, err
}

Expand Down
25 changes: 25 additions & 0 deletions common/callback/callback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package callback

import (
C "github.com/Dreamacro/clash/constant"
)

type FirstWriteCallBackConn struct {
C.Conn
Callback func(error)
written bool
}

func (c *FirstWriteCallBackConn) Write(b []byte) (n int, err error) {
defer func() {
if !c.written {
c.written = true
c.Callback(err)
}
}()
return c.Conn.Write(b)
}

func (c *FirstWriteCallBackConn) Upstream() any {
return c.Conn
}
12 changes: 11 additions & 1 deletion common/net/bufconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,35 @@ import (
type BufferedConn struct {
r *bufio.Reader
net.Conn
peeked bool
}

func NewBufferedConn(c net.Conn) *BufferedConn {
if bc, ok := c.(*BufferedConn); ok {
return bc
}
return &BufferedConn{bufio.NewReader(c), c}
return &BufferedConn{bufio.NewReader(c), c, false}
}

// Reader returns the internal bufio.Reader.
func (c *BufferedConn) Reader() *bufio.Reader {
return c.r
}

func (c *BufferedConn) Peeked() bool {
return c.peeked
}

// Peek returns the next n bytes without advancing the reader.
func (c *BufferedConn) Peek(n int) ([]byte, error) {
c.peeked = true
return c.r.Peek(n)
}

func (c *BufferedConn) Discard(n int) (discarded int, err error) {
return c.r.Discard(n)
}

func (c *BufferedConn) Read(p []byte) (int, error) {
return c.r.Read(p)
}
Expand Down
9 changes: 2 additions & 7 deletions component/sniffer/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ type SnifferDispatcher struct {
parsePureIp bool
}

func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
bufConn, ok := conn.(*N.BufferedConn)
if !ok {
return
}

func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) {
switch metadata.Type {
case C.DNS, C.MTPROXY, C.PROVIDER: // ignore inner connection
return
Expand Down Expand Up @@ -81,7 +76,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
}
sd.rwMux.RUnlock()

if host, err := sd.sniffDomain(bufConn, metadata); err != nil {
if host, err := sd.sniffDomain(conn, metadata); err != nil {
sd.cacheSniffFailed(metadata)
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
return
Expand Down
4 changes: 3 additions & 1 deletion constant/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package constant
import (
"net"

N "github.com/Dreamacro/clash/common/net"

"github.com/gofrs/uuid"
)

Expand All @@ -13,7 +15,7 @@ type PlainContext interface {
type ConnContext interface {
PlainContext
Metadata() *Metadata
Conn() net.Conn
Conn() *N.BufferedConn
}

type PacketConnContext interface {
Expand Down
4 changes: 2 additions & 2 deletions context/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
type ConnContext struct {
id uuid.UUID
metadata *C.Metadata
conn net.Conn
conn *N.BufferedConn
}

func NewConnContext(conn net.Conn, metadata *C.Metadata) *ConnContext {
Expand All @@ -35,6 +35,6 @@ func (c *ConnContext) Metadata() *C.Metadata {
}

// Conn implement C.ConnContext Conn
func (c *ConnContext) Conn() net.Conn {
func (c *ConnContext) Conn() *N.BufferedConn {
return c.conn
}
32 changes: 30 additions & 2 deletions tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,20 @@ func handleTCPConn(connCtx C.ConnContext) {
return
}

conn := connCtx.Conn()
if sniffer.Dispatcher.Enable() && sniffingEnable {
sniffer.Dispatcher.TCPSniff(connCtx.Conn(), metadata)
sniffer.Dispatcher.TCPSniff(conn, metadata)
}

peekMutex := sync.Mutex{}
if !conn.Peeked() {
peekMutex.Lock()
go func() {
defer peekMutex.Unlock()
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
_, _ = conn.Peek(1)
_ = conn.SetReadDeadline(time.Time{})
}()
}

proxy, rule, err := resolveMetadata(connCtx, metadata)
Expand All @@ -401,10 +413,26 @@ func handleTCPConn(connCtx C.ConnContext) {
}
}

var peekBytes []byte

ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
remoteConn, err := retry(ctx, func(ctx context.Context) (C.Conn, error) {
return proxy.DialContext(ctx, dialMetadata)
remoteConn, err := proxy.DialContext(ctx, dialMetadata)
if err != nil {
return nil, err
}
peekMutex.Lock()
defer peekMutex.Unlock()
peekBytes, _ = conn.Peek(conn.Buffered())
_, err = remoteConn.Write(peekBytes)
if err != nil {
return nil, err
}
if peekLen := len(peekBytes); peekLen > 0 {
_, _ = conn.Discard(peekLen)
}
return remoteConn, err
}, func(err error) {
if rule == nil {
log.Warnln(
Expand Down

0 comments on commit d247b3d

Please sign in to comment.