diff --git a/internal/git/archive.go b/internal/git/archive.go index 55a74f0f407a46dc520d37ca2d1f96ff7cdce030..a0367ccdc44b2d165956c04d28ccc8edea916c39 100644 --- a/internal/git/archive.go +++ b/internal/git/archive.go @@ -138,14 +138,14 @@ func (a *archive) Inject(w http.ResponseWriter, r *http.Request, sendData string func setArchiveHeaders(w http.ResponseWriter, format string, archiveFilename string) { w.Header().Del("Content-Length") - w.Header().Add("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename)) + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, archiveFilename)) if format == "zip" { - w.Header().Add("Content-Type", "application/zip") + w.Header().Set("Content-Type", "application/zip") } else { - w.Header().Add("Content-Type", "application/octet-stream") + w.Header().Set("Content-Type", "application/octet-stream") } - w.Header().Add("Content-Transfer-Encoding", "binary") - w.Header().Add("Cache-Control", "private") + w.Header().Set("Content-Transfer-Encoding", "binary") + w.Header().Set("Cache-Control", "private") } func parseArchiveFormat(format string) (*exec.Cmd, string) { diff --git a/internal/git/archive_test.go b/internal/git/archive_test.go index e3e85baf263060f93899f896ada0572d7b1a92d8..8d7e67911019986875f08dbe01863d7dafd7ef44 100644 --- a/internal/git/archive_test.go +++ b/internal/git/archive_test.go @@ -2,7 +2,10 @@ package git import ( "io/ioutil" + "net/http/httptest" "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" ) func TestParseBasename(t *testing.T) { @@ -42,3 +45,27 @@ func TestFinalizeArchive(t *testing.T) { t.Fatalf("expected nil from finalizeCachedArchive, received %v", err) } } + +func TestSetArchiveHeaders(t *testing.T) { + for _, testCase := range []struct{ in, out string }{ + {"zip", "application/zip"}, + {"zippy", "application/octet-stream"}, + {"rezip", "application/octet-stream"}, + {"_anything_", "application/octet-stream"}, + } { + w := httptest.NewRecorder() + + // These should be replaced, not appended to + w.Header().Set("Content-Type", "test") + w.Header().Set("Content-Length", "test") + w.Header().Set("Content-Disposition", "test") + w.Header().Set("Cache-Control", "test") + + setArchiveHeaders(w, testCase.in, "filename") + + testhelper.AssertResponseHeader(t, w, "Content-Type", testCase.out) + testhelper.AssertResponseHeader(t, w, "Content-Length") + testhelper.AssertResponseHeader(t, w, "Content-Disposition", `attachment; filename="filename"`) + testhelper.AssertResponseHeader(t, w, "Cache-Control", "private") + } +} diff --git a/internal/git/git-http.go b/internal/git/git-http.go index 99cf80f5ac67b73b8db0337634b9174706a52a5d..818be8ea6f1ddeb380a5262c72a94e0e19fc485a 100644 --- a/internal/git/git-http.go +++ b/internal/git/git-http.go @@ -77,8 +77,8 @@ func handleGetInfoRefs(w http.ResponseWriter, r *http.Request, a *api.Response) defer helper.CleanUpProcessGroup(cmd) // Ensure brute force subprocess clean-up // Start writing the response - w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc)) - w.Header().Add("Cache-Control", "no-cache") + w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", rpc)) + w.Header().Set("Cache-Control", "no-cache") w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return if err := pktLine(w, fmt.Sprintf("# service=%s\n", rpc)); err != nil { helper.LogError(r, fmt.Errorf("handleGetInfoRefs: pktLine: %v", err)) @@ -164,8 +164,8 @@ func handlePostRPC(w http.ResponseWriter, r *http.Request, a *api.Response) { r.Body.Close() // Start writing the response - w.Header().Add("Content-Type", fmt.Sprintf("application/x-%s-result", action)) - w.Header().Add("Cache-Control", "no-cache") + w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-result", action)) + w.Header().Set("Cache-Control", "no-cache") w.WriteHeader(200) // Don't bother with HTTP 500 from this point on, just return // This io.Copy may take a long time, both for Git push and pull. diff --git a/internal/staticpages/servefile_test.go b/internal/staticpages/servefile_test.go index 652573d799d43865f50b8194bc9f04151fc6eda3..2140262024f09cf2bb46acec0aacdc06bccf83a1 100644 --- a/internal/staticpages/servefile_test.go +++ b/internal/staticpages/servefile_test.go @@ -116,7 +116,7 @@ func testServingThePregzippedFile(t *testing.T, enableGzip bool) { } } else { testhelper.AssertResponseCode(t, w, 200) - testhelper.AssertResponseHeader(t, w, "Content-Encoding", "") + testhelper.AssertResponseHeader(t, w, "Content-Encoding") if w.Body.String() != fileContent { t.Error("We should serve the file: ", w.Body.String()) } diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go index 4bb48de11ab8505ea575a7f7251c8da243a2ec3b..6b82db898a97bef1e3a196a14cdcba7b1ad68edd 100644 --- a/internal/testhelper/testhelper.go +++ b/internal/testhelper/testhelper.go @@ -63,9 +63,17 @@ func AssertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expec } } -func AssertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) { - if response.Header().Get(header) != expectedValue { - t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header)) +func AssertResponseHeader(t *testing.T, w http.ResponseWriter, header string, expected ...string) { + actual := w.Header()[http.CanonicalHeaderKey(header)] + + if len(expected) != len(actual) { + t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual) + } + + for i, value := range expected { + if value != actual[i] { + t.Fatalf("for HTTP request expected to receive the header %q with %+v, got %+v", header, expected, actual) + } } }