Skip to content
代码片段 群组 项目
提交 4ce5a9a3 编辑于 作者: Jacob Vosmaer's avatar Jacob Vosmaer 提交者: Igor Drozdov
浏览文件

Move single use helpers out of helper package

This moves 9 Workhorse functions in `internal/helper` that get called from only one package to their calling package.

The goal is to de-clutter and eventually remove the `internal/helper` package.

- Move FixRemoteAddr helper
- Move HTTPError helper
- Move CloneRequestWithNewBody helper
- Move IsApplicationJson helper
- Move ReadRequestBody helper
- Move SetNoCacheHeaders helper
- Move Workhorse git IO helpers
- Move SetForwardedFor helper
上级 a9d4fc03
No related branches found
No related tags found
无相关合并请求
显示
327 个添加338 个删除
package builds
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"time"
......@@ -63,11 +65,18 @@ func readRunnerBody(w http.ResponseWriter, r *http.Request) ([]byte, error) {
registerHandlerOpenAtReading.Inc()
defer registerHandlerOpenAtReading.Dec()
return helper.ReadRequestBody(w, r, maxRegisterBodySize)
return readRequestBody(w, r, maxRegisterBodySize)
}
func readRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) {
limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize)
defer limitedBody.Close()
return io.ReadAll(limitedBody)
}
func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) {
if !helper.IsApplicationJson(r) {
if !isApplicationJson(r) {
return nil, errors.New("invalid content-type received")
}
......@@ -80,6 +89,11 @@ func readRunnerRequest(r *http.Request, body []byte) (*runnerRequest, error) {
return &runnerRequest, nil
}
func isApplicationJson(r *http.Request) bool {
contentType := r.Header.Get("Content-Type")
return helper.IsContentType("application/json", contentType)
}
func proxyRegisterRequest(h http.Handler, w http.ResponseWriter, r *http.Request) {
registerHandlerOpenAtProxying.Inc()
defer registerHandlerOpenAtProxying.Dec()
......@@ -109,7 +123,7 @@ func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDurati
return
}
newRequest := helper.CloneRequestWithNewBody(r, requestBody)
newRequest := cloneRequestWithNewBody(r, requestBody)
runnerRequest, err := readRunnerRequest(r, requestBody)
if err != nil {
......@@ -161,3 +175,11 @@ func RegisterHandler(h http.Handler, watchHandler WatchKeyHandler, pollingDurati
}
})
}
func cloneRequestWithNewBody(r *http.Request, body []byte) *http.Request {
newReq := *r
newReq.Body = io.NopCloser(bytes.NewReader(body))
newReq.Header = helper.HeaderClone(r.Header)
newReq.ContentLength = int64(len(body))
return &newReq
}
......@@ -106,3 +106,50 @@ func TestRegisterHandlerWatcherNoChange(t *testing.T) {
expectWatcherToBeExecuted(t, redis.WatchKeyStatusNoChange, nil,
http.StatusNoContent)
}
func TestReadRequestBody(t *testing.T) {
data := []byte("123456")
rw := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
result, err := readRequestBody(rw, req, 1000)
require.NoError(t, err)
require.Equal(t, data, result)
}
func TestReadRequestBodyLimit(t *testing.T) {
data := []byte("123456")
rw := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
_, err := readRequestBody(rw, req, 2)
require.Error(t, err)
}
func TestApplicationJson(t *testing.T) {
req, _ := http.NewRequest("POST", "/test", nil)
req.Header.Set("Content-Type", "application/json")
require.True(t, isApplicationJson(req), "expected to match 'application/json' as 'application/json'")
req.Header.Set("Content-Type", "application/json; charset=utf-8")
require.True(t, isApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'")
req.Header.Set("Content-Type", "text/plain")
require.False(t, isApplicationJson(req), "expected not to match 'text/plain' as 'application/json'")
}
func TestCloneRequestWithBody(t *testing.T) {
input := []byte("test")
newInput := []byte("new body")
req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input))
newReq := cloneRequestWithNewBody(req, newInput)
require.NotEqual(t, req, newReq)
require.NotEqual(t, req.Body, newReq.Body)
require.NotEqual(t, len(newInput), newReq.ContentLength)
var buffer bytes.Buffer
io.Copy(&buffer, newReq.Body)
require.Equal(t, newInput, buffer.Bytes())
}
......@@ -2,7 +2,9 @@ package channel
import (
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/gorilla/websocket"
......@@ -109,7 +111,7 @@ func pingLoop(conn Connection) {
func connectToServer(settings *api.ChannelSettings, r *http.Request) (Connection, error) {
settings = settings.Clone()
helper.SetForwardedFor(&settings.Header, r)
setForwardedFor(&settings.Header, r)
conn, _, err := settings.Dial()
if err != nil {
......@@ -130,3 +132,19 @@ func closeAfterMaxTime(proxy *Proxy, maxSessionTime int) {
maxSessionTime,
)
}
func setForwardedFor(newHeaders *http.Header, originalRequest *http.Request) {
if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil {
var header string
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok {
header = strings.Join(prior, ", ") + ", " + clientIP
} else {
header = clientIP
}
newHeaders.Set("X-Forwarded-For", header)
}
}
package channel
import (
"net/http"
"testing"
)
func TestSetForwardedForGeneratesHeader(t *testing.T) {
testCases := []struct {
remoteAddr string
previousForwardedFor []string
expected string
}{
{
"8.8.8.8:3000",
nil,
"8.8.8.8",
},
{
"8.8.8.8:3000",
[]string{"138.124.33.63, 151.146.211.237"},
"138.124.33.63, 151.146.211.237, 8.8.8.8",
},
{
"8.8.8.8:3000",
[]string{"8.154.76.107", "115.206.118.179"},
"8.154.76.107, 115.206.118.179, 8.8.8.8",
},
}
for _, tc := range testCases {
headers := http.Header{}
originalRequest := http.Request{
RemoteAddr: tc.remoteAddr,
}
if tc.previousForwardedFor != nil {
originalRequest.Header = http.Header{
"X-Forwarded-For": tc.previousForwardedFor,
}
}
setForwardedFor(&headers, &originalRequest)
result := headers.Get("X-Forwarded-For")
if result != tc.expected {
t.Fatalf("Expected %v, got %v", tc.expected, result)
}
}
}
package helper
package git
import (
"context"
"fmt"
"io"
"os"
"sync"
)
type WriteFlusher interface {
type contextReader struct {
ctx context.Context
underlyingReader io.Reader
}
func newContextReader(ctx context.Context, underlyingReader io.Reader) *contextReader {
return &contextReader{
ctx: ctx,
underlyingReader: underlyingReader,
}
}
func (r *contextReader) Read(b []byte) (int, error) {
if r.canceled() {
return 0, r.err()
}
n, err := r.underlyingReader.Read(b)
if r.canceled() {
err = r.err()
}
return n, err
}
func (r *contextReader) canceled() bool {
return r.err() != nil
}
func (r *contextReader) err() error {
return r.ctx.Err()
}
type writeFlusher interface {
io.Writer
Flush() error
}
......@@ -16,7 +51,7 @@ type WriteFlusher interface {
// returned some error), all writes to w are sent to a tempfile first.
// The caller must call Flush() on the returned WriteFlusher to ensure
// all data is propagated to w.
func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) {
func newWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, writeFlusher) {
br := &busyReader{Reader: r}
return br, &coupledWriter{Writer: w, busyReader: br}
}
......
package helper
package git
import (
"bytes"
"context"
"fmt"
"io"
"testing"
"testing/iotest"
"time"
"github.com/stretchr/testify/require"
)
type fakeReader struct {
n int
err error
}
func (f *fakeReader) Read(b []byte) (int, error) {
return f.n, f.err
}
type fakeContextWithTimeout struct {
n int
threshold int
}
func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) {
return
}
func (*fakeContextWithTimeout) Done() <-chan struct{} {
return nil
}
func (*fakeContextWithTimeout) Value(key interface{}) interface{} {
return nil
}
func (f *fakeContextWithTimeout) Err() error {
f.n++
if f.n > f.threshold {
return context.DeadlineExceeded
}
return nil
}
func TestContextReaderRead(t *testing.T) {
underlyingReader := &fakeReader{n: 1, err: io.EOF}
for _, tc := range []struct {
desc string
ctx *fakeContextWithTimeout
expectedN int
expectedErr error
}{
{
desc: "Before and after read deadline checks are fine",
ctx: &fakeContextWithTimeout{n: 0, threshold: 2},
expectedN: underlyingReader.n,
expectedErr: underlyingReader.err,
},
{
desc: "Before read deadline check fails",
ctx: &fakeContextWithTimeout{n: 0, threshold: 0},
expectedN: 0,
expectedErr: context.DeadlineExceeded,
},
{
desc: "After read deadline check fails",
ctx: &fakeContextWithTimeout{n: 0, threshold: 1},
expectedN: underlyingReader.n,
expectedErr: context.DeadlineExceeded,
},
} {
t.Run(tc.desc, func(t *testing.T) {
cr := newContextReader(tc.ctx, underlyingReader)
n, err := cr.Read(nil)
require.Equal(t, tc.expectedN, n)
require.Equal(t, tc.expectedErr, err)
})
}
}
func TestBusyReader(t *testing.T) {
testData := "test data"
r := testReader(testData)
br, _ := NewWriteAfterReader(r, &bytes.Buffer{})
br, _ := newWriteAfterReader(r, &bytes.Buffer{})
result, err := io.ReadAll(br)
if err != nil {
......@@ -25,7 +102,7 @@ func TestBusyReader(t *testing.T) {
func TestFirstWriteAfterReadDone(t *testing.T) {
writeRecorder := &bytes.Buffer{}
br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder)
br, cw := newWriteAfterReader(&bytes.Buffer{}, writeRecorder)
if _, err := io.Copy(io.Discard, br); err != nil {
t.Fatalf("copy from busyreader: %v", err)
}
......@@ -44,7 +121,7 @@ func TestFirstWriteAfterReadDone(t *testing.T) {
func TestWriteDelay(t *testing.T) {
writeRecorder := &bytes.Buffer{}
w := &complainingWriter{Writer: writeRecorder}
br, cw := NewWriteAfterReader(&bytes.Buffer{}, w)
br, cw := newWriteAfterReader(&bytes.Buffer{}, w)
testData1 := "1 test"
if _, err := io.Copy(cw, testReader(testData1)); err != nil {
......
......@@ -6,7 +6,6 @@ import (
"gitlab.com/gitlab-org/gitlab/workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
)
// Will not return a non-nil error after the response body has been
......@@ -15,7 +14,7 @@ func handleReceivePack(w *HttpResponseWriter, r *http.Request, a *api.Response)
action := getService(r)
writePostRPCHeader(w, action)
cr, cw := helper.NewWriteAfterReader(r.Body, w)
cr, cw := newWriteAfterReader(r.Body, w)
defer cw.Flush()
gitProtocol := r.Header.Get("Git-Protocol")
......
......@@ -9,7 +9,6 @@ import (
"gitlab.com/gitlab-org/gitlab/workhorse/internal/api"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/gitaly"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
)
var (
......@@ -31,8 +30,8 @@ func handleUploadPack(w *HttpResponseWriter, r *http.Request, a *api.Response) e
readerCtx, cancel := context.WithTimeout(ctx, uploadPackTimeout)
defer cancel()
limited := helper.NewContextReader(readerCtx, r.Body)
cr, cw := helper.NewWriteAfterReader(limited, w)
limited := newContextReader(readerCtx, r.Body)
cr, cw := newWriteAfterReader(limited, w)
defer cw.Flush()
action := getService(r)
......
package helper
import (
"context"
"io"
)
type ContextReader struct {
ctx context.Context
underlyingReader io.Reader
}
func NewContextReader(ctx context.Context, underlyingReader io.Reader) *ContextReader {
return &ContextReader{
ctx: ctx,
underlyingReader: underlyingReader,
}
}
func (r *ContextReader) Read(b []byte) (int, error) {
if r.canceled() {
return 0, r.err()
}
n, err := r.underlyingReader.Read(b)
if r.canceled() {
err = r.err()
}
return n, err
}
func (r *ContextReader) canceled() bool {
return r.err() != nil
}
func (r *ContextReader) err() error {
return r.ctx.Err()
}
package helper
import (
"context"
"io"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type fakeReader struct {
n int
err error
}
func (f *fakeReader) Read(b []byte) (int, error) {
return f.n, f.err
}
type fakeContextWithTimeout struct {
n int
threshold int
}
func (*fakeContextWithTimeout) Deadline() (deadline time.Time, ok bool) {
return
}
func (*fakeContextWithTimeout) Done() <-chan struct{} {
return nil
}
func (*fakeContextWithTimeout) Value(key interface{}) interface{} {
return nil
}
func (f *fakeContextWithTimeout) Err() error {
f.n++
if f.n > f.threshold {
return context.DeadlineExceeded
}
return nil
}
func TestContextReaderRead(t *testing.T) {
underlyingReader := &fakeReader{n: 1, err: io.EOF}
for _, tc := range []struct {
desc string
ctx *fakeContextWithTimeout
expectedN int
expectedErr error
}{
{
desc: "Before and after read deadline checks are fine",
ctx: &fakeContextWithTimeout{n: 0, threshold: 2},
expectedN: underlyingReader.n,
expectedErr: underlyingReader.err,
},
{
desc: "Before read deadline check fails",
ctx: &fakeContextWithTimeout{n: 0, threshold: 0},
expectedN: 0,
expectedErr: context.DeadlineExceeded,
},
{
desc: "After read deadline check fails",
ctx: &fakeContextWithTimeout{n: 0, threshold: 1},
expectedN: underlyingReader.n,
expectedErr: context.DeadlineExceeded,
},
} {
t.Run(tc.desc, func(t *testing.T) {
cr := NewContextReader(tc.ctx, underlyingReader)
n, err := cr.Read(nil)
require.Equal(t, tc.expectedN, n)
require.Equal(t, tc.expectedErr, err)
})
}
}
package helper
import (
"bytes"
"errors"
"io"
"mime"
"net"
"net/http"
"net/url"
"os"
"strings"
"github.com/sebest/xff"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/log"
)
......@@ -38,12 +32,6 @@ func printError(r *http.Request, err error, fields log.Fields) {
log.WithRequest(r).WithFields(fields).WithError(err).Error()
}
func SetNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
func OpenFile(path string) (file *os.File, fi os.FileInfo, err error) {
file, err = os.Open(path)
if err != nil {
......@@ -82,15 +70,6 @@ func URLMustParse(s string) *url.URL {
return u
}
func HTTPError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
func HeaderClone(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
......@@ -101,52 +80,7 @@ func HeaderClone(h http.Header) http.Header {
return h2
}
func FixRemoteAddr(r *http.Request) {
// Unix domain sockets have a remote addr of @. This will make the
// xff package lookup the X-Forwarded-For address if available.
if r.RemoteAddr == "@" {
r.RemoteAddr = "127.0.0.1:0"
}
r.RemoteAddr = xff.GetRemoteAddr(r)
}
func SetForwardedFor(newHeaders *http.Header, originalRequest *http.Request) {
if clientIP, _, err := net.SplitHostPort(originalRequest.RemoteAddr); err == nil {
var header string
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := originalRequest.Header["X-Forwarded-For"]; ok {
header = strings.Join(prior, ", ") + ", " + clientIP
} else {
header = clientIP
}
newHeaders.Set("X-Forwarded-For", header)
}
}
func IsContentType(expected, actual string) bool {
parsed, _, err := mime.ParseMediaType(actual)
return err == nil && parsed == expected
}
func IsApplicationJson(r *http.Request) bool {
contentType := r.Header.Get("Content-Type")
return IsContentType("application/json", contentType)
}
func ReadRequestBody(w http.ResponseWriter, r *http.Request, maxBodySize int64) ([]byte, error) {
limitedBody := http.MaxBytesReader(w, r.Body, maxBodySize)
defer limitedBody.Close()
return io.ReadAll(limitedBody)
}
func CloneRequestWithNewBody(r *http.Request, body []byte) *http.Request {
newReq := *r
newReq.Body = io.NopCloser(bytes.NewReader(body))
newReq.Header = HeaderClone(r.Header)
newReq.ContentLength = int64(len(body))
return &newReq
}
......@@ -2,7 +2,6 @@ package helper
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
......@@ -10,126 +9,6 @@ import (
"github.com/stretchr/testify/require"
)
func TestFixRemoteAddr(t *testing.T) {
testCases := []struct {
initial string
forwarded string
expected string
}{
{initial: "@", forwarded: "", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
{initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"},
{initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"},
{initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
}
for _, tc := range testCases {
req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil)
require.NoError(t, err)
req.RemoteAddr = tc.initial
if tc.forwarded != "" {
req.Header.Add("X-Forwarded-For", tc.forwarded)
}
FixRemoteAddr(req)
require.Equal(t, tc.expected, req.RemoteAddr)
}
}
func TestSetForwardedForGeneratesHeader(t *testing.T) {
testCases := []struct {
remoteAddr string
previousForwardedFor []string
expected string
}{
{
"8.8.8.8:3000",
nil,
"8.8.8.8",
},
{
"8.8.8.8:3000",
[]string{"138.124.33.63, 151.146.211.237"},
"138.124.33.63, 151.146.211.237, 8.8.8.8",
},
{
"8.8.8.8:3000",
[]string{"8.154.76.107", "115.206.118.179"},
"8.154.76.107, 115.206.118.179, 8.8.8.8",
},
}
for _, tc := range testCases {
headers := http.Header{}
originalRequest := http.Request{
RemoteAddr: tc.remoteAddr,
}
if tc.previousForwardedFor != nil {
originalRequest.Header = http.Header{
"X-Forwarded-For": tc.previousForwardedFor,
}
}
SetForwardedFor(&headers, &originalRequest)
result := headers.Get("X-Forwarded-For")
if result != tc.expected {
t.Fatalf("Expected %v, got %v", tc.expected, result)
}
}
}
func TestReadRequestBody(t *testing.T) {
data := []byte("123456")
rw := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
result, err := ReadRequestBody(rw, req, 1000)
require.NoError(t, err)
require.Equal(t, data, result)
}
func TestReadRequestBodyLimit(t *testing.T) {
data := []byte("123456")
rw := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(data))
_, err := ReadRequestBody(rw, req, 2)
require.Error(t, err)
}
func TestCloneRequestWithBody(t *testing.T) {
input := []byte("test")
newInput := []byte("new body")
req, _ := http.NewRequest("POST", "/test", bytes.NewBuffer(input))
newReq := CloneRequestWithNewBody(req, newInput)
require.NotEqual(t, req, newReq)
require.NotEqual(t, req.Body, newReq.Body)
require.NotEqual(t, len(newInput), newReq.ContentLength)
var buffer bytes.Buffer
io.Copy(&buffer, newReq.Body)
require.Equal(t, newInput, buffer.Bytes())
}
func TestApplicationJson(t *testing.T) {
req, _ := http.NewRequest("POST", "/test", nil)
req.Header.Set("Content-Type", "application/json")
require.True(t, IsApplicationJson(req), "expected to match 'application/json' as 'application/json'")
req.Header.Set("Content-Type", "application/json; charset=utf-8")
require.True(t, IsApplicationJson(req), "expected to match 'application/json; charset=utf-8' as 'application/json'")
req.Header.Set("Content-Type", "text/plain")
require.False(t, IsApplicationJson(req), "expected not to match 'text/plain' as 'application/json'")
}
func TestFail500WorksWithNils(t *testing.T) {
body := bytes.NewBuffer(nil)
w := httptest.NewRecorder()
......
......@@ -4,8 +4,6 @@ import (
"net/http"
"os"
"path/filepath"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
)
func (s *Static) DeployPage(handler http.Handler) http.Handler {
......@@ -18,7 +16,7 @@ func (s *Static) DeployPage(handler http.Handler) http.Handler {
return
}
helper.SetNoCacheHeaders(w.Header())
setNoCacheHeaders(w.Header())
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write(data)
......
......@@ -9,8 +9,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/helper"
)
var (
......@@ -84,7 +82,7 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.hijacked = true
staticErrorResponses.WithLabelValues(fmt.Sprintf("%d", s.status)).Inc()
helper.SetNoCacheHeaders(s.rw.Header())
setNoCacheHeaders(s.rw.Header())
s.rw.Header().Set("Content-Type", contentType)
s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
s.rw.Header().Del("Transfer-Encoding")
......
package staticpages
import "net/http"
type Static struct {
DocumentRoot string
Exclude []string
}
func setNoCacheHeaders(header http.Header) {
header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
header.Set("Pragma", "no-cache")
header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT")
}
......@@ -425,7 +425,7 @@ func configureRoutes(u *upstream) {
func denyWebsocket(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if websocket.IsWebSocketUpgrade(r) {
helper.HTTPError(w, r, "websocket upgrade not allowed", http.StatusBadRequest)
httpError(w, r, "websocket upgrade not allowed", http.StatusBadRequest)
return
}
......
......@@ -16,6 +16,7 @@ import (
"net/url"
"strings"
"github.com/sebest/xff"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/labkit/correlation"
......@@ -125,19 +126,19 @@ func (u *upstream) configureURLPrefix() {
}
func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
helper.FixRemoteAddr(r)
fixRemoteAddr(r)
nginx.DisableResponseBuffering(w)
// Drop RequestURI == "*" (FIXME: why?)
if r.RequestURI == "*" {
helper.HTTPError(w, r, "Connection upgrade not allowed", http.StatusBadRequest)
httpError(w, r, "Connection upgrade not allowed", http.StatusBadRequest)
return
}
// Disallow connect
if r.Method == "CONNECT" {
helper.HTTPError(w, r, "CONNECT not allowed", http.StatusBadRequest)
httpError(w, r, "CONNECT not allowed", http.StatusBadRequest)
return
}
......@@ -145,7 +146,7 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
URIPath := urlprefix.CleanURIPath(r.URL.EscapedPath())
prefix := u.URLPrefix
if !prefix.Match(URIPath) {
helper.HTTPError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
httpError(w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound)
return
}
......@@ -156,7 +157,7 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if route == nil {
// The protocol spec in git/Documentation/technical/http-protocol.txt
// says we must return 403 if no matching service is found.
helper.HTTPError(w, r, "Forbidden", http.StatusForbidden)
httpError(w, r, "Forbidden", http.StatusForbidden)
return
}
......@@ -276,3 +277,21 @@ func (u *upstream) updateGeoProxyFieldsFromData(geoProxyData *apipkg.GeoProxyDat
u.geoProxyCableRoute = u.wsRoute(`^/-/cable\z`, geoProxyUpstream)
u.geoProxyRoute = u.route("", "", geoProxyUpstream, withGeoProxy())
}
func httpError(w http.ResponseWriter, r *http.Request, error string, code int) {
if r.ProtoAtLeast(1, 1) {
// Force client to disconnect if we render request error
w.Header().Set("Connection", "close")
}
http.Error(w, error, code)
}
func fixRemoteAddr(r *http.Request) {
// Unix domain sockets have a remote addr of @. This will make the
// xff package lookup the X-Forwarded-For address if available.
if r.RemoteAddr == "@" {
r.RemoteAddr = "127.0.0.1:0"
}
r.RemoteAddr = xff.GetRemoteAddr(r)
}
......@@ -435,3 +435,33 @@ func startWorkhorseServer(railsServerURL string, enableGeoProxyFeature bool) (*h
return ws, ws.Close, waitForNextApiPoll
}
func TestFixRemoteAddr(t *testing.T) {
testCases := []struct {
initial string
forwarded string
expected string
}{
{initial: "@", forwarded: "", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
{initial: "@", forwarded: "127.0.0.1", expected: "127.0.0.1:0"},
{initial: "@", forwarded: "192.168.0.1", expected: "127.0.0.1:0"},
{initial: "192.168.1.1:0", forwarded: "", expected: "192.168.1.1:0"},
{initial: "192.168.1.1:0", forwarded: "18.245.0.1", expected: "18.245.0.1:0"},
}
for _, tc := range testCases {
req, err := http.NewRequest("POST", "unix:///tmp/test.socket/info/refs", nil)
require.NoError(t, err)
req.RemoteAddr = tc.initial
if tc.forwarded != "" {
req.Header.Add("X-Forwarded-For", tc.forwarded)
}
fixRemoteAddr(req)
require.Equal(t, tc.expected, req.RemoteAddr)
}
}
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册