diff --git a/main.go b/main.go index cd62ba1628c58d00c557c4391966a39175f0b2b0..e5f61b155e09041c5c188918796753ab09364f19 100644 --- a/main.go +++ b/main.go @@ -113,7 +113,8 @@ func main() { log.Fatal(err) } - var authTransport http.RoundTripper + // Create Proxy Transport + authTransport := http.DefaultTransport if *authSocket != "" { dialer := &net.Dialer{ // The values below are taken from http.DefaultTransport @@ -126,6 +127,7 @@ func main() { }, } } + proxyTransport := &proxyRoundTripper{transport: authTransport} // The profiler will only be activated by HTTP requests. HTTP // requests can only reach the profiler if we start a listener. So by @@ -137,7 +139,7 @@ func main() { }() } - upstream := newUpstream(*authBackend, authTransport) + upstream := newUpstream(*authBackend, proxyTransport) upstream.SetRelativeUrlRoot(*relativeUrlRoot) upstream.SetProxyTimeout(*proxyTimeout) diff --git a/proxy.go b/proxy.go index 15a63a4f55161e3d982e9a8fd5192c337de422e6..ef65a673b4a7c1aba60bab9161d0863cb066ce3d 100644 --- a/proxy.go +++ b/proxy.go @@ -1,9 +1,37 @@ package main import ( + "bytes" + "io/ioutil" "net/http" ) +type proxyRoundTripper struct { + transport http.RoundTripper +} + +func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { + res, err = p.transport.RoundTrip(r) + + // Map error to 502 response + if err != nil { + res = &http.Response{ + StatusCode: 502, + Status: err.Error(), + + Request: r, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + Proto: r.Proto, + Header: make(http.Header), + Trailer: make(http.Header), + Body: ioutil.NopCloser(&bytes.Buffer{}), + } + err = nil + } + return +} + func headerClone(h http.Header) http.Header { h2 := make(http.Header, len(h)) for k, vv := range h {