From 2112448fd1777f25810e28f2696848b35e563475 Mon Sep 17 00:00:00 2001 From: Mark Pashmfouroush Date: Thu, 14 Mar 2024 16:47:32 +0000 Subject: [PATCH] misc: try to improve logging even further Signed-off-by: Mark Pashmfouroush --- app/app.go | 125 ++++++++++++------- ipscanner/internal/iterator/iterator.go | 9 +- ipscanner/internal/statute/default.go | 2 - main.go | 32 +++-- psiphon/p.go | 39 +++--- warp/account.go | 61 ++++----- warp/tls.go | 1 - wireguard/device/logger.go | 13 ++ wiresocks/proxy.go | 34 ++---- wiresocks/udpfw.go | 156 +----------------------- wiresocks/wiresocks.go | 47 ++----- 11 files changed, 172 insertions(+), 347 deletions(-) diff --git a/app/app.go b/app/app.go index 73e5d260c..ffa8f2703 100644 --- a/app/app.go +++ b/app/app.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "log" + "log/slog" "net" "net/netip" "os" @@ -20,7 +20,6 @@ const singleMTU = 1400 const doubleMTU = 1320 type WarpOptions struct { - LogLevel string Bind netip.AddrPort Endpoint string License string @@ -37,7 +36,7 @@ type ScanOptions struct { MaxRTT time.Duration } -func RunWarp(ctx context.Context, opts WarpOptions) error { +func RunWarp(ctx context.Context, l *slog.Logger, opts WarpOptions) error { if opts.Psiphon != nil && opts.Gool { return errors.New("can't use psiphon and gool at the same time") } @@ -50,16 +49,16 @@ func RunWarp(ctx context.Context, opts WarpOptions) error { if err := makeDirs(); err != nil { return err } - log.Println("'primary' and 'secondary' directories are ready") + l.Debug("'primary' and 'secondary' directories are ready") // Change the current working directory to 'stuff' if err := os.Chdir("stuff"); err != nil { return fmt.Errorf("error changing to 'stuff' directory: %w", err) } - log.Println("Changed working directory to 'stuff'") + l.Debug("Changed working directory to 'stuff'") // create identities - if err := createPrimaryAndSecondaryIdentities(opts.License); err != nil { + if err := createPrimaryAndSecondaryIdentities(l.With("subsystem", "warp/account"), opts.License); err != nil { return err } @@ -72,90 +71,110 @@ func RunWarp(ctx context.Context, opts WarpOptions) error { return err } - log.Printf("scan results: %+v", res) + l.Info("scan results", "endpoints", res) endpoints = make([]string, len(res)) for i := 0; i < len(res); i++ { endpoints[i] = res[i].AddrPort.String() } } - log.Printf("using warp endpoints: %+v", endpoints) + l.Info("using warp endpoints", "endpoints", endpoints) var warpErr error switch { case opts.Psiphon != nil: + l.Info("running in Psiphon (cfon) mode") // run primary warp on a random tcp port and run psiphon on bind address - warpErr = runWarpWithPsiphon(ctx, opts.Bind, endpoints, opts.Psiphon.Country, opts.LogLevel == "debug") + warpErr = runWarpWithPsiphon(ctx, l, opts.Bind, endpoints[0], opts.Psiphon.Country) case opts.Gool: + l.Info("running in warp-in-warp (gool) mode") // run warp in warp - warpErr = runWarpInWarp(ctx, opts.Bind, endpoints, opts.LogLevel == "debug") + warpErr = runWarpInWarp(ctx, l, opts.Bind, endpoints) default: + l.Info("running in normal warp mode") // just run primary warp on bindAddress - _, warpErr = runWarp(ctx, opts.Bind, endpoints, "./primary/wgcf-profile.ini", opts.LogLevel == "debug", true, true, singleMTU) + warpErr = runWarp(ctx, l, opts.Bind, endpoints[0]) } return warpErr } -func runWarp(ctx context.Context, bind netip.AddrPort, endpoints []string, confPath string, verbose, startProxy bool, trick bool, mtu int) (*wiresocks.VirtualTun, error) { - conf, err := wiresocks.ParseConfig(confPath, endpoints[0]) +func runWarp(ctx context.Context, l *slog.Logger, bind netip.AddrPort, endpoint string) error { + conf, err := wiresocks.ParseConfig("./primary/wgcf-profile.ini", endpoint) if err != nil { - log.Println(err) - return nil, err + return err } - conf.Interface.MTU = mtu + conf.Interface.MTU = singleMTU for i, peer := range conf.Peers { - peer.KeepAlive = 10 - if trick { - peer.Trick = true - peer.KeepAlive = 3 - } - + peer.Trick = true + peer.KeepAlive = 3 conf.Peers[i] = peer } - tnet, err := wiresocks.StartWireguard(ctx, conf, verbose) + tnet, err := wiresocks.StartWireguard(ctx, l, conf) if err != nil { - log.Println(err) - return nil, err + return err } - if startProxy { - tnet.StartProxy(bind) - } + tnet.StartProxy(bind) + l.Info("Serving proxy", "address", bind) - return tnet, nil + return nil } -func runWarpWithPsiphon(ctx context.Context, bind netip.AddrPort, endpoints []string, country string, verbose bool) error { +func runWarpWithPsiphon(ctx context.Context, l *slog.Logger, bind netip.AddrPort, endpoint string, country string) error { // make a random bind address for warp warpBindAddress, err := findFreePort("tcp") if err != nil { - log.Println("There are no free tcp ports on Device!") return err } - _, err = runWarp(ctx, warpBindAddress, endpoints, "./primary/wgcf-profile.ini", verbose, true, true, singleMTU) + conf, err := wiresocks.ParseConfig("./primary/wgcf-profile.ini", endpoint) + if err != nil { + return err + } + conf.Interface.MTU = singleMTU + + for i, peer := range conf.Peers { + peer.Trick = true + peer.KeepAlive = 3 + conf.Peers[i] = peer + } + + tnet, err := wiresocks.StartWireguard(ctx, l, conf) if err != nil { return err } + tnet.StartProxy(warpBindAddress) + // run psiphon - err = psiphon.RunPsiphon(warpBindAddress.String(), bind.String(), country, ctx) + err = psiphon.RunPsiphon(ctx, l.With("subsystem", "psiphon"), warpBindAddress.String(), bind.String(), country) if err != nil { - log.Printf("unable to run psiphon %v", err) return fmt.Errorf("unable to run psiphon %w", err) } - log.Printf("Serving on %s", bind) + l.Info("Serving proxy", "address", bind) return nil } -func runWarpInWarp(ctx context.Context, bind netip.AddrPort, endpoints []string, verbose bool) error { +func runWarpInWarp(ctx context.Context, l *slog.Logger, bind netip.AddrPort, endpoints []string) error { // Run outer warp - vTUN, err := runWarp(ctx, netip.AddrPort{}, endpoints, "./secondary/wgcf-profile.ini", verbose, false, true, singleMTU) + conf, err := wiresocks.ParseConfig("./primary/wgcf-profile.ini", endpoints[0]) + if err != nil { + return err + } + conf.Interface.MTU = singleMTU + + for i, peer := range conf.Peers { + peer.Trick = true + peer.KeepAlive = 3 + conf.Peers[i] = peer + } + + tnet, err := wiresocks.StartWireguard(ctx, l.With("gool", "outer"), conf) if err != nil { return err } @@ -163,21 +182,35 @@ func runWarpInWarp(ctx context.Context, bind netip.AddrPort, endpoints []string, // Run virtual endpoint virtualEndpointBindAddress, err := findFreePort("udp") if err != nil { - log.Println("There are no free udp ports on Device!") return err } - addr := endpoints[1] - err = wiresocks.NewVtunUDPForwarder(virtualEndpointBindAddress.String(), addr, vTUN, singleMTU, ctx) + + // Create a UDP port forward between localhost and the remote endpoint + err = wiresocks.NewVtunUDPForwarder(ctx, virtualEndpointBindAddress.String(), endpoints[1], tnet, singleMTU) if err != nil { - log.Println(err) return err } // Run inner warp - _, err = runWarp(ctx, bind, []string{virtualEndpointBindAddress.String()}, "./primary/wgcf-profile.ini", verbose, true, false, doubleMTU) + conf, err = wiresocks.ParseConfig("./secondary/wgcf-profile.ini", virtualEndpointBindAddress.String()) + if err != nil { + return err + } + conf.Interface.MTU = doubleMTU + + for i, peer := range conf.Peers { + peer.KeepAlive = 10 + conf.Peers[i] = peer + } + + tnet, err = wiresocks.StartWireguard(ctx, l.With("gool", "inner"), conf) if err != nil { return err } + + tnet.StartProxy(bind) + + l.Info("Serving proxy", "address", bind) return nil } @@ -207,22 +240,20 @@ func findFreePort(network string) (netip.AddrPort, error) { return netip.MustParseAddrPort(listener.Addr().String()), nil } -func createPrimaryAndSecondaryIdentities(license string) error { +func createPrimaryAndSecondaryIdentities(l *slog.Logger, license string) error { // make primary identity warp.UpdatePath("./primary") if !warp.CheckProfileExists(license) { - err := warp.LoadOrCreateIdentity(license) + err := warp.LoadOrCreateIdentity(l, license) if err != nil { - log.Printf("error: %v", err) return err } } // make secondary warp.UpdatePath("./secondary") if !warp.CheckProfileExists(license) { - err := warp.LoadOrCreateIdentity(license) + err := warp.LoadOrCreateIdentity(l, license) if err != nil { - log.Printf("error: %v", err) return err } } diff --git a/ipscanner/internal/iterator/iterator.go b/ipscanner/internal/iterator/iterator.go index c49c97069..83afa19d1 100644 --- a/ipscanner/internal/iterator/iterator.go +++ b/ipscanner/internal/iterator/iterator.go @@ -3,8 +3,6 @@ package iterator import ( "crypto/rand" "errors" - "fmt" - "log" "math/big" "net" "net/netip" @@ -262,17 +260,18 @@ func NewIterator(opts *statute.ScannerOptions) *IpGenerator { ipRange, err := newIPRange(cidr) if err != nil { - fmt.Printf("Error parsing CIDR %s: %v\n", cidr, err) + // TODO continue } ranges = append(ranges, ipRange) } if len(ranges) == 0 { - log.Fatal("No valid CIDR ranges found") + // TODO + return nil } err := shuffleSubnetsIpRange(ranges) if err != nil { - fmt.Println(err) + // TODO return nil } return &IpGenerator{ diff --git a/ipscanner/internal/statute/default.go b/ipscanner/internal/statute/default.go index 24265a4a4..681961746 100644 --- a/ipscanner/internal/statute/default.go +++ b/ipscanner/internal/statute/default.go @@ -3,7 +3,6 @@ package statute import ( "context" "crypto/tls" - "fmt" "net" "net/http" "net/netip" @@ -72,7 +71,6 @@ func DefaultHTTPClientFunc(rawDialer TDialerFunc, tlsDialer TDialerFunc, quicDia } func DefaultDialerFunc(ctx context.Context, network, addr string) (net.Conn, error) { - fmt.Println(addr) d := &net.Dialer{ Timeout: FinalOptions.ConnectionTimeout, // Connection timeout // Add other custom settings as needed diff --git a/main.go b/main.go index a3a20f2ff..676b413ef 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "log" + "log/slog" "net/netip" "os" "os/signal" @@ -76,13 +76,19 @@ func main() { os.Exit(1) } + l := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) + + if *verbose { + l = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + } + if *psiphon && *gool { - log.Fatal(errors.New("can't use cfon and gool at the same time")) + fatal(l, errors.New("can't use cfon and gool at the same time")) } bindAddrPort, err := netip.ParseAddrPort(*bind) if err != nil { - log.Fatal(fmt.Errorf("invalid bind address: %w", err)) + fatal(l, fmt.Errorf("invalid bind address: %w", err)) } opts := app.WarpOptions{ @@ -92,18 +98,13 @@ func main() { Gool: *gool, } - if *verbose { - opts.LogLevel = "debug" - log.Printf("setting log level to: %s", opts.LogLevel) - } - if *psiphon { - log.Printf("psiphon mode enabled, using country %s", *country) + l.Info("psiphon mode enabled", "country", *country) opts.Psiphon = &app.PsiphonOptions{Country: *country} } if *scan { - log.Printf("scanner mode enabled, using %s max RTT", rtt) + l.Info("scanner mode enabled", "max-rtt", rtt) opts.Scan = &app.ScanOptions{MaxRTT: *rtt} } @@ -111,17 +112,22 @@ func main() { if opts.Endpoint == "" { addrPort, err := warp.RandomWarpEndpoint() if err != nil { - log.Fatal(err) + fatal(l, err) } opts.Endpoint = addrPort.String() } ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) go func() { - if err := app.RunWarp(ctx, opts); err != nil { - log.Fatal(err) + if err := app.RunWarp(ctx, l, opts); err != nil { + fatal(l, err) } }() <-ctx.Done() } + +func fatal(l *slog.Logger, err error) { + l.Error(err.Error()) + os.Exit(1) +} diff --git a/psiphon/p.go b/psiphon/p.go index 3f2d6759d..610d03a66 100644 --- a/psiphon/p.go +++ b/psiphon/p.go @@ -5,7 +5,7 @@ import ( "encoding/json" "errors" "fmt" - "log" + "log/slog" "net" "path/filepath" "strings" @@ -316,7 +316,7 @@ func (tunnel *Tunnel) Stop() { psiphon.CloseDataStore() } -func RunPsiphon(wgBind, localSocksPort, country string, ctx context.Context) error { +func RunPsiphon(ctx context.Context, l *slog.Logger, wgBind, localSocksPort, country string) error { // Embedded configuration host, port, err := net.SplitHostPort(localSocksPort) if err != nil { @@ -355,33 +355,28 @@ func RunPsiphon(wgBind, localSocksPort, country string, ctx context.Context) err EmitDiagnosticNoticesToFiles: false, } - log.Println("Handshaking, Please Wait...") + l.Info("Handshaking, Please Wait...") - var tunnel *Tunnel - startTime := time.Now() - - internalCtx := context.Background() - - timeoutTimer := time.NewTimer(2 * time.Minute) - defer timeoutTimer.Stop() + ctx, _ = context.WithTimeout(ctx, 2*time.Minute) + t0 := time.Now() + t := time.NewTicker(1 * time.Second) + defer t.Stop() for { select { case <-ctx.Done(): - internalCtx.Done() - return errors.New("psiphon handshake operation canceled by user") - case <-timeoutTimer.C: - // Handle the internal timeout - internalCtx.Done() + if errors.Is(ctx.Err(), context.Canceled) { + return errors.New("psiphon handshake operation canceled") + } return errors.New("psiphon handshake maximum time exceeded") - default: - tunnel, err = StartTunnel(internalCtx, []byte(configJSON), "", p, nil, nil) - if err == nil { - log.Println("Psiphon started successfully on port", tunnel.SOCKSProxyPort, "handshake operation took", int64(time.Since(startTime)/time.Millisecond), "milliseconds") - return nil + case <-t.C: + tunnel, err := StartTunnel(ctx, []byte(configJSON), "", p, nil, nil) + if err != nil { + l.Info("Unable to start psiphon", err, "reconnecting...") + continue } - log.Println("Unable to start psiphon", err, "reconnecting...") - time.Sleep(1 * time.Second) + l.Info(fmt.Sprintf("Psiphon started successfully on port %d, handshake operation took %s", tunnel.SOCKSProxyPort, time.Since(t0))) + return nil } } } diff --git a/warp/account.go b/warp/account.go index 7a9a26389..e5808a8af 100644 --- a/warp/account.go +++ b/warp/account.go @@ -7,11 +7,10 @@ import ( "errors" "fmt" "io" - "log" + "log/slog" "net" "net/http" "os" - "path/filepath" "time" ) @@ -120,7 +119,6 @@ func genKeyPair() (string, string, error) { // Generate private key priv, err := GeneratePrivateKey() if err != nil { - fmt.Println("Error generating private key:", err) return "", "", err } privateKey := priv.String() @@ -163,14 +161,12 @@ func doRegister() (*AccountData, error) { // Create HTTP client and execute request response, err := client.Do(req) if err != nil { - fmt.Println("sending request to remote server", err) return nil, err } // convert response to byte array responseData, err := io.ReadAll(response.Body) if err != nil { - fmt.Println("reading response body", err) return nil, err } @@ -178,7 +174,6 @@ func doRegister() (*AccountData, error) { err = json.Unmarshal(responseData, &rspData) if err != nil { - fmt.Println("Error:", err) return nil, err } @@ -195,7 +190,6 @@ func doRegister() (*AccountData, error) { func saveIdentity(accountData *AccountData, identityPath string) error { file, err := os.Create(identityPath) if err != nil { - fmt.Println("Error:", err) return err } @@ -203,7 +197,6 @@ func saveIdentity(accountData *AccountData, identityPath string) error { encoder.SetIndent("", " ") err = encoder.Encode(accountData) if err != nil { - fmt.Println("Error:", err) return err } @@ -213,22 +206,14 @@ func saveIdentity(accountData *AccountData, identityPath string) error { func loadIdentity(identityPath string) (accountData *AccountData, err error) { file, err := os.Open(identityPath) if err != nil { - fmt.Println("Error:", err) return nil, err } - - defer func(file *os.File) { - err = file.Close() - if err != nil { - fmt.Println("Error:", err) - } - }(file) + defer file.Close() accountData = &AccountData{} decoder := json.NewDecoder(file) err = decoder.Decode(&accountData) if err != nil { - fmt.Println("Error:", err) return nil, err } @@ -510,11 +495,11 @@ func createConf(accountData *AccountData, confData *ConfigurationData) error { return os.WriteFile(profileFile, []byte(config), 0o600) } -func LoadOrCreateIdentity(license string) error { +func LoadOrCreateIdentity(l *slog.Logger, license string) error { var accountData *AccountData if _, err := os.Stat(identityFile); os.IsNotExist(err) { - fmt.Println("Creating new identity...") + l.Info("creating new identity") accountData, err = doRegister() if err != nil { return err @@ -522,21 +507,21 @@ func LoadOrCreateIdentity(license string) error { accountData.LicenseKey = license saveIdentity(accountData, identityFile) } else { - fmt.Println("Loading existing identity...") + l.Info("loading existing identity") accountData, err = loadIdentity(identityFile) if err != nil { return err } } - fmt.Println("Getting configuration...") + l.Info("getting server configuration") confData, err := getServerConf(accountData) if err != nil { return err } // updating license key - fmt.Println("Updating account license key...") + l.Info("updating account license key") result, err := updateLicenseKey(accountData, confData) if err != nil { return err @@ -553,16 +538,16 @@ func LoadOrCreateIdentity(license string) error { return err } if !deviceStatus { - fmt.Println("This device is not registered to the account!") + l.Warn("device is not registered to the account") } if confData.WarpPlusEnabled && !deviceStatus { - fmt.Println("Enabling device...") + l.Info("enabling device") deviceStatus, _ = setDeviceActive(accountData, true) } if !confData.WarpEnabled { - fmt.Println("Enabling Warp...") + l.Info("enabling Warp") err := enableWarp(accountData) if err != nil { return err @@ -570,20 +555,19 @@ func LoadOrCreateIdentity(license string) error { confData.WarpEnabled = true } - fmt.Printf("Warp+ enabled: %t\n", confData.WarpPlusEnabled) - fmt.Printf("Device activated: %t\n", deviceStatus) - fmt.Printf("Account type: %s\n", confData.AccountType) - fmt.Printf("Warp+ enabled: %t\n", confData.WarpPlusEnabled) - - fmt.Println("Creating WireGuard configuration...") + l.Info( + "Creating WireGuard configuration", + "device-active", deviceStatus, + "account-type", confData.AccountType, + "warp", confData.WarpEnabled, + "warp+", confData.WarpPlusEnabled, + ) err = createConf(accountData, confData) if err != nil { return fmt.Errorf("unable to enable write config file: %w", err) } - fmt.Println("All done! Find your files here:") - fmt.Println(filepath.Abs(identityFile)) - fmt.Println(filepath.Abs(profileFile)) + l.Info("successfully generated wireguard configuration") return nil } @@ -596,10 +580,7 @@ func fileExist(f string) bool { func removeFile(f string) { if fileExist(f) { - e := os.Remove(f) - if e != nil { - log.Fatal(e) - } + _ = os.Remove(f) } } @@ -631,7 +612,7 @@ func CheckProfileExists(license string) bool { return isOk } -func RemoveDevice(account AccountData) error { +func RemoveDevice(l *slog.Logger, account AccountData) error { headers := map[string]string{ "Content-Type": "application/json", "User-Agent": "okhttp/3.12.1", @@ -652,7 +633,7 @@ func RemoveDevice(account AccountData) error { // Create HTTP client and execute request response, err := client.Do(req) if err != nil { - fmt.Println("sending request to remote server", err) + l.Info("sending request to remote server", err) return err } diff --git a/warp/tls.go b/warp/tls.go index fa34096c4..b69c62c83 100644 --- a/warp/tls.go +++ b/warp/tls.go @@ -190,7 +190,6 @@ func (d *Dialer) TLSDial(plainDialer *net.Dialer, network, addr string) (net.Con utlsConn, handshakeErr := d.makeTLSHelloPacketWithSNICurve(plainConn, &config, sni) if handshakeErr != nil { _ = plainConn.Close() - fmt.Println(handshakeErr) return nil, handshakeErr } return utlsConn, nil diff --git a/wireguard/device/logger.go b/wireguard/device/logger.go index 22b0df028..9fc1f6ce7 100644 --- a/wireguard/device/logger.go +++ b/wireguard/device/logger.go @@ -6,7 +6,9 @@ package device import ( + "fmt" "log" + "log/slog" "os" ) @@ -46,3 +48,14 @@ func NewLogger(level int, prepend string) *Logger { } return logger } + +func NewSLogger(l *slog.Logger) *Logger { + return &Logger{ + Verbosef: func(format string, v ...any) { + l.Debug(fmt.Sprintf(format, v...)) + }, + Errorf: func(format string, v ...any) { + l.Error(fmt.Sprintf(format, v...)) + }, + } +} diff --git a/wiresocks/proxy.go b/wiresocks/proxy.go index e7baa1e98..99c58515a 100644 --- a/wiresocks/proxy.go +++ b/wiresocks/proxy.go @@ -3,7 +3,7 @@ package wiresocks import ( "context" "io" - "log" + "log/slog" "net/netip" "time" @@ -17,31 +17,17 @@ import ( type VirtualTun struct { Tnet *netstack.Net SystemDNS bool - Verbose bool - Logger DefaultLogger + Logger *slog.Logger Dev *device.Device Ctx context.Context } -type DefaultLogger struct { - verbose bool -} - -func (l DefaultLogger) Debug(v ...interface{}) { - if l.verbose { - log.Println(v...) - } -} - -func (l DefaultLogger) Error(v ...interface{}) { - log.Println(v...) -} - // StartProxy spawns a socks5 server. func (vt *VirtualTun) StartProxy(bindAddress netip.AddrPort) { proxy := mixed.NewProxy( mixed.WithBindAddress(bindAddress.String()), - mixed.WithLogger(vt.Logger), + // TODO + // mixed.WithLogger(vt.Logger), mixed.WithContext(vt.Ctx), mixed.WithUserHandler(func(request *statute.ProxyRequest) error { return vt.generalHandler(request) @@ -64,9 +50,7 @@ func (vt *VirtualTun) StartProxy(bindAddress netip.AddrPort) { } func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error { - if vt.Verbose { - log.Printf("handling %s request to %s", req.Network, req.Destination) - } + vt.Logger.Debug("handling request", "protocol", req.Network, "destination", req.Destination) conn, err := vt.Tnet.Dial(req.Network, req.Destination) if err != nil { return err @@ -89,8 +73,9 @@ func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error { // Wait for one of the copy operations to finish err = <-done if err != nil { - log.Println(err) + vt.Logger.Warn(err.Error()) } + // Close connections and wait for the other copy operation to finish conn.Close() req.Conn.Close() @@ -101,9 +86,8 @@ func (vt *VirtualTun) generalHandler(req *statute.ProxyRequest) error { func (vt *VirtualTun) Stop() { if vt.Dev != nil { - err := vt.Dev.Down() - if err != nil { - log.Println(err) + if err := vt.Dev.Down(); err != nil { + vt.Logger.Warn(err.Error()) } } } diff --git a/wiresocks/udpfw.go b/wiresocks/udpfw.go index 20dd8193d..fcd44d267 100644 --- a/wiresocks/udpfw.go +++ b/wiresocks/udpfw.go @@ -2,24 +2,11 @@ package wiresocks import ( "context" - "encoding/binary" - "errors" - "fmt" - "io" "net" "sync" ) -type Socks5UDPForwarder struct { - socks5Server string - destAddr *net.UDPAddr - proxyUDPAddr *net.UDPAddr - conn *net.UDPConn - listener *net.UDPConn - clientAddr *net.UDPAddr -} - -func NewVtunUDPForwarder(localBind, dest string, vtun *VirtualTun, mtu int, ctx context.Context) error { +func NewVtunUDPForwarder(ctx context.Context, localBind, dest string, vtun *VirtualTun, mtu int) error { localAddr, err := net.ResolveUDPAddr("udp", localBind) if err != nil { return err @@ -88,144 +75,3 @@ func NewVtunUDPForwarder(localBind, dest string, vtun *VirtualTun, mtu int, ctx }() return nil } - -func NewSocks5UDPForwarder(localBind, socks5Server, dest string) (*Socks5UDPForwarder, error) { - localAddr, err := net.ResolveUDPAddr("udp", localBind) - if err != nil { - return nil, err - } - - destAddr, err := net.ResolveUDPAddr("udp", dest) - if err != nil { - return nil, err - } - - listener, err := net.ListenUDP("udp", localAddr) - if err != nil { - return nil, err - } - - tcpConn, err := net.Dial("tcp", socks5Server) - if err != nil { - return nil, err - } - defer tcpConn.Close() - - if err := socks5Handshake(tcpConn); err != nil { - return nil, err - } - - proxyUDPAddr, err := requestUDPAssociate(tcpConn) - if err != nil { - return nil, err - } - - udpConn, err := net.DialUDP("udp", nil, proxyUDPAddr) - if err != nil { - return nil, err - } - - return &Socks5UDPForwarder{ - socks5Server: socks5Server, - destAddr: destAddr, - proxyUDPAddr: proxyUDPAddr, - conn: udpConn, - listener: listener, - }, nil -} - -func (f *Socks5UDPForwarder) Start() { - go f.listenAndServe() - go f.receiveFromProxy() -} - -func socks5Handshake(conn net.Conn) error { - // Send greeting - _, err := conn.Write([]byte{0x05, 0x01, 0x00}) // SOCKS5, 1 authentication method, No authentication - if err != nil { - return err - } - - // Receive server response - resp := make([]byte, 2) - if _, err := io.ReadFull(conn, resp); err != nil { - return err - } - - if resp[0] != 0x05 || resp[1] != 0x00 { - return errors.New("invalid SOCKS5 authentication response") - } - return nil -} - -func (f *Socks5UDPForwarder) listenAndServe() { - for { - buffer := make([]byte, 4096) - // Listen for incoming UDP packets - n, clientAddr, err := f.listener.ReadFromUDP(buffer) - if err != nil { - fmt.Printf("Error reading from listener: %v\n", err) - continue - } - - // Store client address for response mapping - f.clientAddr = clientAddr - - // Forward packet to destination via SOCKS5 proxy - go f.forwardPacketToRemote(buffer[:n]) - } -} - -func (f *Socks5UDPForwarder) forwardPacketToRemote(data []byte) { - packet := make([]byte, 10+len(data)) - packet[0] = 0x00 // Reserved - packet[1] = 0x00 // Reserved - packet[2] = 0x00 // Fragment - packet[3] = 0x01 // Address type (IPv4) - copy(packet[4:8], f.destAddr.IP.To4()) - binary.BigEndian.PutUint16(packet[8:10], uint16(f.destAddr.Port)) - copy(packet[10:], data) - - _, err := f.conn.Write(packet) - if err != nil { - fmt.Printf("Error forwarding packet to remote: %v\n", err) - } -} - -func (f *Socks5UDPForwarder) receiveFromProxy() { - for { - buffer := make([]byte, 4096) - n, err := f.conn.Read(buffer) - if err != nil { - fmt.Printf("Error reading from proxy connection: %v\n", err) - continue - } - - // Forward the packet to the original client - f.listener.WriteToUDP(buffer[10:n], f.clientAddr) - } -} - -func requestUDPAssociate(conn net.Conn) (*net.UDPAddr, error) { - // Send UDP associate request with local address and port set to zero - req := []byte{0x05, 0x03, 0x00, 0x01, 0, 0, 0, 0, 0, 0} // Command: UDP Associate - if _, err := conn.Write(req); err != nil { - return nil, err - } - - // Receive response - resp := make([]byte, 10) - if _, err := io.ReadFull(conn, resp); err != nil { - return nil, err - } - - if resp[1] != 0x00 { - return nil, errors.New("UDP ASSOCIATE request failed") - } - - // Parse the proxy UDP address - bindIP := net.IP(resp[4:8]) - bindPort := binary.BigEndian.Uint16(resp[8:10]) - - return &net.UDPAddr{IP: bindIP, Port: int(bindPort)}, nil -} diff --git a/wiresocks/wiresocks.go b/wiresocks/wiresocks.go index 7327b419a..75f5d641f 100644 --- a/wiresocks/wiresocks.go +++ b/wiresocks/wiresocks.go @@ -4,23 +4,15 @@ import ( "bytes" "context" "fmt" - "net/netip" + "log/slog" "github.com/bepass-org/warp-plus/wireguard/conn" "github.com/bepass-org/warp-plus/wireguard/device" "github.com/bepass-org/warp-plus/wireguard/tun/netstack" ) -// DeviceSetting contains the parameters for setting up a tun interface -type DeviceSetting struct { - ipcRequest string - dns []netip.Addr - deviceAddr []netip.Addr - mtu int -} - -// serialize the config into an IPC request and DeviceSetting -func createIPCRequest(conf *Configuration) (*DeviceSetting, error) { +// StartWireguard creates a tun interface on netstack given a configuration +func StartWireguard(ctx context.Context, l *slog.Logger, conf *Configuration) (*VirtualTun, error) { var request bytes.Buffer request.WriteString(fmt.Sprintf("private_key=%s\n", conf.Interface.PrivateKey)) @@ -37,29 +29,13 @@ func createIPCRequest(conf *Configuration) (*DeviceSetting, error) { } } - setting := &DeviceSetting{ipcRequest: request.String(), dns: conf.Interface.DNS, deviceAddr: conf.Interface.Addresses, mtu: conf.Interface.MTU} - return setting, nil -} - -// StartWireguard creates a tun interface on netstack given a configuration -func StartWireguard(ctx context.Context, conf *Configuration, verbose bool) (*VirtualTun, error) { - setting, err := createIPCRequest(conf) - if err != nil { - return nil, err - } - - tun, tnet, err := netstack.CreateNetTUN(setting.deviceAddr, setting.dns, setting.mtu) + tun, tnet, err := netstack.CreateNetTUN(conf.Interface.Addresses, conf.Interface.DNS, conf.Interface.MTU) if err != nil { return nil, err } - logLevel := device.LogLevelVerbose - if !verbose { - logLevel = device.LogLevelSilent - } - - dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(logLevel, "")) - err = dev.IpcSet(setting.ipcRequest) + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewSLogger(l.With("subsystem", "wireguard-go"))) + err = dev.IpcSet(request.String()) if err != nil { return nil, err } @@ -71,12 +47,9 @@ func StartWireguard(ctx context.Context, conf *Configuration, verbose bool) (*Vi return &VirtualTun{ Tnet: tnet, - SystemDNS: len(setting.dns) == 0, - Verbose: verbose, - Logger: DefaultLogger{ - verbose: verbose, - }, - Dev: dev, - Ctx: ctx, + SystemDNS: len(conf.Interface.DNS) == 0, + Logger: l.With("subsystem", "vtun"), + Dev: dev, + Ctx: ctx, }, nil }