diff --git a/authorization_test.go b/authorization_test.go index f7119dfe1b21c100a3ff6f2f131a0c7c6dfdc4ca..8be1fdcf73a658c201eb4f560cfd2fea56e44e36 100644 --- a/authorization_test.go +++ b/authorization_test.go @@ -1,12 +1,15 @@ package main import ( + "context" "fmt" "net/http" "net/http/httptest" "regexp" "testing" + "gitlab.com/gitlab-org/labkit/correlation" + "github.com/dgrijalva/jwt-go" "github.com/stretchr/testify/require" @@ -24,12 +27,13 @@ func okHandler(w http.ResponseWriter, _ *http.Request, _ *api.Response) { func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, url *regexp.Regexp, apiResponse interface{}, returnCode, expectedCode int) *httptest.ResponseRecorder { if ts == nil { - ts = testAuthServer(url, nil, returnCode, apiResponse) + ts = testAuthServer(t, url, nil, returnCode, apiResponse) defer ts.Close() } // Create http request - httpRequest, err := http.NewRequest("GET", "/address", nil) + ctx := correlation.ContextWithCorrelation(context.Background(), "12345678") + httpRequest, err := http.NewRequestWithContext(ctx, "GET", "/address", nil) require.NoError(t, err) parsedURL := helper.URLMustParse(ts.URL) testhelper.ConfigureSecret() diff --git a/changelogs/unreleased/sh-fix-correlation-id-for-preauth.yml b/changelogs/unreleased/sh-fix-correlation-id-for-preauth.yml new file mode 100644 index 0000000000000000000000000000000000000000..41947ab00eaa0feb5f2f44535417cf3ec8157d18 --- /dev/null +++ b/changelogs/unreleased/sh-fix-correlation-id-for-preauth.yml @@ -0,0 +1,5 @@ +--- +title: Fix correlation IDs not being propagated in preauth check +merge_request: 607 +author: +type: fixed diff --git a/channel_test.go b/channel_test.go index d294f646bc07a2d2a25fc4d7f125261bc5c7dc83..cd8957ed829526ff488c1c9c98ad12671e73dea4 100644 --- a/channel_test.go +++ b/channel_test.go @@ -42,7 +42,7 @@ func TestChannelHappyPath(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - serverConns, clientURL, close := wireupChannel(test.channelPath, nil, "channel.k8s.io") + serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io") defer close() client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") @@ -70,7 +70,7 @@ func TestChannelHappyPath(t *testing.T) { } func TestChannelBadTLS(t *testing.T) { - _, clientURL, close := wireupChannel(envTerminalPath, badCA, "channel.k8s.io") + _, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io") defer close() _, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") @@ -78,7 +78,7 @@ func TestChannelBadTLS(t *testing.T) { } func TestChannelSessionTimeout(t *testing.T) { - serverConns, clientURL, close := wireupChannel(envTerminalPath, timeout, "channel.k8s.io") + serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io") defer close() client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") @@ -96,7 +96,7 @@ func TestChannelSessionTimeout(t *testing.T) { func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { hdr := make(http.Header) hdr.Set("Random-Header", "Value") - serverConns, clientURL, close := wireupChannel(envTerminalPath, setHeader(hdr), "channel.k8s.io") + serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io") defer close() client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com") @@ -109,7 +109,7 @@ func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) { } func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { - serverConns, clientURL, close := wireupChannel(envTerminalPath, nil, "channel.k8s.io") + serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io") defer close() hdr := make(http.Header) @@ -127,13 +127,13 @@ func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) { require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote") } -func wireupChannel(channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) { +func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) { serverConns, remote := startWebsocketServer(subprotocols...) authResponse := channelOkBody(remote, nil, subprotocols...) if modifier != nil { modifier(authResponse) } - upstream := testAuthServer(nil, nil, 200, authResponse) + upstream := testAuthServer(t, nil, nil, 200, authResponse) workhorse := startWorkhorseServer(upstream.URL) return serverConns, websocketURL(workhorse.URL, channelPath), func() { diff --git a/gitaly_integration_test.go b/gitaly_integration_test.go index 1aa73c851aed13b41a1725b91fc4bae70652913d..418d9589235e4a155d26dcd8a0e3637811653c31 100644 --- a/gitaly_integration_test.go +++ b/gitaly_integration_test.go @@ -90,7 +90,7 @@ func TestAllowedClone(t *testing.T) { require.NoError(t, ensureGitalyRepository(t, apiResponse)) // Prepare test server and backend - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() @@ -114,7 +114,7 @@ func TestAllowedShallowClone(t *testing.T) { require.NoError(t, ensureGitalyRepository(t, apiResponse)) // Prepare test server and backend - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() @@ -138,7 +138,7 @@ func TestAllowedPush(t *testing.T) { require.NoError(t, ensureGitalyRepository(t, apiResponse)) // Prepare the test server and backend - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() diff --git a/gitaly_test.go b/gitaly_test.go index d571e697ff0c6db53999d003787ae01b5c6264a7..95d6907ac6a190caf8afff86591853ee57843c3a 100644 --- a/gitaly_test.go +++ b/gitaly_test.go @@ -43,7 +43,7 @@ func TestFailedCloneNoGitaly(t *testing.T) { } // Prepare test server and backend - ts := testAuthServer(nil, nil, 200, authBody) + ts := testAuthServer(t, nil, nil, 200, authBody) defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() @@ -95,7 +95,7 @@ func TestGetInfoRefsProxiedToGitalySuccessfully(t *testing.T) { t.Run(fmt.Sprintf("ShowAllRefs=%v,gitRpc=%v", tc.showAllRefs, tc.gitRpc), func(t *testing.T) { apiResponse.ShowAllRefs = tc.showAllRefs - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) @@ -147,7 +147,7 @@ func TestGetInfoRefsProxiedToGitalyInterruptedStream(t *testing.T) { gitalyAddress := "unix:" + socketPath apiResponse.GitalyServer.Address = gitalyAddress - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) @@ -187,7 +187,7 @@ func TestPostReceivePackProxiedToGitalySuccessfully(t *testing.T) { apiResponse.GitalyServer.Address = "unix:" + socketPath apiResponse.GitConfigOptions = []string{"git-config-hello=world"} - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) @@ -232,7 +232,7 @@ func TestPostReceivePackProxiedToGitalyInterrupted(t *testing.T) { defer gitalyServer.GracefulStop() apiResponse.GitalyServer.Address = "unix:" + socketPath - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) @@ -282,7 +282,7 @@ func TestPostUploadPackProxiedToGitalySuccessfully(t *testing.T) { defer gitalyServer.GracefulStop() apiResponse.GitalyServer.Address = "unix:" + socketPath - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) @@ -349,7 +349,7 @@ func TestPostUploadPackProxiedToGitalyInterrupted(t *testing.T) { defer gitalyServer.GracefulStop() apiResponse.GitalyServer.Address = "unix:" + socketPath - ts := testAuthServer(nil, nil, 200, apiResponse) + ts := testAuthServer(t, nil, nil, 200, apiResponse) defer ts.Close() ws := startWorkhorseServer(ts.URL) diff --git a/go.sum b/go.sum index 4a07f03801f76b0453bb875087e5d6159df730fb..28707bcc4553e9709cca278c71d19fd7aad1cdbe 100644 --- a/go.sum +++ b/go.sum @@ -456,6 +456,7 @@ github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnIn github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= @@ -837,11 +838,11 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= -honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4 h1:UoveltGrhghAA7ePc+e+QYDHXrBps2PqFZiHkGR/xK8= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= +honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= +honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/api/api.go b/internal/api/api.go index 5db947a157f82aa90a8215081ec088370219e632..4ddc66bf26e9570dd44ab81d77510021f91d91f6 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -188,6 +188,8 @@ func (api *API) newRequest(r *http.Request, suffix string) (*http.Request, error Header: helper.HeaderClone(r.Header), } + authReq = authReq.WithContext(r.Context()) + // Clean some headers when issuing a new request without body authReq.Header.Del("Content-Type") authReq.Header.Del("Content-Encoding") diff --git a/main_test.go b/main_test.go index a5c0a8aac42785ce4e6e9864d236e2925efbed89..bc4b8e2408272500a6a8b7e8a4ab847fa4b74a2b 100644 --- a/main_test.go +++ b/main_test.go @@ -63,7 +63,7 @@ func TestDeniedClone(t *testing.T) { require.NoError(t, os.RemoveAll(scratchDir)) // Prepare test server and backend - ts := testAuthServer(nil, nil, 403, "Access denied") + ts := testAuthServer(t, nil, nil, 403, "Access denied") defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() @@ -77,7 +77,7 @@ func TestDeniedClone(t *testing.T) { func TestDeniedPush(t *testing.T) { // Prepare the test server and backend - ts := testAuthServer(nil, nil, 403, "Access denied") + ts := testAuthServer(t, nil, nil, 403, "Access denied") defer ts.Close() ws := startWorkhorseServer(ts.URL) defer ws.Close() @@ -594,8 +594,10 @@ func newBranch() string { return fmt.Sprintf("branch-%d", time.Now().UnixNano()) } -func testAuthServer(url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server { +func testAuthServer(t *testing.T, url *regexp.Regexp, params url.Values, code int, body interface{}) *httptest.Server { return testhelper.TestServerWithHandler(url, func(w http.ResponseWriter, r *http.Request) { + require.NotEmpty(t, r.Header.Get("X-Request-Id")) + w.Header().Set("Content-Type", api.ResponseContentType) logEntry := log.WithFields(log.Fields{