Skip to content

Commit

Permalink
custom copy func
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Pashmfouroush <[email protected]>
  • Loading branch information
markpash committed Apr 17, 2024
1 parent 6eb445d commit c48f2be
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 142 deletions.
263 changes: 133 additions & 130 deletions proxy/pkg/socks5/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"io"
"log/slog"
"net"
"time"

"github.com/bepass-org/warp-plus/proxy/pkg/statute"
)
Expand Down Expand Up @@ -236,9 +235,9 @@ func (s *Server) handle(req *request) error {
}

func (s *Server) handleConnect(req *request) error {
if s.UserConnectHandle == nil {
return s.embedHandleConnect(req)
}
// if s.UserConnectHandle == nil {
// return s.embedHandleConnect(req)
// }

if err := sendReply(req.Conn, successReply, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
Expand All @@ -261,46 +260,46 @@ func (s *Server) handleConnect(req *request) error {
return s.UserConnectHandle(proxyReq)
}

func (s *Server) embedHandleConnect(req *request) error {
defer func() {
_ = req.Conn.Close()
}()

target, err := s.ProxyDial(s.Context, "tcp", req.DestinationAddr.Address())
if err != nil {
if err := sendReply(req.Conn, errToReply(err), nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err)
}
defer func() {
_ = target.Close()
}()

localAddr := target.LocalAddr()
local, ok := localAddr.(*net.TCPAddr)
if !ok {
return fmt.Errorf("connect to %v failed: local address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String())
}
bind := address{IP: local.IP, Port: local.Port}
if err := sendReply(req.Conn, successReply, &bind); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}

var buf1, buf2 []byte
if s.BytesPool != nil {
buf1 = s.BytesPool.Get()
buf2 = s.BytesPool.Get()
defer func() {
s.BytesPool.Put(buf1)
s.BytesPool.Put(buf2)
}()
} else {
buf1 = make([]byte, 4*1024)
buf2 = make([]byte, 4*1024)
}
return statute.Tunnel(s.Context, target, req.Conn, buf1, buf2)
}
// func (s *Server) embedHandleConnect(req *request) error {
// defer func() {
// _ = req.Conn.Close()
// }()

// target, err := s.ProxyDial(s.Context, "tcp", req.DestinationAddr.Address())
// if err != nil {
// if err := sendReply(req.Conn, errToReply(err), nil); err != nil {
// return fmt.Errorf("failed to send reply: %v", err)
// }
// return fmt.Errorf("connect to %v failed: %w", req.DestinationAddr, err)
// }
// defer func() {
// _ = target.Close()
// }()

// localAddr := target.LocalAddr()
// local, ok := localAddr.(*net.TCPAddr)
// if !ok {
// return fmt.Errorf("connect to %v failed: local address is %s://%s", req.DestinationAddr, localAddr.Network(), localAddr.String())
// }
// bind := address{IP: local.IP, Port: local.Port}
// if err := sendReply(req.Conn, successReply, &bind); err != nil {
// return fmt.Errorf("failed to send reply: %v", err)
// }

// var buf1, buf2 []byte
// if s.BytesPool != nil {
// buf1 = s.BytesPool.Get()
// buf2 = s.BytesPool.Get()
// defer func() {
// s.BytesPool.Put(buf1)
// s.BytesPool.Put(buf2)
// }()
// } else {
// buf1 = make([]byte, 4*1024)
// buf2 = make([]byte, 4*1024)
// }
// return statute.Tunnel(s.Context, target, req.Conn, buf1, buf2)
// }

func (s *Server) handleAssociate(req *request) error {
destinationAddr := req.DestinationAddr.String()
Expand All @@ -321,9 +320,9 @@ func (s *Server) handleAssociate(req *request) error {
return fmt.Errorf("failed to send reply: %v", err)
}

if s.UserAssociateHandle == nil {
return s.embedHandleAssociate(req, udpConn)
}
// if s.UserAssociateHandle == nil {
// return s.embedHandleAssociate(req, udpConn)
// }

cConn := &udpCustomConn{
PacketConn: udpConn,
Expand All @@ -350,89 +349,93 @@ func (s *Server) handleAssociate(req *request) error {
return s.UserAssociateHandle(proxyReq)
}

func (s *Server) embedHandleAssociate(req *request, udpConn net.PacketConn) error {
defer udpConn.Close()

go func() {
var buf [1]byte
for {
req.Conn.SetReadDeadline(time.Now().Add(15 * time.Second))
_, err := req.Conn.Read(buf[:])
if err != nil {
_ = udpConn.Close()
break
}
}
}()

var (
sourceAddr net.Addr
wantSource string
targetAddr net.Addr
wantTarget string
replyPrefix []byte
buf [maxUdpPacket]byte
)

for {
udpConn.SetReadDeadline(time.Now().Add(15 * time.Second))
n, addr, err := udpConn.ReadFrom(buf[:])
if err != nil {
return err
}

if sourceAddr == nil {
sourceAddr = addr
wantSource = sourceAddr.String()
}

gotAddr := addr.String()
if wantSource == gotAddr {
if n < 3 {
continue
}
reader := bytes.NewBuffer(buf[3:n])
addr, err := readAddr(reader)
if err != nil {
s.Logger.Debug(err.Error())
continue
}

if targetAddr == nil {
targetAddr = &net.UDPAddr{
IP: addr.IP,
Port: addr.Port,
}
wantTarget = targetAddr.String()
}

if addr.String() != wantTarget {
s.Logger.Debug("ignore non-target addresses", "address", addr)
continue
}

_, err = udpConn.WriteTo(reader.Bytes(), targetAddr)
if err != nil {
return err
}
} else if targetAddr != nil && wantTarget == gotAddr {
if replyPrefix == nil {
b := bytes.NewBuffer(make([]byte, 3, 16))
err = writeAddrWithStr(b, wantTarget)
if err != nil {
return err
}
replyPrefix = b.Bytes()
}
copy(buf[len(replyPrefix):len(replyPrefix)+n], buf[:n])
copy(buf[:len(replyPrefix)], replyPrefix)
_, err = udpConn.WriteTo(buf[:len(replyPrefix)+n], sourceAddr)
if err != nil {
return err
}
}
}
}
// func (s *Server) embedHandleAssociate(req *request, udpConn net.PacketConn) error {
// s.Logger.Debug("EMBED HANDLE ASSOCIATE")
// defer udpConn.Close()
// defer req.Conn.Close()

// go func() {
// var buf [1]byte
// for {
// s.Logger.Debug("ASSOC READ PACKET")
// req.Conn.SetReadDeadline(time.Now().Add(15 * time.Second))
// _, err := req.Conn.Read(buf[:])
// if err != nil {
// udpConn.Close()
// req.Conn.Close()
// break
// }
// }
// }()

// var (
// sourceAddr net.Addr
// wantSource string
// targetAddr net.Addr
// wantTarget string
// replyPrefix []byte
// buf [maxUdpPacket]byte
// )

// for {
// udpConn.SetReadDeadline(time.Now().Add(15 * time.Second))
// n, addr, err := udpConn.ReadFrom(buf[:])
// if err != nil {
// return err
// }

// if sourceAddr == nil {
// sourceAddr = addr
// wantSource = sourceAddr.String()
// }

// gotAddr := addr.String()
// if wantSource == gotAddr {
// if n < 3 {
// continue
// }
// reader := bytes.NewBuffer(buf[3:n])
// addr, err := readAddr(reader)
// if err != nil {
// s.Logger.Debug(err.Error())
// continue
// }

// if targetAddr == nil {
// targetAddr = &net.UDPAddr{
// IP: addr.IP,
// Port: addr.Port,
// }
// wantTarget = targetAddr.String()
// }

// if addr.String() != wantTarget {
// s.Logger.Debug("ignore non-target addresses", "address", addr)
// continue
// }

// _, err = udpConn.WriteTo(reader.Bytes(), targetAddr)
// if err != nil {
// return err
// }
// } else if targetAddr != nil && wantTarget == gotAddr {
// if replyPrefix == nil {
// b := bytes.NewBuffer(make([]byte, 3, 16))
// err = writeAddrWithStr(b, wantTarget)
// if err != nil {
// return err
// }
// replyPrefix = b.Bytes()
// }
// copy(buf[len(replyPrefix):len(replyPrefix)+n], buf[:n])
// copy(buf[:len(replyPrefix)], replyPrefix)
// _, err = udpConn.WriteTo(buf[:len(replyPrefix)+n], sourceAddr)
// if err != nil {
// return err
// }
// }
// }
// }

func sendReply(w io.Writer, resp reply, addr *address) error {
_, err := w.Write([]byte{socks5Version, byte(resp), 0})
Expand Down
2 changes: 1 addition & 1 deletion wireguard/device/queueconstants_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ const (
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
PreallocatedBuffersPerPool = 4096 // Disable and allow for infinite memory growth
)
Loading

0 comments on commit c48f2be

Please sign in to comment.