From c48f2be19fe25a7bb82ed123482e8e63508e2058 Mon Sep 17 00:00:00 2001 From: Mark Pashmfouroush Date: Wed, 17 Apr 2024 12:12:12 +0100 Subject: [PATCH] custom copy func Signed-off-by: Mark Pashmfouroush --- proxy/pkg/socks5/server.go | 263 +++++++++++---------- wireguard/device/queueconstants_default.go | 2 +- wiresocks/proxy.go | 61 ++++- wiresocks/wiresocks.go | 10 +- 4 files changed, 194 insertions(+), 142 deletions(-) diff --git a/proxy/pkg/socks5/server.go b/proxy/pkg/socks5/server.go index 089faddd4..fcd539228 100644 --- a/proxy/pkg/socks5/server.go +++ b/proxy/pkg/socks5/server.go @@ -7,7 +7,6 @@ import ( "io" "log/slog" "net" - "time" "github.com/bepass-org/warp-plus/proxy/pkg/statute" ) @@ -236,9 +235,9 @@ func (s *Server) handle(req *request) error { } func (s *Server) handleConnect(req *request) error { - if s.UserConnectHandle == nil { - return s.embedHandleConnect(req) - } + // if s.UserConnectHandle == nil { + // return s.embedHandleConnect(req) + // } if err := sendReply(req.Conn, successReply, nil); err != nil { return fmt.Errorf("failed to send reply: %v", err) @@ -261,46 +260,46 @@ func (s *Server) handleConnect(req *request) error { return s.UserConnectHandle(proxyReq) } -func (s *Server) embedHandleConnect(req *request) error { - defer func() { - _ = req.Conn.Close() - }() - - target, err := s.ProxyDial(s.Context, "tcp", req.DestinationAddr.Address()) - if err != nil { - if err := sendReply(req.Conn, errToReply(err), nil); err != nil { - return fmt.Errorf("failed to send reply: %v", err) - } - return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) - } - defer func() { - _ = target.Close() - }() - - localAddr := target.LocalAddr() - local, ok := localAddr.(*net.TCPAddr) - if !ok { - return fmt.Errorf("connect to %v failed: local address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String()) - } - bind := address{IP: local.IP, Port: local.Port} - if err := sendReply(req.Conn, successReply, &bind); err != nil { - return fmt.Errorf("failed to send reply: %v", err) - } - - var buf1, buf2 []byte - if s.BytesPool != nil { - buf1 = s.BytesPool.Get() - buf2 = s.BytesPool.Get() - defer func() { - s.BytesPool.Put(buf1) - s.BytesPool.Put(buf2) - }() - } else { - buf1 = make([]byte, 4*1024) - buf2 = make([]byte, 4*1024) - } - return statute.Tunnel(s.Context, target, req.Conn, buf1, buf2) -} +// func (s *Server) embedHandleConnect(req *request) error { +// defer func() { +// _ = req.Conn.Close() +// }() + +// target, err := s.ProxyDial(s.Context, "tcp", req.DestinationAddr.Address()) +// if err != nil { +// if err := sendReply(req.Conn, errToReply(err), nil); err != nil { +// return fmt.Errorf("failed to send reply: %v", err) +// } +// return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err) +// } +// defer func() { +// _ = target.Close() +// }() + +// localAddr := target.LocalAddr() +// local, ok := localAddr.(*net.TCPAddr) +// if !ok { +// return fmt.Errorf("connect to %v failed: local address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String()) +// } +// bind := address{IP: local.IP, Port: local.Port} +// if err := sendReply(req.Conn, successReply, &bind); err != nil { +// return fmt.Errorf("failed to send reply: %v", err) +// } + +// var buf1, buf2 []byte +// if s.BytesPool != nil { +// buf1 = s.BytesPool.Get() +// buf2 = s.BytesPool.Get() +// defer func() { +// s.BytesPool.Put(buf1) +// s.BytesPool.Put(buf2) +// }() +// } else { +// buf1 = make([]byte, 4*1024) +// buf2 = make([]byte, 4*1024) +// } +// return statute.Tunnel(s.Context, target, req.Conn, buf1, buf2) +// } func (s *Server) handleAssociate(req *request) error { destinationAddr := req.DestinationAddr.String() @@ -321,9 +320,9 @@ func (s *Server) handleAssociate(req *request) error { return fmt.Errorf("failed to send reply: %v", err) } - if s.UserAssociateHandle == nil { - return s.embedHandleAssociate(req, udpConn) - } + // if s.UserAssociateHandle == nil { + // return s.embedHandleAssociate(req, udpConn) + // } cConn := &udpCustomConn{ PacketConn: udpConn, @@ -350,89 +349,93 @@ func (s *Server) handleAssociate(req *request) error { return s.UserAssociateHandle(proxyReq) } -func (s *Server) embedHandleAssociate(req *request, udpConn net.PacketConn) error { - defer udpConn.Close() - - go func() { - var buf [1]byte - for { - req.Conn.SetReadDeadline(time.Now().Add(15 * time.Second)) - _, err := req.Conn.Read(buf[:]) - if err != nil { - _ = udpConn.Close() - break - } - } - }() - - var ( - sourceAddr net.Addr - wantSource string - targetAddr net.Addr - wantTarget string - replyPrefix []byte - buf [maxUdpPacket]byte - ) - - for { - udpConn.SetReadDeadline(time.Now().Add(15 * time.Second)) - n, addr, err := udpConn.ReadFrom(buf[:]) - if err != nil { - return err - } - - if sourceAddr == nil { - sourceAddr = addr - wantSource = sourceAddr.String() - } - - gotAddr := addr.String() - if wantSource == gotAddr { - if n < 3 { - continue - } - reader := bytes.NewBuffer(buf[3:n]) - addr, err := readAddr(reader) - if err != nil { - s.Logger.Debug(err.Error()) - continue - } - - if targetAddr == nil { - targetAddr = &net.UDPAddr{ - IP: addr.IP, - Port: addr.Port, - } - wantTarget = targetAddr.String() - } - - if addr.String() != wantTarget { - s.Logger.Debug("ignore non-target addresses", "address", addr) - continue - } - - _, err = udpConn.WriteTo(reader.Bytes(), targetAddr) - if err != nil { - return err - } - } else if targetAddr != nil && wantTarget == gotAddr { - if replyPrefix == nil { - b := bytes.NewBuffer(make([]byte, 3, 16)) - err = writeAddrWithStr(b, wantTarget) - if err != nil { - return err - } - replyPrefix = b.Bytes() - } - copy(buf[len(replyPrefix):len(replyPrefix)+n], buf[:n]) - copy(buf[:len(replyPrefix)], replyPrefix) - _, err = udpConn.WriteTo(buf[:len(replyPrefix)+n], sourceAddr) - if err != nil { - return err - } - } - } -} +// func (s *Server) embedHandleAssociate(req *request, udpConn net.PacketConn) error { +// s.Logger.Debug("EMBED HANDLE ASSOCIATE") +// defer udpConn.Close() +// defer req.Conn.Close() + +// go func() { +// var buf [1]byte +// for { +// s.Logger.Debug("ASSOC READ PACKET") +// req.Conn.SetReadDeadline(time.Now().Add(15 * time.Second)) +// _, err := req.Conn.Read(buf[:]) +// if err != nil { +// udpConn.Close() +// req.Conn.Close() +// break +// } +// } +// }() + +// var ( +// sourceAddr net.Addr +// wantSource string +// targetAddr net.Addr +// wantTarget string +// replyPrefix []byte +// buf [maxUdpPacket]byte +// ) + +// for { +// udpConn.SetReadDeadline(time.Now().Add(15 * time.Second)) +// n, addr, err := udpConn.ReadFrom(buf[:]) +// if err != nil { +// return err +// } + +// if sourceAddr == nil { +// sourceAddr = addr +// wantSource = sourceAddr.String() +// } + +// gotAddr := addr.String() +// if wantSource == gotAddr { +// if n < 3 { +// continue +// } +// reader := bytes.NewBuffer(buf[3:n]) +// addr, err := readAddr(reader) +// if err != nil { +// s.Logger.Debug(err.Error()) +// continue +// } + +// if targetAddr == nil { +// targetAddr = &net.UDPAddr{ +// IP: addr.IP, +// Port: addr.Port, +// } +// wantTarget = targetAddr.String() +// } + +// if addr.String() != wantTarget { +// s.Logger.Debug("ignore non-target addresses", "address", addr) +// continue +// } + +// _, err = udpConn.WriteTo(reader.Bytes(), targetAddr) +// if err != nil { +// return err +// } +// } else if targetAddr != nil && wantTarget == gotAddr { +// if replyPrefix == nil { +// b := bytes.NewBuffer(make([]byte, 3, 16)) +// err = writeAddrWithStr(b, wantTarget) +// if err != nil { +// return err +// } +// replyPrefix = b.Bytes() +// } +// copy(buf[len(replyPrefix):len(replyPrefix)+n], buf[:n]) +// copy(buf[:len(replyPrefix)], replyPrefix) +// _, err = udpConn.WriteTo(buf[:len(replyPrefix)+n], sourceAddr) +// if err != nil { +// return err +// } +// } +// } +// } func sendReply(w io.Writer, resp reply, addr *address) error { _, err := w.Write([]byte{socks5Version, byte(resp), 0}) diff --git a/wireguard/device/queueconstants_default.go b/wireguard/device/queueconstants_default.go index f22fa33df..ae02b6387 100644 --- a/wireguard/device/queueconstants_default.go +++ b/wireguard/device/queueconstants_default.go @@ -15,5 +15,5 @@ const ( QueueInboundSize = 1024 QueueHandshakeSize = 1024 MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth + PreallocatedBuffersPerPool = 4096 // Disable and allow for infinite memory growth ) diff --git a/wiresocks/proxy.go b/wiresocks/proxy.go index 89e177c1e..d47ee88b7 100644 --- a/wiresocks/proxy.go +++ b/wiresocks/proxy.go @@ -2,15 +2,18 @@ package wiresocks import ( "context" + "errors" "io" "log/slog" "net" "net/netip" + "time" "github.com/bepass-org/warp-plus/proxy/pkg/mixed" "github.com/bepass-org/warp-plus/proxy/pkg/statute" "github.com/bepass-org/warp-plus/wireguard/device" "github.com/bepass-org/warp-plus/wireguard/tun/netstack" + "github.com/things-go/go-socks5/bufferpool" ) // VirtualTun stores a reference to netstack network and DNS configuration @@ -19,6 +22,7 @@ type VirtualTun struct { Logger *slog.Logger Dev *device.Device Ctx context.Context + pool bufferpool.BufPool } // StartProxy spawns a socks5 server. @@ -60,15 +64,20 @@ func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error { // Channel to notify when copy operation is done done := make(chan error, 1) // Copy data from req.Conn to conn - buf1 := make([]byte, 4*1024) - buf2 := make([]byte, 4*1024) + go func() { - _, err := io.CopyBuffer(conn, req.Conn, buf1) + req.Conn.SetReadDeadline(time.Now().Add(15 * time.Second)) + buf1 := vt.pool.Get() + defer vt.pool.Put(buf1) + _, err := copyConnTimeout(conn, req.Conn, buf1[:cap(buf1)], 15*time.Second) done <- err }() // Copy data from conn to req.Conn go func() { - _, err := io.CopyBuffer(req.Conn, conn, buf2) + conn.SetReadDeadline(time.Now().Add(15 * time.Second)) + buf2 := vt.pool.Get() + defer vt.pool.Put(buf2) + _, err := copyConnTimeout(req.Conn, conn, buf2[:cap(buf2)], 15*time.Second) done <- err }() // Wait for one of the copy operations to finish @@ -79,9 +88,6 @@ func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error { // Close connections and wait for the other copy operation to finish <-done - conn.Close() - req.Conn.Close() - return nil } @@ -92,3 +98,44 @@ func (vt *VirtualTun) Stop() { } } } + +var errInvalidWrite = errors.New("invalid write result") + +func copyConnTimeout(dst net.Conn, src net.Conn, buf []byte, timeout time.Duration) (written int64, err error) { + if buf != nil && len(buf) == 0 { + panic("empty buffer in CopyBuffer") + } + + for { + if err := src.SetReadDeadline(time.Now().Add(timeout)); err != nil { + return 0, err + } + + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errInvalidWrite + } + } + written += int64(nw) + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} diff --git a/wiresocks/wiresocks.go b/wiresocks/wiresocks.go index b75d4acd4..b50587c42 100644 --- a/wiresocks/wiresocks.go +++ b/wiresocks/wiresocks.go @@ -9,6 +9,7 @@ import ( "github.com/bepass-org/warp-plus/wireguard/conn" "github.com/bepass-org/warp-plus/wireguard/device" "github.com/bepass-org/warp-plus/wireguard/tun/netstack" + "github.com/things-go/go-socks5/bufferpool" ) // StartWireguard creates a tun interface on netstack given a configuration @@ -46,9 +47,10 @@ func StartWireguard(ctx context.Context, l *slog.Logger, conf *Configuration) (* } return &VirtualTun{ - Tnet: tnet, - Logger: l.With("subsystem", "vtun"), - Dev: dev, - Ctx: ctx, + Tnet: tnet, + Logger: l.With("subsystem", "vtun"), + Dev: dev, + Ctx: ctx, + pool: bufferpool.NewPool(256 * 1024), }, nil }