From babf86c064147ddf2a35c6dac8c6438d6d327ebd Mon Sep 17 00:00:00 2001 From: Nian Tang Date: Thu, 26 Sep 2024 08:57:20 +0800 Subject: [PATCH] retry for other region --- pkg/account/account.go | 3 + pkg/account/account_test.go | 5 +- pkg/connector/inet/inet.go | 4 +- pkg/proxy/proxy.go | 147 ++++++++++++++++++++++++++++-------- 4 files changed, 124 insertions(+), 35 deletions(-) diff --git a/pkg/account/account.go b/pkg/account/account.go index e038a49f..e69c6fd4 100644 --- a/pkg/account/account.go +++ b/pkg/account/account.go @@ -67,6 +67,7 @@ type Account struct { UserAgent string authHeader string Host string + Subject string client http.Client } @@ -74,6 +75,7 @@ type Account struct { type oauthPayload struct { Audiences []string `json:"aud"` OUCode string `json:"ou_code"` + Subject string `json:"sub"` } var domainRegEx = regexp.MustCompile(`^[A-Za-z0-9-.]+$`) // We're mostly interested in stopping paths; the http package handles the rest. @@ -136,6 +138,7 @@ func New(oauthToken, userAgent string) (*Account, error) { UserAgent: buildUserAgent(userAgent), authHeader: "Bearer " + strings.TrimSpace(oauthToken), Host: domain, + Subject: payload.Subject, }, nil } diff --git a/pkg/account/account_test.go b/pkg/account/account_test.go index 9a99d4ad..e6964562 100644 --- a/pkg/account/account_test.go +++ b/pkg/account/account_test.go @@ -67,7 +67,8 @@ func TestDomainExtraction(t *testing.T) { "https://fleet-api.prd.na.vn.cloud.tesla.com", "https://fleet-api.prd.eu.vn.cloud.tesla.com", }, - OUCode: "EU", + OUCode: "EU", + Subject: "SUBJECT", } acct, err := New(makeTestJWT(payload), "") @@ -75,7 +76,7 @@ func TestDomainExtraction(t *testing.T) { t.Fatalf("Returned error on valid JWT: %s", err) } expectedHost := "fleet-api.prd.eu.vn.cloud.tesla.com" - if acct == nil || acct.Host != expectedHost { + if acct == nil || acct.Host != expectedHost || acct.Subject != "SUBJECT" { t.Errorf("acct = %+v, expected Host = %s", acct, expectedHost) } } diff --git a/pkg/connector/inet/inet.go b/pkg/connector/inet/inet.go index c2e40b61..7b97cfec 100644 --- a/pkg/connector/inet/inet.go +++ b/pkg/connector/inet/inet.go @@ -21,7 +21,7 @@ import ( // MaxLatency is the default maximum latency permitted when updating the vehicle clock estimate. var MaxLatency = 10 * time.Second -func readWithContext(ctx context.Context, r io.Reader, p []byte) ([]byte, error) { +func ReadWithContext(ctx context.Context, r io.Reader, p []byte) ([]byte, error) { bytesRead := 0 for { if ctx.Err() != nil { @@ -108,7 +108,7 @@ func SendFleetAPICommand(ctx context.Context, client *http.Client, userAgent, au defer result.Body.Close() body = make([]byte, connector.MaxResponseLength+1) - body, err = readWithContext(ctx, result.Body, body) + body, err = ReadWithContext(ctx, result.Body, body) if err != nil { return nil, &protocol.CommandError{Err: err, PossibleSuccess: true, PossibleTemporary: false} } diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index bae76c7b..468a46e7 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -10,6 +10,8 @@ import ( "net" "net/http" "net/url" + "regexp" + "slices" "strings" "sync" "time" @@ -30,8 +32,13 @@ const ( maxRequestBodyBytes = 512 vinLength = 17 proxyProtocolVersion = "tesla-http-proxy/1.1.0" + MaxResponseLength = 10000000 + MaxAttempts = 2 ) +var baseDomainRE = regexp.MustCompile(`use base URL: https://([-a-z0-9.]*)`) +var h2Prefix = "h2=https://" + func getAccount(req *http.Request) (*account.Account, error) { token, ok := strings.CutPrefix(req.Header.Get("Authorization"), "Bearer ") if !ok { @@ -44,10 +51,23 @@ func getAccount(req *http.Request) (*account.Account, error) { type Proxy struct { Timeout time.Duration - commandKey protocol.ECDHPrivateKey - sessions *cache.SessionCache - vinLock sync.Map - unsupported sync.Map + commandKey protocol.ECDHPrivateKey + sessions *cache.SessionCache + vinLock sync.Map + unsupported sync.Map + domainForSubject sync.Map +} + +func (p *Proxy) updateDomainForSubject(subject, domain string) { + p.domainForSubject.Store(subject, domain) +} + +func (p *Proxy) fetchDomainForSubject(subject string) string { + domain, ok := p.domainForSubject.Load(subject) + if !ok { + return "" + } + return domain.(string) } func (p *Proxy) markUnsupportedVIN(vin string) { @@ -155,7 +175,7 @@ var connectionHeaders = []string{ // forwardRequest is the fallback handler for "/api/1/*". // It forwards GET and POST requests to Tesla using the proxy's OAuth token. -func (p *Proxy) forwardRequest(host string, w http.ResponseWriter, req *http.Request) { +func (p *Proxy) forwardRequest(acct *account.Account, w http.ResponseWriter, req *http.Request) { ctx, cancel := context.WithTimeout(context.Background(), p.Timeout) defer cancel() @@ -185,32 +205,91 @@ func (p *Proxy) forwardRequest(host string, w http.ResponseWriter, req *http.Req // If the client sent multiple XFF headers, flatten them. proxyReq.Header.Set(xff, strings.Join(previous, ", ")) } - proxyReq.URL.Host = host proxyReq.URL.Scheme = "https" + attempts := 0 - log.Debug("Forwarding request to %s", proxyReq.URL.String()) - client := http.Client{} - resp, err := client.Do(proxyReq) - if err != nil { - if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() { - writeJSONError(w, http.StatusGatewayTimeout, urlErr) - } else { + var requestBody []byte + if req.Body != nil { + requestBody, err = io.ReadAll(req.Body) + if err != nil { writeJSONError(w, http.StatusBadGateway, err) + return } - return + req.Body = io.NopCloser(bytes.NewBuffer(requestBody)) } - defer resp.Body.Close() - for _, hdr := range connectionHeaders { - resp.Header.Del(hdr) - } - outHeader := w.Header() - for name, value := range resp.Header { - outHeader[name] = value - } + for { + proxyReq.URL.Host = acct.Host + log.Debug("Forwarding request to %s", proxyReq.URL.String()) + client := http.Client{} + result, err := client.Do(proxyReq) + + if err != nil { + if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() { + writeJSONError(w, http.StatusGatewayTimeout, urlErr) + } else { + writeJSONError(w, http.StatusBadGateway, err) + } + return + } - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) + limitedReader := &io.LimitedReader{R: result.Body, N: MaxResponseLength + 1} + body, err := io.ReadAll(limitedReader) + result.Body.Close() + + if err != nil { + writeJSONError(w, http.StatusBadGateway, err) + return + } + + if len(body) == MaxResponseLength+1 { + writeJSONError(w, http.StatusBadGateway, protocol.NewError("response exceeds maximum length", true, true)) + return + } + + if result.StatusCode == http.StatusMisdirectedRequest && result.Header.Get("Alt-Svc") != "" { + altSvc := result.Header.Values("Alt-Svc") + idx := slices.IndexFunc(altSvc, func(str string) bool { return strings.HasPrefix(str, h2Prefix) }) + if idx == -1 { + writeJSONError(w, result.StatusCode, err) + return + } + + altHost := altSvc[idx][len(h2Prefix):] + log.Debug("Received HTTP Status 421. Updating server URL to %s", altHost) + acct.Host = altHost + p.updateDomainForSubject(acct.Subject, acct.Host) + if req.Body != nil { + req.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + } + } else { + for _, hdr := range connectionHeaders { + result.Header.Del(hdr) + } + outHeader := w.Header() + for name, value := range result.Header { + outHeader[name] = value + } + + w.WriteHeader(result.StatusCode) + w.Write(body) + return + } + + attempts += 1 + if attempts == MaxAttempts { + writeJSONError(w, http.StatusBadGateway, protocol.NewError("max retry exhausted", false, false)) + } + + log.Debug("Retrying transmission after error...") + select { + case <-ctx.Done(): + writeJSONError(w, http.StatusGatewayTimeout, ctx.Err()) + return + case <-time.After(1 * time.Second): + continue + } + } } func (p *Proxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -221,6 +300,9 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { writeJSONError(w, http.StatusForbidden, err) return } + if host := p.fetchDomainForSubject(acct.Subject); host != "" { + acct.Host = host + } if strings.HasPrefix(req.URL.Path, "/api/1/vehicles/") { path := strings.Split(req.URL.Path, "/") @@ -232,23 +314,26 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } if p.isNotSupported(vin) { - p.forwardRequest(acct.Host, w, req) + p.forwardRequest(acct, w, req) + if acct.Host != p.fetchDomainForSubject(acct.Subject) { + p.updateDomainForSubject(acct.Subject, acct.Host) + } } else { if err := p.handleVehicleCommand(acct, w, req, command, vin); err == ErrCommandUseRESTAPI { - p.forwardRequest(acct.Host, w, req) + p.forwardRequest(acct, w, req) } } return } if len(path) == 5 && path[4] == "fleet_telemetry_config" { - p.handleFleetTelemetryConfig(acct.Host, w, req) + p.handleFleetTelemetryConfig(acct, w, req) return } } - p.forwardRequest(acct.Host, w, req) + p.forwardRequest(acct, w, req) } -func (p *Proxy) handleFleetTelemetryConfig(host string, w http.ResponseWriter, req *http.Request) { +func (p *Proxy) handleFleetTelemetryConfig(acct *account.Account, w http.ResponseWriter, req *http.Request) { log.Info("Processing fleet telemetry configuration...") defer req.Body.Close() body, err := io.ReadAll(req.Body) @@ -294,7 +379,7 @@ func (p *Proxy) handleFleetTelemetryConfig(host string, w http.ResponseWriter, r return } log.Debug("Posting data to %s: %s", req.URL.String(), bodyJSON) - p.forwardRequest(host, w, req) + p.forwardRequest(acct, w, req) } func (p *Proxy) handleVehicleCommand(acct *account.Account, w http.ResponseWriter, req *http.Request, command, vin string) error { @@ -322,7 +407,7 @@ func (p *Proxy) handleVehicleCommand(acct *account.Account, w http.ResponseWrite if err := car.StartSession(ctx, nil); errors.Is(err, protocol.ErrProtocolNotSupported) { p.markUnsupportedVIN(vin) - p.forwardRequest(acct.Host, w, req) + p.forwardRequest(acct, w, req) return err } else if err != nil { writeJSONError(w, http.StatusInternalServerError, err)