Skip to content

Commit

Permalink
fix: fix proxy error handling and race conditions (#585)
Browse files Browse the repository at this point in the history
* feat: improve proxy error handling and allow CORS

* fix: fix live proxy race condition

* fix: handle proxy stop error

* feat: add more tests for proxy

* fix: fix reviewer's comments

---------

Co-authored-by: Neemias Almeida <[email protected]>
  • Loading branch information
ndajr and Neemias Almeida authored May 26, 2024
1 parent 9f85057 commit f673320
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 59 deletions.
4 changes: 3 additions & 1 deletion runner/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,9 @@ func (e *Engine) cleanup() {

if e.config.Proxy.Enabled {
e.mainDebug("powering down the proxy...")
e.proxy.Stop()
if err := e.proxy.Stop(); err != nil {
e.mainLog("failed to stop proxy: %+v", err)
}
}

e.withLock(func() {
Expand Down
57 changes: 32 additions & 25 deletions runner/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

type Reloader interface {
AddSubscriber() *Subscriber
RemoveSubscriber(id int)
RemoveSubscriber(id int32)
Reload()
Stop()
}
Expand Down Expand Up @@ -47,46 +47,47 @@ func (p *Proxy) Run() {
http.HandleFunc("/", p.proxyHandler)
http.HandleFunc("/internal/reload", p.reloadHandler)
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("failed to start proxy server: %v", err)
log.Fatal(p.Stop())
}
}

func (p *Proxy) Stop() {
p.server.Close()
p.stream.Stop()
}

func (p *Proxy) Reload() {
p.stream.Reload()
}

func (p *Proxy) injectLiveReload(respBody io.ReadCloser) string {
func (p *Proxy) injectLiveReload(resp *http.Response) (page string, modified bool) {
if !strings.Contains(resp.Header.Get("Content-Type"), "text/html") {
return page, false
}

buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(respBody); err != nil {
log.Fatalf("failed to convert request body to bytes buffer, err: %+v\n", err)
if _, err := buf.ReadFrom(resp.Body); err != nil {
return page, false
}
original := buf.String()
page = buf.String()

// the script will be injected before the end of the body tag. In case the tag is missing, the injection will be skipped without an error to ensure that a page with partial reloads only has at most one injected script.
body := strings.LastIndex(original, "</body>")
// the script will be injected before the end of the body tag. In case the tag is missing, the injection will be skipped.
body := strings.LastIndex(page, "</body>")
if body == -1 {
return original
return page, false
}

script := fmt.Sprintf(
`<script>new EventSource("http://localhost:%d/internal/reload").onmessage = () => { location.reload() }</script>`,
p.config.ProxyPort,
)
return original[:body] + script + original[body:]
return page[:body] + script + page[body:], true
}

func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")

appURL := r.URL
appURL.Scheme = "http"
appURL.Host = fmt.Sprintf("localhost:%d", p.config.AppPort)

if err := r.ParseForm(); err != nil {
log.Fatalf("failed to read form data from request, err: %+v\n", err)
http.Error(w, "proxy handler: bad form", http.StatusInternalServerError)
}
var body io.Reader
if len(r.Form) > 0 {
Expand All @@ -96,7 +97,7 @@ func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) {
}
req, err := http.NewRequest(r.Method, appURL.String(), body)
if err != nil {
log.Fatalf("proxy could not create request, err: %+v\n", err)
http.Error(w, "proxy handler: unable to create request", http.StatusInternalServerError)
}

// Copy the headers from the original request
Expand All @@ -115,7 +116,7 @@ func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) {
break
}
if !errors.Is(err, syscall.ECONNREFUSED) {
log.Fatalf("proxy failed to call %s, err: %+v\n", appURL.String(), err)
http.Error(w, "proxy handler: unable to reach app", http.StatusInternalServerError)
}
time.Sleep(100 * time.Millisecond)
}
Expand All @@ -132,27 +133,28 @@ func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) {
}
w.WriteHeader(resp.StatusCode)

if strings.Contains(resp.Header.Get("Content-Type"), "text/html") {
newPage := p.injectLiveReload(resp.Body)
w.Header().Set("Content-Length", strconv.Itoa((len([]byte(newPage)))))
if _, err := io.WriteString(w, newPage); err != nil {
log.Fatalf("proxy failed injected live reloading script, err: %+v\n", err)
page, modified := p.injectLiveReload(resp)
if modified {
w.Header().Set("Content-Length", strconv.Itoa((len([]byte(page)))))
if _, err := io.WriteString(w, page); err != nil {
http.Error(w, "proxy handler: unable to inject live reload script", http.StatusInternalServerError)
}
} else {
w.Header().Set("Content-Length", resp.Header.Get("Content-Length"))
if _, err := io.Copy(w, resp.Body); err != nil {
log.Fatalf("proxy failed to forward the response body, err: %+v\n", err)
http.Error(w, "proxy handler: failed to forward the response body", http.StatusInternalServerError)
}
}
}

func (p *Proxy) reloadHandler(w http.ResponseWriter, r *http.Request) {
flusher, err := w.(http.Flusher)
if !err {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
http.Error(w, "reload handler: streaming unsupported", http.StatusInternalServerError)
return
}

w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
Expand All @@ -171,3 +173,8 @@ func (p *Proxy) reloadHandler(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
}
}

func (p *Proxy) Stop() error {
p.stream.Stop()
return p.server.Close()
}
36 changes: 20 additions & 16 deletions runner/proxy_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,49 @@ package runner

import (
"sync"
"sync/atomic"
)

type ProxyStream struct {
sync.Mutex
subscribers map[int]*Subscriber
count int
mu sync.Mutex
subscribers map[int32]*Subscriber
count atomic.Int32
}

type Subscriber struct {
id int
id int32
reloadCh chan struct{}
}

func NewProxyStream() *ProxyStream {
return &ProxyStream{subscribers: make(map[int]*Subscriber)}
return &ProxyStream{subscribers: make(map[int32]*Subscriber)}
}

func (stream *ProxyStream) Stop() {
for id := range stream.subscribers {
stream.RemoveSubscriber(id)
}
stream.count = 0
stream.count = atomic.Int32{}
}

func (stream *ProxyStream) AddSubscriber() *Subscriber {
stream.Lock()
defer stream.Unlock()
stream.count++
stream.mu.Lock()
defer stream.mu.Unlock()
stream.count.Add(1)

sub := &Subscriber{id: stream.count, reloadCh: make(chan struct{})}
stream.subscribers[stream.count] = sub
sub := &Subscriber{id: stream.count.Load(), reloadCh: make(chan struct{})}
stream.subscribers[stream.count.Load()] = sub
return sub
}

func (stream *ProxyStream) RemoveSubscriber(id int) {
stream.Lock()
defer stream.Unlock()
close(stream.subscribers[id].reloadCh)
delete(stream.subscribers, id)
func (stream *ProxyStream) RemoveSubscriber(id int32) {
stream.mu.Lock()
defer stream.mu.Unlock()

if _, ok := stream.subscribers[id]; ok {
close(stream.subscribers[id].reloadCh)
delete(stream.subscribers, id)
}
}

func (stream *ProxyStream) Reload() {
Expand Down
29 changes: 17 additions & 12 deletions runner/proxy_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package runner

import (
"sync"
"sync/atomic"
"testing"
)

func find(s map[int]*Subscriber, id int) bool {
func find(s map[int32]*Subscriber, id int32) bool {
for _, sub := range s {
if sub.id == id {
return true
Expand All @@ -28,39 +29,43 @@ func TestProxyStream(t *testing.T) {
wg.Wait()

if got, exp := len(stream.subscribers), 10; got != exp {
t.Errorf("expected %d but got %d", exp, got)
t.Errorf("expect subscribers count to be %d, got %d", exp, got)
}

doneCh := make(chan struct{})
go func() {
stream.Reload()
doneCh <- struct{}{}
}()

reloadCount := 0
var reloadCount atomic.Int32
for _, sub := range stream.subscribers {
wg.Add(1)
go func(sub *Subscriber) {
defer wg.Done()
<-sub.reloadCh
reloadCount++
reloadCount.Add(1)
}(sub)
}
wg.Wait()
<-doneCh

if got, exp := reloadCount, 10; got != exp {
t.Errorf("expected %d but got %d", exp, got)
if got, exp := reloadCount.Load(), int32(10); got != exp {
t.Errorf("expect reloadCount %d, got %d", exp, got)
}

stream.RemoveSubscriber(2)
stream.AddSubscriber()
if got, exp := find(stream.subscribers, 2), false; got != exp {
t.Errorf("expected subscriber found to be %t but got %t", exp, got)
if find(stream.subscribers, 2) {
t.Errorf("expected subscriber 2 not to be found")
}
if got, exp := find(stream.subscribers, 11), true; got != exp {
t.Errorf("expected subscriber found to be %t but got %t", exp, got)

stream.AddSubscriber()
if !find(stream.subscribers, 11) {
t.Errorf("expected subscriber 11 to be found")
}

stream.Stop()
if got, exp := len(stream.subscribers), 0; got != exp {
t.Errorf("expected %d but got %d", exp, got)
t.Errorf("expected subscribers count to be %d, got %d", exp, got)
}
}
76 changes: 71 additions & 5 deletions runner/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (r *reloader) AddSubscriber() *Subscriber {
return &Subscriber{reloadCh: r.reloadCh}
}

func (r *reloader) RemoveSubscriber(_ int) {
func (r *reloader) RemoveSubscriber(_ int32) {
close(r.subCh)
}

Expand All @@ -48,7 +48,7 @@ func getServerPort(t *testing.T, srv *httptest.Server) int {
return port
}

func TestNewProxy(t *testing.T) {
func TestProxy_run(t *testing.T) {
_ = os.Unsetenv(airWd)
cfg := &cfgProxy{
Enabled: true,
Expand All @@ -62,6 +62,12 @@ func TestNewProxy(t *testing.T) {
if proxy.server.Addr == "" {
t.Fatal("server address should not be nil")
}
go func() {
proxy.Run()
}()
if err := proxy.Stop(); err != nil {
t.Errorf("failed stopping the proxy: %v", err)
}
}

func TestProxy_proxyHandler(t *testing.T) {
Expand All @@ -72,14 +78,14 @@ func TestProxy_proxyHandler(t *testing.T) {
}{
{
name: "get_request_with_headers",
assert: func(resp *http.Request) {
assert.Equal(t, "bar", resp.Header.Get("foo"))
},
req: func() *http.Request {
req := httptest.NewRequest("GET", fmt.Sprintf("http://localhost:%d", proxyPort), nil)
req.Header.Set("foo", "bar")
return req
},
assert: func(resp *http.Request) {
assert.Equal(t, "bar", resp.Header.Get("foo"))
},
},
{
name: "post_form_request",
Expand Down Expand Up @@ -141,6 +147,66 @@ func TestProxy_proxyHandler(t *testing.T) {
}
}

func TestProxy_injectLiveReload(t *testing.T) {
tests := []struct {
name string
given *http.Response
expect string
}{
{
name: "when_no_body_should_not_be_injected",
given: &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: http.StatusOK,
Body: http.NoBody,
},
expect: "",
},
{
name: "when_missing_body_should_not_be_injected",
given: &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/html"},
},
Body: io.NopCloser(strings.NewReader(`<h1>test</h1>`)),
},
expect: "<h1>test</h1>",
},
{
name: "when_text_html_and_body_is_present_should_be_injected",
given: &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"text/html"},
},
Body: io.NopCloser(strings.NewReader(`<body><h1>test</h1></body>`)),
},
expect: `<body><h1>test</h1><script>new EventSource("http://localhost:1111/internal/reload").onmessage = () => { location.reload() }</script></body>`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proxy := NewProxy(&cfgProxy{
Enabled: true,
ProxyPort: 1111,
AppPort: 2222,
})
if got, _ := proxy.injectLiveReload(tt.given); got != tt.expect {
t.Errorf("expected page %+v, got %v", tt.expect, got)
}
})
}
}

func TestProxy_reloadHandler(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, "thin air")
Expand Down

0 comments on commit f673320

Please sign in to comment.