diff --git a/.github/workflows/docker.yml b/.github/workflows/publish.yml similarity index 100% rename from .github/workflows/docker.yml rename to .github/workflows/publish.yml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..91a43eb7 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,26 @@ +name: Unit tests + +on: + push: + branches: + - master + pull_request: + +jobs: + test-go: + runs-on: ubuntu-latest + strategy: + matrix: + go: [ '1.18', '1.19' ] + name: Go ${{ matrix.go }} tests + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Go + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + cache: true + - name: Run go test + run: go test -v ./... + diff --git a/cmd/serve/main.go b/cmd/serve/main.go index ffa63a16..478743f0 100644 --- a/cmd/serve/main.go +++ b/cmd/serve/main.go @@ -184,9 +184,10 @@ func (cmd *servecmd) Run() { ListenAddr: listenAddr, }) if err != nil { - logrus.Error(errors.Wrap(err, "failed to start dns server")) + logrus.Error(errors.Wrap(err, "failed to create dns server")) return } + dns.ListenAndServe() defer dns.Close() if conf.DNS.Domain != "" { // Generate initial DNS zone for registered devices diff --git a/internal/dnsproxy/proxy.go b/internal/dnsproxy/proxy.go index c01fb33e..f7ff9e2e 100644 --- a/internal/dnsproxy/proxy.go +++ b/internal/dnsproxy/proxy.go @@ -13,11 +13,13 @@ import ( ) type DNSProxy struct { - client *dns.Client - cache *cache.Cache - upstream []string + udpClient *dns.Client + tcpClient *dns.Client + cache *cache.Cache + upstream []string } +// ServeDNS is called by the mux from the listening servers. func (d *DNSProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { defer func() { if err := recover(); err != nil { @@ -32,13 +34,17 @@ func (d *DNSProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { case dns.OpcodeQuery: // Remove EDNS0 Client Subnet information as we don't handle them in the cache purgeECS(r) - m, err := d.Lookup(r) + outQuery := r.Copy() + // Set EDNS BufSize for forwarding to upstream + ensureEDNS0BufSize(outQuery) + m, err := d.Lookup(outQuery) if err != nil { logrus.Errorf("failed lookup record with error: %s\n%s", err.Error(), r) HandleFailed(w, r) return } m.SetReply(r) + truncateIfRequired(m, r, w.RemoteAddr().Network()) err = w.WriteMsg(m) if err != nil { logrus.Errorf("failed write response for client with error: %s\n%s", err.Error(), r) @@ -56,13 +62,14 @@ func (d *DNSProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } +// Lookup first checks the cache for a matching response, and if unsuccessful queries the upstream resolvers. func (d *DNSProxy) Lookup(m *dns.Msg) (*dns.Msg, error) { key := makekey(m) // check the cache first if item, found := d.cache.Get(key); found { logrus.Debugf("dns cache hit %s", prettyPrintMsg(m)) - return item.(*dns.Msg), nil + return item.(*dns.Msg).Copy(), nil } // fallback to upstream exchange @@ -70,11 +77,21 @@ func (d *DNSProxy) Lookup(m *dns.Msg) (*dns.Msg, error) { var response *dns.Msg var firstErr error for _, upstream := range d.upstream { - resp, _, err := d.client.Exchange(m, net.JoinHostPort(upstream, "53")) + target := net.JoinHostPort(upstream, "53") + resp, _, err := d.udpClient.Exchange(m, target) if err != nil && firstErr == nil { logrus.Warnf(errors.Wrap(err, fmt.Sprintf("DNS lookup failed for upstream %s", upstream)).Error()) firstErr = err } else if err == nil { + // Retry truncated responses over TCP + if resp.Truncated { + resp, _, err = d.tcpClient.Exchange(m, target) + if err != nil && firstErr == nil { + logrus.Warnf(errors.Wrap(err, fmt.Sprintf("DNS lookup failed over TCP for upstream %s", upstream)).Error()) + firstErr = err + continue + } + } response = resp break } @@ -89,7 +106,7 @@ func (d *DNSProxy) Lookup(m *dns.Msg) (*dns.Msg, error) { d.cache.Set(key, response, ttl) } - return response, nil + return response.Copy(), nil } func purgeECS(m *dns.Msg) { @@ -101,3 +118,22 @@ func purgeECS(m *dns.Msg) { } } } + +func ensureEDNS0BufSize(m *dns.Msg) { + if opt := m.IsEdns0(); opt != nil { + opt.SetUDPSize(1232) + } else { + m.SetEdns0(1232, false) + } +} + +func truncateIfRequired(response *dns.Msg, original *dns.Msg, transport string) { + size := dns.MinMsgSize + if transport == "tcp" { + size = dns.MaxMsgSize + } else if opt := original.IsEdns0(); opt != nil { + size = int(opt.UDPSize()) + } + logrus.Debugf("truncating to %d", size) + response.Truncate(size) +} diff --git a/internal/dnsproxy/proxy_test.go b/internal/dnsproxy/proxy_test.go new file mode 100644 index 00000000..40fab589 --- /dev/null +++ b/internal/dnsproxy/proxy_test.go @@ -0,0 +1,58 @@ +package dnsproxy + +import ( + "context" + "net" + "testing" + "time" +) + +var ffmucUpstreams, _ = net.LookupHost("dns.ffmuc.net") + +func TestDNSProxy_ServeDNS(t *testing.T) { + const listen = "[::1]:8053" + + resolver := net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{Timeout: time.Second} + return d.DialContext(ctx, network, listen) + }, + } + + server, err := New(DNSServerOpts{ + Domain: "", + ListenAddr: []string{listen}, + Upstream: ffmucUpstreams, + }) + server.ListenAndServe() + defer func() { _ = server.Close() }() + + if err != nil { + t.Fatal(err) + } + + t.Run("Reply over 1300 bytes", func(t *testing.T) { + _, err := resolver.LookupTXT(context.Background(), "cloudflare.com.") + if err != nil { + t.Error(err) + return + } + }) + t.Run("Reply over 1500 bytes", func(t *testing.T) { + records, err := resolver.LookupTXT(context.Background(), "txtfill1500.go.dnscheck.tools.") + if err != nil { + t.Error(err) + return + } + var containsBigRecord bool + for _, r := range records { + if len(r) >= 1500 { + containsBigRecord = true + } + } + if !containsBigRecord { + t.Error("missing big TXT record, packet probably truncated") + } + }) +} diff --git a/internal/dnsproxy/server.go b/internal/dnsproxy/server.go index cec485e7..7fc20791 100644 --- a/internal/dnsproxy/server.go +++ b/internal/dnsproxy/server.go @@ -24,17 +24,22 @@ type DNSServer struct { auth *DNSAuth } +// New returns a pointer to a DNSServer configured using opts DNSServerOpts. +// The returned server needs to be started using DNSServer.ListenAndServe() func New(opts DNSServerOpts) (*DNSServer, error) { if len(opts.Upstream) == 0 { return nil, errors.New("at least 1 upstream dns server is required for the dns proxy server to function") } - logrus.Infof("starting dns server on %s with upstreams: %s", strings.Join(opts.ListenAddr, ", "), strings.Join(opts.Upstream, ", ")) - dnsServer := &DNSServer{ servers: []*dns.Server{}, proxy: &DNSProxy{ - client: &dns.Client{ + udpClient: &dns.Client{ + SingleInflight: true, + Timeout: 5 * time.Second, + }, + tcpClient: &dns.Client{ + Net: "tcp", SingleInflight: true, Timeout: 5 * time.Second, }, @@ -72,15 +77,39 @@ func New(opts DNSServerOpts) (*DNSServer, error) { dnsServer.servers = append(dnsServer.servers, tcpServer) } - for _, server := range dnsServer.servers { + return dnsServer, nil +} + +// ListenAndServe starts the DNSServer and waits until all listeners are up. +func (d *DNSServer) ListenAndServe() { + var sb strings.Builder + for i, s := range d.servers { + sb.WriteString(s.Addr) + sb.WriteString("/") + sb.WriteString(s.Net) + if i < len(d.servers)-1 { + sb.WriteString(", ") + } + } + + logrus.Infof("starting dns server on %s with upstreams: %s", sb.String(), strings.Join(d.proxy.upstream, ", ")) + + var wg sync.WaitGroup + + for _, server := range d.servers { + wg.Add(1) + server.NotifyStartedFunc = func() { + wg.Done() + } go func(server *dns.Server) { if err := server.ListenAndServe(); err != nil { logrus.Error(errors.Errorf("failed to start DNS server on %s/%s: %s", server.Addr, server.Net, err)) + wg.Done() } }(server) } - return dnsServer, nil + wg.Wait() } func (d *DNSServer) Close() error {