diff --git a/workhorse/internal/builds/register.go b/workhorse/internal/builds/register.go index f28ad75e1d83230d77c26a1f223aba0e04434c5c..d033c72cce9873f317da03a7116b876b1acc40ea 100644 --- a/workhorse/internal/builds/register.go +++ b/workhorse/internal/builds/register.go @@ -1,8 +1,10 @@ package builds import ( + "bytes" "encoding/json" "errors" + "io" "net/http" "time" @@ -63,11 +65,18 @@ func readRunnerBody(w http.ResponseWriter, r *http.Request) ([]byte, error) { registerHandlerOpenAtReading.Inc() defer registerHandlerOpenAtReading.Dec() - return helper.ReadRequestBody(w, r, maxRegisterBodySize) + return readRequestBody(w, r, maxRegisterBodySize) +} + +func readRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) { + limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize) + defer limitedBody.Close() + + return io.ReadAll(limitedBody) } func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) { - if !helper.IsApplicationJson(r) { + if !isApplicationJson(r) { return nil, errors.New("invalid content-type received") } @@ -80,6 +89,11 @@ func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) { return &runnerRequest, nil } +func isApplicationJson(r *http.Request) bool { + contentType := r.Header.Get("Content-Type") + return helper.IsContentType("application/json", contentType) +} + func proxyRegisterRequest(h http.Handler, w http.ResponseWriter, r *http.Request) { registerHandlerOpenAtProxying.Inc() defer registerHandlerOpenAtProxying.Dec() @@ -109,7 +123,7 @@ func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDurati return } - newRequest := helper.CloneRequestWithNewBody(r, requestBody) + newRequest := cloneRequestWithNewBody(r, requestBody) runnerRequest, err := readRunnerRequest(r, requestBody) if err != nil { @@ -161,3 +175,11 @@ func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDurati } }) } + +func cloneRequestWithNewBody(r *http.Request, body []byte) *http.Request { + newReq := *r + newReq.Body = io.NopCloser(bytes.NewReader(body)) + newReq.Header = helper.HeaderClone(r.Header) + newReq.ContentLength = int64(len(body)) + return &newReq +} diff --git a/workhorse/internal/builds/register_test.go b/workhorse/internal/builds/register_test.go index 3c975f610035089752841948e1b966a109a8d714..d5cbebd500bf6bcb8d8a3ab2c0c50fd537028e96 100644 --- a/workhorse/internal/builds/register_test.go +++ b/workhorse/internal/builds/register_test.go @@ -106,3 +106,50 @@ func TestRegisterHandlerWatcherNoChange(t *testing.T) { expectWatcherToBeExecuted(t, redis.WatchKeyStatusNoChange, nil, http.StatusNoContent) } + +func TestReadRequestBody(t *testing.T) { + data := []byte("123456") + rw := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) + + result, err := readRequestBody(rw, req, 1000) + require.NoError(t, err) + require.Equal(t, data, result) +} + +func TestReadRequestBodyLimit(t *testing.T) { + data := []byte("123456") + rw := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) + + _, err := readRequestBody(rw, req, 2) + require.Error(t, err) +} + +func TestApplicationJson(t *testing.T) { + req, _ := http.NewRequest("POST", "/test", nil) + req.Header.Set("Content-Type", "application/json") + + require.True(t, isApplicationJson(req), "expected to match 'application/json' as 'application/json'") + + req.Header.Set("Content-Type", "application/json; charset=utf-8") + require.True(t, isApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'") + + req.Header.Set("Content-Type", "text/plain") + require.False(t, isApplicationJson(req), "expected not to match 'text/plain' as 'application/json'") +} + +func TestCloneRequestWithBody(t *testing.T) { + input := []byte("test") + newInput := []byte("new body") + req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input)) + newReq := cloneRequestWithNewBody(req, newInput) + + require.NotEqual(t, req, newReq) + require.NotEqual(t, req.Body, newReq.Body) + require.NotEqual(t, len(newInput), newReq.ContentLength) + + var buffer bytes.Buffer + io.Copy(&buffer, newReq.Body) + require.Equal(t, newInput, buffer.Bytes()) +} diff --git a/workhorse/internal/channel/channel.go b/workhorse/internal/channel/channel.go index e740015d54a4a5971079665c4002fdfea5339015..deb4c32d6611afaa3d096ba65a3e23838b74a9e6 100644 --- a/workhorse/internal/channel/channel.go +++ b/workhorse/internal/channel/channel.go @@ -2,7 +2,9 @@ package channel import ( "fmt" + "net" "net/http" + "strings" "time" "github.com/gorilla/websocket" @@ -109,7 +111,7 @@ func pingLoop(conn Connection) { func connectToServer(settings *api.ChannelSettings, r *http.Request) (Connection, error) { settings = settings.Clone() - helper.SetForwardedFor(&settings.Header, r) + setForwardedFor(&settings.Header, r) conn, _, err := settings.Dial() if err != nil { @@ -130,3 +132,19 @@ func closeAfterMaxTime(proxy *Proxy, maxSessionTime int) { maxSessionTime, ) } + +func setForwardedFor(newHeaders *http.Header, originalRequest *http.Request) { + if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil { + var header string + + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok { + header = strings.Join(prior, ", ") + ", " + clientIP + } else { + header = clientIP + } + newHeaders.Set("X-Forwarded-For", header) + } +} diff --git a/workhorse/internal/channel/channel_test.go b/workhorse/internal/channel/channel_test.go new file mode 100644 index 0000000000000000000000000000000000000000..fade6e42c27ff9b698c6d15ea08bdb91bfb1b64d --- /dev/null +++ b/workhorse/internal/channel/channel_test.go @@ -0,0 +1,49 @@ +package channel + +import ( + "net/http" + "testing" +) + +func TestSetForwardedForGeneratesHeader(t *testing.T) { + testCases := []struct { + remoteAddr string + previousForwardedFor []string + expected string + }{ + { + "8.8.8.8:3000", + nil, + "8.8.8.8", + }, + { + "8.8.8.8:3000", + []string{"138.124.33.63, 151.146.211.237"}, + "138.124.33.63, 151.146.211.237, 8.8.8.8", + }, + { + "8.8.8.8:3000", + []string{"8.154.76.107", "115.206.118.179"}, + "8.154.76.107, 115.206.118.179, 8.8.8.8", + }, + } + for _, tc := range testCases { + headers := http.Header{} + originalRequest := http.Request{ + RemoteAddr: tc.remoteAddr, + } + + if tc.previousForwardedFor != nil { + originalRequest.Header = http.Header{ + "X-Forwarded-For": tc.previousForwardedFor, + } + } + + setForwardedFor(&headers, &originalRequest) + + result := headers.Get("X-Forwarded-For") + if result != tc.expected { + t.Fatalf("Expected %v, got %v", tc.expected, result) + } + } +} diff --git a/workhorse/internal/helper/writeafterreader.go b/workhorse/internal/git/io.go similarity index 78% rename from workhorse/internal/helper/writeafterreader.go rename to workhorse/internal/git/io.go index 3626d70e4931efd3fd11496a9f44ae613ba3bed8..7b62b04395c08134cc263040e65bcfa23f108131 100644 --- a/workhorse/internal/helper/writeafterreader.go +++ b/workhorse/internal/git/io.go @@ -1,13 +1,48 @@ -package helper +package git import ( + "context" "fmt" "io" "os" "sync" ) -type WriteFlusher interface { +type contextReader struct { + ctx context.Context + underlyingReader io.Reader +} + +func newContextReader(ctx context.Context, underlyingReader io.Reader) *contextReader { + return &contextReader{ + ctx: ctx, + underlyingReader: underlyingReader, + } +} + +func (r *contextReader) Read(b []byte) (int, error) { + if r.canceled() { + return 0, r.err() + } + + n, err := r.underlyingReader.Read(b) + + if r.canceled() { + err = r.err() + } + + return n, err +} + +func (r *contextReader) canceled() bool { + return r.err() != nil +} + +func (r *contextReader) err() error { + return r.ctx.Err() +} + +type writeFlusher interface { io.Writer Flush() error } @@ -16,7 +51,7 @@ type WriteFlusher interface { // returned some error), all writes to w are sent to a tempfile first. // The caller must call Flush() on the returned WriteFlusher to ensure // all data is propagated to w. -func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) { +func newWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, writeFlusher) { br := &busyReader{Reader: r} return br, &coupledWriter{Writer: w, busyReader: br} } diff --git a/workhorse/internal/helper/writeafterreader_test.go b/workhorse/internal/git/io_test.go similarity index 58% rename from workhorse/internal/helper/writeafterreader_test.go rename to workhorse/internal/git/io_test.go index c3da428184b3573f5c61eeddfeceef2451840af2..f283c20c23c702cc8197717e58fff99f57ea4a35 100644 --- a/workhorse/internal/helper/writeafterreader_test.go +++ b/workhorse/internal/git/io_test.go @@ -1,17 +1,94 @@ -package helper +package git import ( "bytes" + "context" "fmt" "io" "testing" "testing/iotest" + "time" + + "github.com/stretchr/testify/require" ) +type fakeReader struct { + n int + err error +} + +func (f *fakeReader) Read(b []byte) (int, error) { + return f.n, f.err +} + +type fakeContextWithTimeout struct { + n int + threshold int +} + +func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*fakeContextWithTimeout) Done() <-chan struct{} { + return nil +} + +func (*fakeContextWithTimeout) Value(key interface{}) interface{} { + return nil +} + +func (f *fakeContextWithTimeout) Err() error { + f.n++ + if f.n > f.threshold { + return context.DeadlineExceeded + } + + return nil +} + +func TestContextReaderRead(t *testing.T) { + underlyingReader := &fakeReader{n: 1, err: io.EOF} + + for _, tc := range []struct { + desc string + ctx *fakeContextWithTimeout + expectedN int + expectedErr error + }{ + { + desc: "Before and after read deadline checks are fine", + ctx: &fakeContextWithTimeout{n: 0, threshold: 2}, + expectedN: underlyingReader.n, + expectedErr: underlyingReader.err, + }, + { + desc: "Before read deadline check fails", + ctx: &fakeContextWithTimeout{n: 0, threshold: 0}, + expectedN: 0, + expectedErr: context.DeadlineExceeded, + }, + { + desc: "After read deadline check fails", + ctx: &fakeContextWithTimeout{n: 0, threshold: 1}, + expectedN: underlyingReader.n, + expectedErr: context.DeadlineExceeded, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + cr := newContextReader(tc.ctx, underlyingReader) + + n, err := cr.Read(nil) + require.Equal(t, tc.expectedN, n) + require.Equal(t, tc.expectedErr, err) + }) + } +} + func TestBusyReader(t *testing.T) { testData := "test data" r := testReader(testData) - br, _ := NewWriteAfterReader(r, &bytes.Buffer{}) + br, _ := newWriteAfterReader(r, &bytes.Buffer{}) result, err := io.ReadAll(br) if err != nil { @@ -25,7 +102,7 @@ func TestBusyReader(t *testing.T) { func TestFirstWriteAfterReadDone(t *testing.T) { writeRecorder := &bytes.Buffer{} - br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder) + br, cw := newWriteAfterReader(&bytes.Buffer{}, writeRecorder) if _, err := io.Copy(io.Discard, br); err != nil { t.Fatalf("copy from busyreader: %v", err) } @@ -44,7 +121,7 @@ func TestFirstWriteAfterReadDone(t *testing.T) { func TestWriteDelay(t *testing.T) { writeRecorder := &bytes.Buffer{} w := &complainingWriter{Writer: writeRecorder} - br, cw := NewWriteAfterReader(&bytes.Buffer{}, w) + br, cw := newWriteAfterReader(&bytes.Buffer{}, w) testData1 := "1 test" if _, err := io.Copy(cw, testReader(testData1)); err != nil { diff --git a/workhorse/internal/git/receive-pack.go b/workhorse/internal/git/receive-pack.go index e3af472fffae407034a88bd94ec756aa765c4d42..5e93c0f36d112a1c2e1fe08829619bb4554a8301 100644 --- a/workhorse/internal/git/receive-pack.go +++ b/workhorse/internal/git/receive-pack.go @@ -6,7 +6,6 @@ import ( "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) // Will not return a non-nil error after the response body has been @@ -15,7 +14,7 @@ func handleReceivePack(w *HttpResponseWriter, r *http.Request, a *api.Response) action := getService(r) writePostRPCHeader(w, action) - cr, cw := helper.NewWriteAfterReader(r.Body, w) + cr, cw := newWriteAfterReader(r.Body, w) defer cw.Flush() gitProtocol := r.Header.Get("Git-Protocol") diff --git a/workhorse/internal/git/upload-pack.go b/workhorse/internal/git/upload-pack.go index 74995fb61c8b7540dd0c6b5c7810bed5dc44173e..ef2a00bf3ac5941f40ae732932f0c97e1cdfc7a7 100644 --- a/workhorse/internal/git/upload-pack.go +++ b/workhorse/internal/git/upload-pack.go @@ -9,7 +9,6 @@ import ( "gitlab.com/gitlab-org/gitlab/workhorse/internal/api" "gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly" - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) var ( @@ -31,8 +30,8 @@ func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) e readerCtx, cancel := context.WithTimeout(ctx, uploadPackTimeout) defer cancel() - limited := helper.NewContextReader(readerCtx, r.Body) - cr, cw := helper.NewWriteAfterReader(limited, w) + limited := newContextReader(readerCtx, r.Body) + cr, cw := newWriteAfterReader(limited, w) defer cw.Flush() action := getService(r) diff --git a/workhorse/internal/helper/context_reader.go b/workhorse/internal/helper/context_reader.go deleted file mode 100644 index a476404314764f73ee3e76511fff5540c3c95527..0000000000000000000000000000000000000000 --- a/workhorse/internal/helper/context_reader.go +++ /dev/null @@ -1,40 +0,0 @@ -package helper - -import ( - "context" - "io" -) - -type ContextReader struct { - ctx context.Context - underlyingReader io.Reader -} - -func NewContextReader(ctx context.Context, underlyingReader io.Reader) *ContextReader { - return &ContextReader{ - ctx: ctx, - underlyingReader: underlyingReader, - } -} - -func (r *ContextReader) Read(b []byte) (int, error) { - if r.canceled() { - return 0, r.err() - } - - n, err := r.underlyingReader.Read(b) - - if r.canceled() { - err = r.err() - } - - return n, err -} - -func (r *ContextReader) canceled() bool { - return r.err() != nil -} - -func (r *ContextReader) err() error { - return r.ctx.Err() -} diff --git a/workhorse/internal/helper/context_reader_test.go b/workhorse/internal/helper/context_reader_test.go deleted file mode 100644 index 257ec4e35f2872f2fb4c9e7c0e1062e315207a47..0000000000000000000000000000000000000000 --- a/workhorse/internal/helper/context_reader_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package helper - -import ( - "context" - "io" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -type fakeReader struct { - n int - err error -} - -func (f *fakeReader) Read(b []byte) (int, error) { - return f.n, f.err -} - -type fakeContextWithTimeout struct { - n int - threshold int -} - -func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) { - return -} - -func (*fakeContextWithTimeout) Done() <-chan struct{} { - return nil -} - -func (*fakeContextWithTimeout) Value(key interface{}) interface{} { - return nil -} - -func (f *fakeContextWithTimeout) Err() error { - f.n++ - if f.n > f.threshold { - return context.DeadlineExceeded - } - - return nil -} - -func TestContextReaderRead(t *testing.T) { - underlyingReader := &fakeReader{n: 1, err: io.EOF} - - for _, tc := range []struct { - desc string - ctx *fakeContextWithTimeout - expectedN int - expectedErr error - }{ - { - desc: "Before and after read deadline checks are fine", - ctx: &fakeContextWithTimeout{n: 0, threshold: 2}, - expectedN: underlyingReader.n, - expectedErr: underlyingReader.err, - }, - { - desc: "Before read deadline check fails", - ctx: &fakeContextWithTimeout{n: 0, threshold: 0}, - expectedN: 0, - expectedErr: context.DeadlineExceeded, - }, - { - desc: "After read deadline check fails", - ctx: &fakeContextWithTimeout{n: 0, threshold: 1}, - expectedN: underlyingReader.n, - expectedErr: context.DeadlineExceeded, - }, - } { - t.Run(tc.desc, func(t *testing.T) { - cr := NewContextReader(tc.ctx, underlyingReader) - - n, err := cr.Read(nil) - require.Equal(t, tc.expectedN, n) - require.Equal(t, tc.expectedErr, err) - }) - } -} diff --git a/workhorse/internal/helper/helpers.go b/workhorse/internal/helper/helpers.go index 4b458372629626a5644f93df37812f474cfd1557..7d7e7df3c74af4f9106adddc65ce8fa4d09648ec 100644 --- a/workhorse/internal/helper/helpers.go +++ b/workhorse/internal/helper/helpers.go @@ -1,17 +1,11 @@ package helper import ( - "bytes" "errors" - "io" "mime" - "net" "net/http" "net/url" "os" - "strings" - - "github.com/sebest/xff" "gitlab.com/gitlab-org/gitlab/workhorse/internal/log" ) @@ -38,12 +32,6 @@ func printError(r *http.Request, err error, fields log.Fields) { log.WithRequest(r).WithFields(fields).WithError(err).Error() } -func SetNoCacheHeaders(header http.Header) { - header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") - header.Set("Pragma", "no-cache") - header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") -} - func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) { file, err = os.Open(path) if err != nil { @@ -82,15 +70,6 @@ func URLMustParse(s string) *url.URL { return u } -func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) { - if r.ProtoAtLeast(1, 1) { - // Force client to disconnect if we render request error - w.Header().Set("Connection", "close") - } - - http.Error(w, error, code) -} - func HeaderClone(h http.Header) http.Header { h2 := make(http.Header, len(h)) for k, vv := range h { @@ -101,52 +80,7 @@ func HeaderClone(h http.Header) http.Header { return h2 } -func FixRemoteAddr(r *http.Request) { - // Unix domain sockets have a remote addr of @. This will make the - // xff package lookup the X-Forwarded-For address if available. - if r.RemoteAddr == "@" { - r.RemoteAddr = "127.0.0.1:0" - } - r.RemoteAddr = xff.GetRemoteAddr(r) -} - -func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) { - if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil { - var header string - - // If we aren't the first proxy retain prior - // X-Forwarded-For information as a comma+space - // separated list and fold multiple headers into one. - if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok { - header = strings.Join(prior, ", ") + ", " + clientIP - } else { - header = clientIP - } - newHeaders.Set("X-Forwarded-For", header) - } -} - func IsContentType(expected, actual string) bool { parsed, _, err := mime.ParseMediaType(actual) return err == nil && parsed == expected } - -func IsApplicationJson(r *http.Request) bool { - contentType := r.Header.Get("Content-Type") - return IsContentType("application/json", contentType) -} - -func ReadRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) { - limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize) - defer limitedBody.Close() - - return io.ReadAll(limitedBody) -} - -func CloneRequestWithNewBody(r *http.Request, body []byte) *http.Request { - newReq := *r - newReq.Body = io.NopCloser(bytes.NewReader(body)) - newReq.Header = HeaderClone(r.Header) - newReq.ContentLength = int64(len(body)) - return &newReq -} diff --git a/workhorse/internal/helper/helpers_test.go b/workhorse/internal/helper/helpers_test.go index 93d1ee33d59624f75ecee9a65965fa9007c197d0..f303b22d424d8468ed3d85e41a2b5ab8b6a41a0b 100644 --- a/workhorse/internal/helper/helpers_test.go +++ b/workhorse/internal/helper/helpers_test.go @@ -2,7 +2,6 @@ package helper import ( "bytes" - "io" "net/http" "net/http/httptest" "testing" @@ -10,126 +9,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestFixRemoteAddr(t *testing.T) { - testCases := []struct { - initial string - forwarded string - expected string - }{ - {initial: "@", forwarded: "", expected: "127.0.0.1:0"}, - {initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, - {initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"}, - {initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"}, - {initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"}, - {initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, - } - - for _, tc := range testCases { - req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil) - require.NoError(t, err) - - req.RemoteAddr = tc.initial - - if tc.forwarded != "" { - req.Header.Add("X-Forwarded-For", tc.forwarded) - } - - FixRemoteAddr(req) - - require.Equal(t, tc.expected, req.RemoteAddr) - } -} - -func TestSetForwardedForGeneratesHeader(t *testing.T) { - testCases := []struct { - remoteAddr string - previousForwardedFor []string - expected string - }{ - { - "8.8.8.8:3000", - nil, - "8.8.8.8", - }, - { - "8.8.8.8:3000", - []string{"138.124.33.63, 151.146.211.237"}, - "138.124.33.63, 151.146.211.237, 8.8.8.8", - }, - { - "8.8.8.8:3000", - []string{"8.154.76.107", "115.206.118.179"}, - "8.154.76.107, 115.206.118.179, 8.8.8.8", - }, - } - for _, tc := range testCases { - headers := http.Header{} - originalRequest := http.Request{ - RemoteAddr: tc.remoteAddr, - } - - if tc.previousForwardedFor != nil { - originalRequest.Header = http.Header{ - "X-Forwarded-For": tc.previousForwardedFor, - } - } - - SetForwardedFor(&headers, &originalRequest) - - result := headers.Get("X-Forwarded-For") - if result != tc.expected { - t.Fatalf("Expected %v, got %v", tc.expected, result) - } - } -} - -func TestReadRequestBody(t *testing.T) { - data := []byte("123456") - rw := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) - - result, err := ReadRequestBody(rw, req, 1000) - require.NoError(t, err) - require.Equal(t, data, result) -} - -func TestReadRequestBodyLimit(t *testing.T) { - data := []byte("123456") - rw := httptest.NewRecorder() - req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data)) - - _, err := ReadRequestBody(rw, req, 2) - require.Error(t, err) -} - -func TestCloneRequestWithBody(t *testing.T) { - input := []byte("test") - newInput := []byte("new body") - req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input)) - newReq := CloneRequestWithNewBody(req, newInput) - - require.NotEqual(t, req, newReq) - require.NotEqual(t, req.Body, newReq.Body) - require.NotEqual(t, len(newInput), newReq.ContentLength) - - var buffer bytes.Buffer - io.Copy(&buffer, newReq.Body) - require.Equal(t, newInput, buffer.Bytes()) -} - -func TestApplicationJson(t *testing.T) { - req, _ := http.NewRequest("POST", "/test", nil) - req.Header.Set("Content-Type", "application/json") - - require.True(t, IsApplicationJson(req), "expected to match 'application/json' as 'application/json'") - - req.Header.Set("Content-Type", "application/json; charset=utf-8") - require.True(t, IsApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'") - - req.Header.Set("Content-Type", "text/plain") - require.False(t, IsApplicationJson(req), "expected not to match 'text/plain' as 'application/json'") -} - func TestFail500WorksWithNils(t *testing.T) { body := bytes.NewBuffer(nil) w := httptest.NewRecorder() diff --git a/workhorse/internal/staticpages/deploy_page.go b/workhorse/internal/staticpages/deploy_page.go index 3dc2d982981cde924be7ba8daaf01b99964b430e..ca0931addd0f88aa1851f3d0fd03a5edba7b925e 100644 --- a/workhorse/internal/staticpages/deploy_page.go +++ b/workhorse/internal/staticpages/deploy_page.go @@ -4,8 +4,6 @@ import ( "net/http" "os" "path/filepath" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) func (s *Static) DeployPage(handler http.Handler) http.Handler { @@ -18,7 +16,7 @@ func (s *Static) DeployPage(handler http.Handler) http.Handler { return } - helper.SetNoCacheHeaders(w.Header()) + setNoCacheHeaders(w.Header()) w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK) w.Write(data) diff --git a/workhorse/internal/staticpages/error_pages.go b/workhorse/internal/staticpages/error_pages.go index e0ba7a5ceef90f5f9ca84dde648e07ae61ea1a85..d1aa7603658f69c7c96432ed1b7634b9cb0825bf 100644 --- a/workhorse/internal/staticpages/error_pages.go +++ b/workhorse/internal/staticpages/error_pages.go @@ -9,8 +9,6 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - - "gitlab.com/gitlab-org/gitlab/workhorse/internal/helper" ) var ( @@ -84,7 +82,7 @@ func (s *errorPageResponseWriter) WriteHeader(status int) { s.hijacked = true staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc() - helper.SetNoCacheHeaders(s.rw.Header()) + setNoCacheHeaders(s.rw.Header()) s.rw.Header().Set("Content-Type", contentType) s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) s.rw.Header().Del("Transfer-Encoding") diff --git a/workhorse/internal/staticpages/static.go b/workhorse/internal/staticpages/static.go index 5b804e4d6448f38b24a9d802de148e92f0a7b6ff..c5c0573090b4bec6f423f027a5ccb0a7ee3283b1 100644 --- a/workhorse/internal/staticpages/static.go +++ b/workhorse/internal/staticpages/static.go @@ -1,6 +1,14 @@ package staticpages +import "net/http" + type Static struct { DocumentRoot string Exclude []string } + +func setNoCacheHeaders(header http.Header) { + header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + header.Set("Pragma", "no-cache") + header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") +} diff --git a/workhorse/internal/upstream/routes.go b/workhorse/internal/upstream/routes.go index c47053ad682b9ad3559110390183ef1750a43a21..982f3a5b5f8804d565d926d4d91f090cc9f18754 100644 --- a/workhorse/internal/upstream/routes.go +++ b/workhorse/internal/upstream/routes.go @@ -425,7 +425,7 @@ func configureRoutes(u *upstream) { func denyWebsocket(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if websocket.IsWebSocketUpgrade(r) { - helper.HTTPError(w, r, "websocket upgrade not allowed", http.StatusBadRequest) + httpError(w, r, "websocket upgrade not allowed", http.StatusBadRequest) return } diff --git a/workhorse/internal/upstream/upstream.go b/workhorse/internal/upstream/upstream.go index cde1967460c511d008724656f2d3379b1cd42c7c..34fe300192f94f3b2f36cec083224a3d697ad9de 100644 --- a/workhorse/internal/upstream/upstream.go +++ b/workhorse/internal/upstream/upstream.go @@ -16,6 +16,7 @@ import ( "net/url" "strings" + "github.com/sebest/xff" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/labkit/correlation" @@ -125,19 +126,19 @@ func (u *upstream) configureURLPrefix() { } func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { - helper.FixRemoteAddr(r) + fixRemoteAddr(r) nginx.DisableResponseBuffering(w) // Drop RequestURI == "*" (FIXME: why?) if r.RequestURI == "*" { - helper.HTTPError(w, r, "Connection upgrade not allowed", http.StatusBadRequest) + httpError(w, r, "Connection upgrade not allowed", http.StatusBadRequest) return } // Disallow connect if r.Method == "CONNECT" { - helper.HTTPError(w, r, "CONNECT not allowed", http.StatusBadRequest) + httpError(w, r, "CONNECT not allowed", http.StatusBadRequest) return } @@ -145,7 +146,7 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { URIPath := urlprefix.CleanURIPath(r.URL.EscapedPath()) prefix := u.URLPrefix if !prefix.Match(URIPath) { - helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound) + httpError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound) return } @@ -156,7 +157,7 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { if route == nil { // The protocol spec in git/Documentation/technical/http-protocol.txt // says we must return 403 if no matching service is found. - helper.HTTPError(w, r, "Forbidden", http.StatusForbidden) + httpError(w, r, "Forbidden", http.StatusForbidden) return } @@ -276,3 +277,21 @@ func (u *upstream) updateGeoProxyFieldsFromData(geoProxyData *apipkg.GeoProxyDat u.geoProxyCableRoute = u.wsRoute(`^/-/cable\z`, geoProxyUpstream) u.geoProxyRoute = u.route("", "", geoProxyUpstream, withGeoProxy()) } + +func httpError(w http.ResponseWriter, r *http.Request, error string, code int) { + if r.ProtoAtLeast(1, 1) { + // Force client to disconnect if we render request error + w.Header().Set("Connection", "close") + } + + http.Error(w, error, code) +} + +func fixRemoteAddr(r *http.Request) { + // Unix domain sockets have a remote addr of @. This will make the + // xff package lookup the X-Forwarded-For address if available. + if r.RemoteAddr == "@" { + r.RemoteAddr = "127.0.0.1:0" + } + r.RemoteAddr = xff.GetRemoteAddr(r) +} diff --git a/workhorse/internal/upstream/upstream_test.go b/workhorse/internal/upstream/upstream_test.go index 7ab3e67116fa5cb9ef6a2ed6a16b4f60fae81fbb..705e40c74d53653afc1fe4bb16cb89080543e4f3 100644 --- a/workhorse/internal/upstream/upstream_test.go +++ b/workhorse/internal/upstream/upstream_test.go @@ -435,3 +435,33 @@ func startWorkhorseServer(railsServerURL string, enableGeoProxyFeature bool) (*h return ws, ws.Close, waitForNextApiPoll } + +func TestFixRemoteAddr(t *testing.T) { + testCases := []struct { + initial string + forwarded string + expected string + }{ + {initial: "@", forwarded: "", expected: "127.0.0.1:0"}, + {initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, + {initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"}, + {initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"}, + {initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"}, + {initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"}, + } + + for _, tc := range testCases { + req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil) + require.NoError(t, err) + + req.RemoteAddr = tc.initial + + if tc.forwarded != "" { + req.Header.Add("X-Forwarded-For", tc.forwarded) + } + + fixRemoteAddr(req) + + require.Equal(t, tc.expected, req.RemoteAddr) + } +}