Skip to content

Commit

Permalink
retry for other region (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
tnsetting2023 authored Oct 18, 2024
1 parent 3ee2bec commit 611f3d7
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 35 deletions.
3 changes: 3 additions & 0 deletions pkg/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ type Account struct {
UserAgent string
authHeader string
Host string
Subject string
client http.Client
}

// We don't parse JWTs beyond what's required to extract the API server domain name
type oauthPayload struct {
Audiences []string `json:"aud"`
OUCode string `json:"ou_code"`
Subject string `json:"sub"`
}

var domainRegEx = regexp.MustCompile(`^[A-Za-z0-9-.]+$`) // We're mostly interested in stopping paths; the http package handles the rest.
Expand Down Expand Up @@ -136,6 +138,7 @@ func New(oauthToken, userAgent string) (*Account, error) {
UserAgent: buildUserAgent(userAgent),
authHeader: "Bearer " + strings.TrimSpace(oauthToken),
Host: domain,
Subject: payload.Subject,
}, nil
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ func TestDomainExtraction(t *testing.T) {
"https://fleet-api.prd.na.vn.cloud.tesla.com",
"https://fleet-api.prd.eu.vn.cloud.tesla.com",
},
OUCode: "EU",
OUCode: "EU",
Subject: "SUBJECT",
}

acct, err := New(makeTestJWT(payload), "")
if err != nil {
t.Fatalf("Returned error on valid JWT: %s", err)
}
expectedHost := "fleet-api.prd.eu.vn.cloud.tesla.com"
if acct == nil || acct.Host != expectedHost {
if acct == nil || acct.Host != expectedHost || acct.Subject != "SUBJECT" {
t.Errorf("acct = %+v, expected Host = %s", acct, expectedHost)
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/connector/inet/inet.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
// MaxLatency is the default maximum latency permitted when updating the vehicle clock estimate.
var MaxLatency = 10 * time.Second

func readWithContext(ctx context.Context, r io.Reader, p []byte) ([]byte, error) {
func ReadWithContext(ctx context.Context, r io.Reader, p []byte) ([]byte, error) {
bytesRead := 0
for {
if ctx.Err() != nil {
Expand Down Expand Up @@ -108,7 +108,7 @@ func SendFleetAPICommand(ctx context.Context, client *http.Client, userAgent, au
defer result.Body.Close()

body = make([]byte, connector.MaxResponseLength+1)
body, err = readWithContext(ctx, result.Body, body)
body, err = ReadWithContext(ctx, result.Body, body)
if err != nil {
return nil, &protocol.CommandError{Err: err, PossibleSuccess: true, PossibleTemporary: false}
}
Expand Down
147 changes: 116 additions & 31 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"net"
"net/http"
"net/url"
"regexp"
"slices"
"strings"
"sync"
"time"
Expand All @@ -30,8 +32,13 @@ const (
maxRequestBodyBytes = 512
vinLength = 17
proxyProtocolVersion = "tesla-http-proxy/1.1.0"
MaxResponseLength = 10000000
MaxAttempts = 2
)

var baseDomainRE = regexp.MustCompile(`use base URL: https://([-a-z0-9.]*)`)
var h2Prefix = "h2=https://"

func getAccount(req *http.Request) (*account.Account, error) {
token, ok := strings.CutPrefix(req.Header.Get("Authorization"), "Bearer ")
if !ok {
Expand All @@ -44,10 +51,23 @@ func getAccount(req *http.Request) (*account.Account, error) {
type Proxy struct {
Timeout time.Duration

commandKey protocol.ECDHPrivateKey
sessions *cache.SessionCache
vinLock sync.Map
unsupported sync.Map
commandKey protocol.ECDHPrivateKey
sessions *cache.SessionCache
vinLock sync.Map
unsupported sync.Map
domainForSubject sync.Map
}

func (p *Proxy) updateDomainForSubject(subject, domain string) {
p.domainForSubject.Store(subject, domain)
}

func (p *Proxy) fetchDomainForSubject(subject string) string {
domain, ok := p.domainForSubject.Load(subject)
if !ok {
return ""
}
return domain.(string)
}

func (p *Proxy) markUnsupportedVIN(vin string) {
Expand Down Expand Up @@ -155,7 +175,7 @@ var connectionHeaders = []string{

// forwardRequest is the fallback handler for "/api/1/*".
// It forwards GET and POST requests to Tesla using the proxy's OAuth token.
func (p *Proxy) forwardRequest(host string, w http.ResponseWriter, req *http.Request) {
func (p *Proxy) forwardRequest(acct *account.Account, w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithTimeout(context.Background(), p.Timeout)
defer cancel()

Expand Down Expand Up @@ -185,32 +205,91 @@ func (p *Proxy) forwardRequest(host string, w http.ResponseWriter, req *http.Req
// If the client sent multiple XFF headers, flatten them.
proxyReq.Header.Set(xff, strings.Join(previous, ", "))
}
proxyReq.URL.Host = host
proxyReq.URL.Scheme = "https"
attempts := 0

log.Debug("Forwarding request to %s", proxyReq.URL.String())
client := http.Client{}
resp, err := client.Do(proxyReq)
if err != nil {
if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() {
writeJSONError(w, http.StatusGatewayTimeout, urlErr)
} else {
var requestBody []byte
if req.Body != nil {
requestBody, err = io.ReadAll(req.Body)
if err != nil {
writeJSONError(w, http.StatusBadGateway, err)
return
}
return
req.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
defer resp.Body.Close()

for _, hdr := range connectionHeaders {
resp.Header.Del(hdr)
}
outHeader := w.Header()
for name, value := range resp.Header {
outHeader[name] = value
}
for {
proxyReq.URL.Host = acct.Host
log.Debug("Forwarding request to %s", proxyReq.URL.String())
client := http.Client{}
result, err := client.Do(proxyReq)

if err != nil {
if urlErr, ok := err.(*url.Error); ok && urlErr.Timeout() {
writeJSONError(w, http.StatusGatewayTimeout, urlErr)
} else {
writeJSONError(w, http.StatusBadGateway, err)
}
return
}

w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
limitedReader := &io.LimitedReader{R: result.Body, N: MaxResponseLength + 1}
body, err := io.ReadAll(limitedReader)
result.Body.Close()

if err != nil {
writeJSONError(w, http.StatusBadGateway, err)
return
}

if len(body) == MaxResponseLength+1 {
writeJSONError(w, http.StatusBadGateway, protocol.NewError("response exceeds maximum length", true, true))
return
}

if result.StatusCode == http.StatusMisdirectedRequest && result.Header.Get("Alt-Svc") != "" {
altSvc := result.Header.Values("Alt-Svc")
idx := slices.IndexFunc(altSvc, func(str string) bool { return strings.HasPrefix(str, h2Prefix) })
if idx == -1 {
writeJSONError(w, result.StatusCode, err)
return
}

altHost := altSvc[idx][len(h2Prefix):]
log.Debug("Received HTTP Status 421. Updating server URL to %s", altHost)
acct.Host = altHost
p.updateDomainForSubject(acct.Subject, acct.Host)
if req.Body != nil {
req.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
} else {
for _, hdr := range connectionHeaders {
result.Header.Del(hdr)
}
outHeader := w.Header()
for name, value := range result.Header {
outHeader[name] = value
}

w.WriteHeader(result.StatusCode)
w.Write(body)
return
}

attempts += 1
if attempts == MaxAttempts {
writeJSONError(w, http.StatusBadGateway, protocol.NewError("max retry exhausted", false, false))
}

log.Debug("Retrying transmission after error...")
select {
case <-ctx.Done():
writeJSONError(w, http.StatusGatewayTimeout, ctx.Err())
return
case <-time.After(1 * time.Second):
continue
}
}
}

func (p *Proxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
Expand All @@ -221,6 +300,9 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
writeJSONError(w, http.StatusForbidden, err)
return
}
if host := p.fetchDomainForSubject(acct.Subject); host != "" {
acct.Host = host
}

if strings.HasPrefix(req.URL.Path, "/api/1/vehicles/") {
path := strings.Split(req.URL.Path, "/")
Expand All @@ -232,23 +314,26 @@ func (p *Proxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
if p.isNotSupported(vin) {
p.forwardRequest(acct.Host, w, req)
p.forwardRequest(acct, w, req)
if acct.Host != p.fetchDomainForSubject(acct.Subject) {
p.updateDomainForSubject(acct.Subject, acct.Host)
}
} else {
if err := p.handleVehicleCommand(acct, w, req, command, vin); err == ErrCommandUseRESTAPI {
p.forwardRequest(acct.Host, w, req)
p.forwardRequest(acct, w, req)
}
}
return
}
if len(path) == 5 && path[4] == "fleet_telemetry_config" {
p.handleFleetTelemetryConfig(acct.Host, w, req)
p.handleFleetTelemetryConfig(acct, w, req)
return
}
}
p.forwardRequest(acct.Host, w, req)
p.forwardRequest(acct, w, req)
}

func (p *Proxy) handleFleetTelemetryConfig(host string, w http.ResponseWriter, req *http.Request) {
func (p *Proxy) handleFleetTelemetryConfig(acct *account.Account, w http.ResponseWriter, req *http.Request) {
log.Info("Processing fleet telemetry configuration...")
defer req.Body.Close()
body, err := io.ReadAll(req.Body)
Expand Down Expand Up @@ -294,7 +379,7 @@ func (p *Proxy) handleFleetTelemetryConfig(host string, w http.ResponseWriter, r
return
}
log.Debug("Posting data to %s: %s", req.URL.String(), bodyJSON)
p.forwardRequest(host, w, req)
p.forwardRequest(acct, w, req)
}

func (p *Proxy) handleVehicleCommand(acct *account.Account, w http.ResponseWriter, req *http.Request, command, vin string) error {
Expand Down Expand Up @@ -322,7 +407,7 @@ func (p *Proxy) handleVehicleCommand(acct *account.Account, w http.ResponseWrite

if err := car.StartSession(ctx, nil); errors.Is(err, protocol.ErrProtocolNotSupported) {
p.markUnsupportedVIN(vin)
p.forwardRequest(acct.Host, w, req)
p.forwardRequest(acct, w, req)
return err
} else if err != nil {
writeJSONError(w, http.StatusInternalServerError, err)
Expand Down

0 comments on commit 611f3d7

Please sign in to comment.