From 252baf5b8b82a87eb50c8e0915b5646821ca7988 Mon Sep 17 00:00:00 2001 From: Roberto Santalla Date: Thu, 7 Sep 2023 14:18:59 +0200 Subject: [PATCH] poc: tcp proxy --- pkg/agent/protocol/tcp/handler.go | 67 ++++++++ pkg/agent/protocol/tcp/proxy.go | 90 +++++++++++ pkg/agent/protocol/tcp/proxy_test.go | 224 +++++++++++++++++++++++++++ 3 files changed, 381 insertions(+) create mode 100644 pkg/agent/protocol/tcp/handler.go create mode 100644 pkg/agent/protocol/tcp/proxy.go create mode 100644 pkg/agent/protocol/tcp/proxy_test.go diff --git a/pkg/agent/protocol/tcp/handler.go b/pkg/agent/protocol/tcp/handler.go new file mode 100644 index 00000000..c1a25420 --- /dev/null +++ b/pkg/agent/protocol/tcp/handler.go @@ -0,0 +1,67 @@ +package tcp + +import ( + "errors" + "hash/crc64" + "io" + "net" + "time" +) + +// ConnMeta holds metadata about a TCP connection. +type ConnMeta struct { + Opened time.Time + ClientAddress net.Addr + ServerAddress net.Addr +} + +// Hash returns a semi-unique number to every connection. +// The implementation of Hash is not guaranteed to be stable between updates of this package. +func (c ConnMeta) Hash() uint64 { + // We use CRC64 as this hash does not need to be cryptographically secure, and it's easy to get an uint64 from it. + hash := crc64.New(crc64.MakeTable(crc64.ISO)) + _, _ = hash.Write([]byte(c.Opened.String())) + _, _ = hash.Write([]byte(c.ClientAddress.String())) + _, _ = hash.Write([]byte(c.ServerAddress.String())) + + return hash.Sum64() +} + +// Handler is an object capable of acting when TCP messages are either sent or received. +type Handler interface { + // HandleUpward forwards data from the client to the server. Proxy will call HandleUpward once for every + // connection, expecting it to keep consuming data until an error occurs, in which case the Proxy will close both + // upstream and downstream connections. If ErrTerminate is returned, the connection is still closed but no error + // message is logged. + HandleUpward(client io.Reader, server io.Writer, meta ConnMeta) error + // HandleDownward provides is the equivalent of HandleUpward for data sent from the server to the client. + HandleDownward(server io.Reader, client io.Writer, meta ConnMeta) error +} + +// ErrTerminate may be returned by Handler implementations that wish to willingly terminate a connection. Connection +// will be closed, but no error log will be generated. +var ErrTerminate = errors.New("connection terminated by proxy handler") + +// ForwardHandler is a handler that forwards data between client and server without taking any actions. +type ForwardHandler struct{} + +func (ForwardHandler) HandleUpward(client io.Reader, server io.Writer, _ ConnMeta) error { + _, err := io.Copy(server, client) + return err +} + +func (ForwardHandler) HandleDownward(server io.Reader, client io.Writer, _ ConnMeta) error { + _, err := io.Copy(client, server) + return err +} + +// RejectHandler is a handler that closes connections immediately after being opened. +type RejectHandler struct{} + +func (RejectHandler) HandleUpward(client io.Reader, server io.Writer, _ ConnMeta) error { + return ErrTerminate +} + +func (RejectHandler) HandleDownward(server io.Reader, client io.Writer, _ ConnMeta) error { + return ErrTerminate +} diff --git a/pkg/agent/protocol/tcp/proxy.go b/pkg/agent/protocol/tcp/proxy.go new file mode 100644 index 00000000..d364c27d --- /dev/null +++ b/pkg/agent/protocol/tcp/proxy.go @@ -0,0 +1,90 @@ +package tcp + +import ( + "errors" + "fmt" + "log" + "net" + "time" +) + +// Proxy implements a TCP transparent proxy between a client and a server. +type Proxy struct { + l net.Listener + upstream net.Addr + handler Handler +} + +func NewProxy(l net.Listener, upstream net.Addr, handler Handler) *Proxy { + return &Proxy{ + l: l, + upstream: upstream, + handler: handler, + } +} + +func (p *Proxy) Start() error { + for { + conn, err := p.l.Accept() + if err != nil { + return err + } + + go func() { + err := p.handleConn(conn) + // TODO: Better error handling + log.Printf("handling connection: %v", err) + }() + } +} + +func (p *Proxy) Stop() error { + // TODO: Harvest open connections and close them. + return nil +} + +func (p *Proxy) handleConn(downstreamConn net.Conn) error { + defer func() { + _ = downstreamConn.Close() + }() + + upstreamConn, err := net.Dial("tcp", p.upstream.String()) + if err != nil { + return fmt.Errorf("opening upstream connection: %w", err) + } + + defer func() { + _ = upstreamConn.Close() + }() + + metadata := ConnMeta{ + Opened: time.Now(), + ClientAddress: downstreamConn.RemoteAddr(), + ServerAddress: upstreamConn.RemoteAddr(), + } + + errChan := make(chan error, 2) + go func() { + errChan <- func() error { + err := p.handler.HandleUpward(downstreamConn, upstreamConn, metadata) + if err != nil && !errors.Is(err, ErrTerminate) { + return err + } + + return nil + }() + }() + go func() { + errChan <- func() error { + err := p.handler.HandleDownward(upstreamConn, downstreamConn, metadata) + if err != nil && !errors.Is(err, ErrTerminate) { + return err + } + + return nil + }() + }() + + err = <-errChan + return fmt.Errorf("forwarding data: %w", err) +} diff --git a/pkg/agent/protocol/tcp/proxy_test.go b/pkg/agent/protocol/tcp/proxy_test.go new file mode 100644 index 00000000..a0f0c440 --- /dev/null +++ b/pkg/agent/protocol/tcp/proxy_test.go @@ -0,0 +1,224 @@ +package tcp_test + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/grafana/xk6-disruptor/pkg/agent/protocol/tcp" +) + +const localv4 = "127.0.0.1:0" + +// Test_Proxy_Forwards tests the tcp.Proxy using tcp.ForwardHandler, ensuring messages are forwarded to and from the +// proxy. +func Test_Proxy_Forwards(t *testing.T) { + t.Parallel() + + upstreamL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating upstream listener: %v", err) + } + + serverCh := make(chan string) + serverErr := make(chan error) + go func() { + serverErr <- echoServer(upstreamL, serverCh) + }() + + proxyL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating proxy listener: %v", err) + } + + proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.ForwardHandler{}) + go func() { + err := proxy.Start() + if err != nil { + // t.Fatal cannot be used inside a goroutine. + t.Errorf("couldn't start poxy: %v", err) + } + }() + + proxyConn, err := net.Dial("tcp", proxyL.Addr().String()) + if err != nil { + t.Fatalf("dialing proxy address: %v", err) + } + + bufReader := bufio.NewReader(proxyConn) + + // Write a first line. + _, err = fmt.Fprintln(proxyConn, "a line") + if err != nil { + t.Fatalf("writing to proxy conn: %v", err) + } + + // Check the server received the line + select { + case <-time.After(time.Second): + t.Fatalf("upstream did not receive the line before the deadline") + case serverLine := <-serverCh: + if serverLine != "a line\n" { + t.Fatalf("upstream received unexpected data %q", serverLine) + } + } + + // Check we received the echoed data + clientLine, err := bufReader.ReadString('\n') + if err != nil { + t.Fatalf("reading upstream response from proxyconn: %v", err) + } + if clientLine != "a line\n" { + t.Fatalf("downstream received unexpected data %q", clientLine) + } + + // Write a second line. + _, err = fmt.Fprintln(proxyConn, "another line") + if err != nil { + t.Fatalf("writing to proxy conn: %v", err) + } + + // Check the server received the line + select { + case <-time.After(time.Second): + t.Fatalf("upstream did not receive the line before the deadline") + case serverLine := <-serverCh: + if serverLine != "another line\n" { + t.Fatalf("upstream received unexpected data %q", serverLine) + } + } + + // Check we received the echoed data + clientLine, err = bufReader.ReadString('\n') + if err != nil { + t.Fatalf("reading upstream response from proxyconn: %v", err) + } + if clientLine != "another line\n" { + t.Fatalf("downstream received unexpected data %q", clientLine) + } + + // Close the connection to the proxy. + _ = proxyConn.Close() + + select { + case <-time.After(time.Second): + t.Fatalf("upstream connection was not closed") + case line, ok := <-serverCh: + if ok { + t.Fatalf("upstream receive unexpected data: %q", line) + } + } + + select { + case <-time.After(time.Second): + t.Fatalf("server did not terminate") + case err = <-serverErr: + if err != nil { + t.Fatalf("server returned an error: %v", err) + } + } +} + +// Test_Proxy_Forwards tests the tcp.Proxy using tcp.RejectHandler, ensuring both client and server connections are +// closed properly and cleanly when handlers return errors. +func Test_Proxy_Rejects(t *testing.T) { + t.Parallel() + + upstreamL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating upstream listener: %v", err) + } + + serverCh := make(chan string) + serverErr := make(chan error) + go func() { + serverErr <- echoServer(upstreamL, serverCh) + }() + + proxyL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating proxy listener: %v", err) + } + + proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.RejectHandler{}) + go func() { + err := proxy.Start() + if err != nil { + // t.Fatal cannot be used inside a goroutine. + t.Errorf("couldn't start poxy: %v", err) + } + }() + + proxyConn, err := net.Dial("tcp", proxyL.Addr().String()) + if err != nil { + t.Fatalf("dialing proxy address: %v", err) + } + + // Attempt to write a first line. + _, err = fmt.Fprintln(proxyConn, "a line") + if err != nil { + t.Fatalf("error writing data: %v", err) + } + + singleByte := make([]byte, 1) + _, err = proxyConn.Read(singleByte) + if err == nil { + t.Fatalf("expected connection to be closed by rejectHandler: %v", err) + } + + select { + case <-time.After(time.Second): + t.Fatalf("upstream connection was not closed") + case line, ok := <-serverCh: + if ok { + t.Fatalf("upstream receive unexpected data: %q", line) + } + } + + select { + case <-time.After(time.Second): + t.Fatalf("server did not terminate") + case err = <-serverErr: + if err != nil { + t.Fatalf("server returned an error: %v", err) + } + } +} + +// echoServer is a helper function for testing that accepts a single connection from the given listener, and pushes +// each received line to lineCh. When the connection is closed, it also closes lineCh. +func echoServer(l net.Listener, lineCh chan string) error { + defer close(lineCh) + + conn, err := l.Accept() + if err != nil { + return fmt.Errorf("accepting conn: %w", err) + } + + reader := bufio.NewReader(conn) + for { + line, err := reader.ReadString('\n') + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return fmt.Errorf("reading from conn: %w", err) + } + + _, err = conn.Write([]byte(line)) + if err != nil { + return fmt.Errorf("echoing back to conn: %w", err) + } + + select { + case lineCh <- line: + continue + case <-time.After(time.Second): + return fmt.Errorf("reader did not consume line %q", line) + } + } +}