diff --git a/workhorse/internal/upstream/routes.go b/workhorse/internal/upstream/routes.go index 7bf03f662db5d1fd4cb22ed6d5d8b83fee25b9cf..00bd32c7326387535be4765e025cb8c4a5d4737f 100644 --- a/workhorse/internal/upstream/routes.go +++ b/workhorse/internal/upstream/routes.go @@ -43,6 +43,7 @@ type routeOptions struct { tracing bool isGeoProxyRoute bool matchers []matcherFunc + allowOrigins *regexp.Regexp } const ( @@ -92,6 +93,12 @@ func withGeoProxy() func(*routeOptions) { } } +func withAllowOrigins(pattern string) func(*routeOptions) { + return func(options *routeOptions) { + options.allowOrigins = compileRegexp(pattern) + } +} + func (u *upstream) observabilityMiddlewares(handler http.Handler, method string, regexpStr string, opts *routeOptions) http.Handler { handler = log.AccessLogger( handler, @@ -128,6 +135,9 @@ func (u *upstream) route(method, regexpStr string, handler http.Handler, opts .. // Add distributed tracing handler = tracing.Handler(handler, tracing.WithRouteIdentifier(regexpStr)) } + if options.allowOrigins != nil { + handler = corsMiddleware(handler, options.allowOrigins) + } return routeEntry{ method: method, @@ -360,6 +370,7 @@ func configureRoutes(u *upstream) { assetsNotFoundHandler, ), withoutTracing(), // Tracing on assets is very noisy + withAllowOrigins("^https://.*\\.web-ide\\.gitlab-static\\.net$"), ), // Uploads @@ -425,6 +436,7 @@ func configureRoutes(u *upstream) { assetsNotFoundHandler, ), withoutTracing(), // Tracing on assets is very noisy + withAllowOrigins("^https://.*\\.web-ide\\.gitlab-static\\.net$"), ), // Don't define a catch-all route. If a route does not match, then we know @@ -442,3 +454,20 @@ func denyWebsocket(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +func corsMiddleware(next http.Handler, allowOriginRegex *regexp.Regexp) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestOrigin := r.Header.Get("Origin") + hasOriginMatch := allowOriginRegex.MatchString(requestOrigin) + hasMethodMatch := r.Method == "GET" || r.Method == "HEAD" || r.Method == "OPTIONS" + + if hasOriginMatch && hasMethodMatch { + w.Header().Set("Access-Control-Allow-Origin", requestOrigin) + // why: `Vary: Origin` is needed because allowable origin is variable + // https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#the_http_response_headers + w.Header().Set("Vary", "Origin") + } + + next.ServeHTTP(w, r) + }) +} diff --git a/workhorse/internal/upstream/routes_test.go b/workhorse/internal/upstream/routes_test.go index 09551b7f60589dfd125943479578d0af4855fb27..fde73e6371f7f3f57f430c5c34729f6ecca078d6 100644 --- a/workhorse/internal/upstream/routes_test.go +++ b/workhorse/internal/upstream/routes_test.go @@ -7,6 +7,24 @@ import ( "gitlab.com/gitlab-org/gitlab/workhorse/internal/testhelper" ) +func TestStaticCORS(t *testing.T) { + path := "/assets/static.txt" + content := "local geo asset" + testhelper.SetupStaticFileHelper(t, path, content, testDocumentRoot) + + testCases := []testCaseRequest{ + {"With no origin, does not set cors headers", "GET", "/assets/static.txt", map[string]string{}, map[string]string{"Access-Control-Allow-Origin": ""}}, + {"With unknown origin, does not set cors headers", "GET", "/assets/static.txt", map[string]string{"Origin": "https://example.com"}, map[string]string{"Access-Control-Allow-Origin": ""}}, + {"With known origin, sets cors headers", "GET", "/assets/static.txt", map[string]string{"Origin": "https://123.cdn.web-ide.gitlab-static.net"}, map[string]string{"Access-Control-Allow-Origin": "https://123.cdn.web-ide.gitlab-static.net", "Vary": "Origin"}}, + {"With known origin HEAD, sets cors headers", "HEAD", "/assets/static.txt", map[string]string{"Origin": "https://123.cdn.web-ide.gitlab-static.net"}, map[string]string{"Access-Control-Allow-Origin": "https://123.cdn.web-ide.gitlab-static.net", "Vary": "Origin"}}, + {"With known origin OPTIONS, sets cors headers", "OPTIONS", "/assets/static.txt", map[string]string{"Origin": "https://123.cdn.web-ide.gitlab-static.net"}, map[string]string{"Access-Control-Allow-Origin": "https://123.cdn.web-ide.gitlab-static.net", "Vary": "Origin"}}, + {"With known origin POST, does not set cors headers", "POST", "/assets/static.txt", map[string]string{"Origin": "https://123.cdn.web-ide.gitlab-static.net"}, map[string]string{"Access-Control-Allow-Origin": ""}}, + {"With evil origin, does not set cors headers", "GET", "/assets/static.txt", map[string]string{"Origin": "https://123.cdn.web-ide.gitlab-static.net.evil.com"}, map[string]string{"Access-Control-Allow-Origin": ""}}, + } + + runTestCasesWithGeoProxyEnabledRequest(t, testCases) +} + func TestAdminGeoPathsWithGeoProxy(t *testing.T) { testCases := []testCase{ {"Regular admin/geo", "/admin/geo", "Geo primary received request to path /admin/geo"}, diff --git a/workhorse/internal/upstream/upstream_test.go b/workhorse/internal/upstream/upstream_test.go index b852627aa14244019701ddc8e899d8e325db9c85..9fb8d127a3c3ec7409cd01c6ef2eefa21a3f407a 100644 --- a/workhorse/internal/upstream/upstream_test.go +++ b/workhorse/internal/upstream/upstream_test.go @@ -37,6 +37,14 @@ type testCasePost struct { body io.Reader } +type testCaseRequest struct { + desc string + method string + path string + headers map[string]string + expectedHeaders map[string]string +} + func TestMain(m *testing.M) { // Secret should be configured before any Geo API poll happens to prevent // race conditions where the first API call happens without a secret path @@ -367,6 +375,29 @@ func runTestCasesPost(t *testing.T, ws *httptest.Server, testCases []testCasePos } } +func runTestCasesRequest(t *testing.T, ws *httptest.Server, testCases []testCaseRequest) { + t.Helper() + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + client := http.Client{} + request, err := http.NewRequest(tc.method, ws.URL+tc.path, nil) + require.NoError(t, err) + for key, value := range tc.headers { + request.Header.Set(key, value) + } + + resp, err := client.Do(request) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, 200, resp.StatusCode, "response code") + for key, value := range tc.expectedHeaders { + require.Equal(t, resp.Header.Get(key), value, fmt.Sprint("response header ", key)) + } + }) + } +} + func runTestCasesWithGeoProxyEnabled(t *testing.T, testCases []testCase) { remoteServer := startRemoteServer(t) @@ -389,6 +420,17 @@ func runTestCasesWithGeoProxyEnabledPost(t *testing.T, testCases []testCasePost) runTestCasesPost(t, ws, testCases) } +func runTestCasesWithGeoProxyEnabledRequest(t *testing.T, testCases []testCaseRequest) { + remoteServer := startRemoteServer(t) + + geoProxyEndpointResponseBody := fmt.Sprintf(`{"geo_enabled":true,"geo_proxy_url":"%v"}`, remoteServer.URL) + railsServer := startRailsServer(t, &geoProxyEndpointResponseBody) + + ws, _ := startWorkhorseServer(t, railsServer.URL, true) + + runTestCasesRequest(t, ws, testCases) +} + func newUpstreamConfig(authBackend string) *config.Config { return &config.Config{ Version: "123",