diff --git a/authorization_test.go b/authorization_test.go index 602389cc2f88fefcf32c30045dfbde359a758973..3d154d6c71d7a9062e69e95d6d43d4ab5bc97aae 100644 --- a/authorization_test.go +++ b/authorization_test.go @@ -10,6 +10,7 @@ import ( "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" "github.com/dgrijalva/jwt-go" @@ -32,7 +33,8 @@ func runPreAuthorizeHandler(t *testing.T, ts *httptest.Server, suffix string, ur t.Fatal(err) } parsedURL := helper.URLMustParse(ts.URL) - a := api.NewAPI(parsedURL, "123", testhelper.SecretPath(), badgateway.TestRoundTripper(parsedURL)) + testhelper.ConfigureSecret() + a := api.NewAPI(parsedURL, "123", badgateway.TestRoundTripper(parsedURL)) response := httptest.NewRecorder() a.PreAuthorizeHandler(okHandler, suffix).ServeHTTP(response, httpRequest) @@ -86,7 +88,8 @@ func TestPreAuthorizeJWT(t *testing.T) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) } - secretBytes, err := (&api.Secret{Path: testhelper.SecretPath()}).Bytes() + testhelper.ConfigureSecret() + secretBytes, err := secret.Bytes() if err != nil { return nil, fmt.Errorf("read secret from file: %v", err) } diff --git a/internal/api/api.go b/internal/api/api.go index 55700b01b8cc06b224fed05f6d6648a684a51f16..d642dae9d66b0a7cde62b55d8d1ab6561f6feb7d 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -11,8 +11,7 @@ import ( "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" - - "github.com/dgrijalva/jwt-go" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" ) // Custom content type for API responses, to catch routing / programming mistakes @@ -24,15 +23,13 @@ type API struct { Client *http.Client URL *url.URL Version string - Secret *Secret } -func NewAPI(myURL *url.URL, version, secretPath string, roundTripper *badgateway.RoundTripper) *API { +func NewAPI(myURL *url.URL, version string, roundTripper *badgateway.RoundTripper) *API { return &API{ Client: &http.Client{Transport: roundTripper}, URL: myURL, Version: version, - Secret: &Secret{Path: secretPath}, } } @@ -130,13 +127,7 @@ func (api *API) newRequest(r *http.Request, body io.Reader, suffix string) (*htt // configurations (Passenger) to solve auth request routing problems. authReq.Header.Set("Gitlab-Workhorse", api.Version) - secretBytes, err := api.Secret.Bytes() - if err != nil { - return nil, fmt.Errorf("newRequest: %v", err) - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{Issuer: "gitlab-workhorse"}) - tokenString, err := token.SignedString(secretBytes) + tokenString, err := secret.JWTTokenString(secret.DefaultClaims) if err != nil { return nil, fmt.Errorf("newRequest: sign JWT: %v", err) } diff --git a/internal/api/secret.go b/internal/api/secret.go deleted file mode 100644 index bc9d27f3ef1ddfbf6db193e2e193cea9f1a54340..0000000000000000000000000000000000000000 --- a/internal/api/secret.go +++ /dev/null @@ -1,60 +0,0 @@ -package api - -import ( - "encoding/base64" - "fmt" - "io/ioutil" - "sync" -) - -const numSecretBytes = 32 - -type Secret struct { - Path string - bytes []byte - sync.RWMutex -} - -// Lazy access to the HMAC secret key. We must be lazy because if the key -// is not already there, it will be generated by gitlab-rails, and -// gitlab-rails is slow. -func (s *Secret) Bytes() ([]byte, error) { - if bytes := s.getBytes(); bytes != nil { - return bytes, nil - } - - return s.setBytes() -} - -func (s *Secret) getBytes() []byte { - s.RLock() - defer s.RUnlock() - return s.bytes -} - -func (s *Secret) setBytes() ([]byte, error) { - s.Lock() - defer s.Unlock() - - if s.bytes != nil { - return s.bytes, nil - } - - base64Bytes, err := ioutil.ReadFile(s.Path) - if err != nil { - return nil, fmt.Errorf("read Secret.Path: %v", err) - } - - secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes))) - n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes) - if err != nil { - return nil, fmt.Errorf("decode secret: %v", err) - } - - if n != numSecretBytes { - return nil, fmt.Errorf("expected %d secretBytes in %s, found %d", numSecretBytes, s.Path, n) - } - - s.bytes = secretBytes - return s.bytes, nil -} diff --git a/internal/artifacts/artifacts_upload.go b/internal/artifacts/artifacts_upload.go index cc1095bf15d9d1ae7844855aab8f0388f29e3cb4..65d823c46f6026af739e430ba039de0872694bb5 100644 --- a/internal/artifacts/artifacts_upload.go +++ b/internal/artifacts/artifacts_upload.go @@ -65,6 +65,10 @@ func (a *artifactsUploadProcessor) ProcessField(formName string, writer *multipa return nil } +func (a *artifactsUploadProcessor) Finalize() error { + return nil +} + func (a *artifactsUploadProcessor) Cleanup() { if a.metadataFile != "" { os.Remove(a.metadataFile) diff --git a/internal/artifacts/artifacts_upload_test.go b/internal/artifacts/artifacts_upload_test.go index e9a214e99e3c36396958422818863de1a500e13a..ccc26d7d699759a7879276c43ec806b35d8a65a4 100644 --- a/internal/artifacts/artifacts_upload_test.go +++ b/internal/artifacts/artifacts_upload_test.go @@ -93,7 +93,8 @@ func testUploadArtifacts(contentType string, body io.Reader, t *testing.T, ts *h response := httptest.NewRecorder() parsedURL := helper.URLMustParse(ts.URL) roundTripper := badgateway.TestRoundTripper(parsedURL) - apiClient := api.NewAPI(parsedURL, "123", testhelper.SecretPath(), roundTripper) + testhelper.ConfigureSecret() + apiClient := api.NewAPI(parsedURL, "123", roundTripper) proxyClient := proxy.NewProxy(parsedURL, "123", roundTripper) UploadArtifacts(apiClient, proxyClient).ServeHTTP(response, httpRequest) return response diff --git a/internal/secret/jwt.go b/internal/secret/jwt.go new file mode 100644 index 0000000000000000000000000000000000000000..04335e58f760e2f6e8c0c038f3d5f40fd554ce2e --- /dev/null +++ b/internal/secret/jwt.go @@ -0,0 +1,25 @@ +package secret + +import ( + "fmt" + + "github.com/dgrijalva/jwt-go" +) + +var ( + DefaultClaims = jwt.StandardClaims{Issuer: "gitlab-workhorse"} +) + +func JWTTokenString(claims jwt.Claims) (string, error) { + secretBytes, err := Bytes() + if err != nil { + return "", fmt.Errorf("secret.JWTTokenString: %v", err) + } + + tokenString, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(secretBytes) + if err != nil { + return "", fmt.Errorf("secret.JWTTokenString: sign JWT: %v", err) + } + + return tokenString, nil +} diff --git a/internal/secret/secret.go b/internal/secret/secret.go new file mode 100644 index 0000000000000000000000000000000000000000..e8c7c25393c34a6779722dcb14e4b272fd29239a --- /dev/null +++ b/internal/secret/secret.go @@ -0,0 +1,77 @@ +package secret + +import ( + "encoding/base64" + "fmt" + "io/ioutil" + "sync" +) + +const numSecretBytes = 32 + +type sec struct { + path string + bytes []byte + sync.RWMutex +} + +var ( + theSecret = &sec{} +) + +func SetPath(path string) { + theSecret.Lock() + defer theSecret.Unlock() + theSecret.path = path + theSecret.bytes = nil +} + +// Lazy access to the HMAC secret key. We must be lazy because if the key +// is not already there, it will be generated by gitlab-rails, and +// gitlab-rails is slow. +func Bytes() ([]byte, error) { + if bytes := getBytes(); bytes != nil { + return copyBytes(bytes), nil + } + + return setBytes() +} + +func getBytes() []byte { + theSecret.RLock() + defer theSecret.RUnlock() + return theSecret.bytes +} + +func copyBytes(bytes []byte) []byte { + out := make([]byte, len(bytes)) + copy(out, bytes) + return out +} + +func setBytes() ([]byte, error) { + theSecret.Lock() + defer theSecret.Unlock() + + if theSecret.bytes != nil { + return theSecret.bytes, nil + } + + base64Bytes, err := ioutil.ReadFile(theSecret.path) + if err != nil { + return nil, fmt.Errorf("secret.setBytes: read %q: %v", theSecret.path, err) + } + + secretBytes := make([]byte, base64.StdEncoding.DecodedLen(len(base64Bytes))) + n, err := base64.StdEncoding.Decode(secretBytes, base64Bytes) + if err != nil { + return nil, fmt.Errorf("secret.setBytes: decode secret: %v", err) + } + + if n != numSecretBytes { + return nil, fmt.Errorf("secret.setBytes: expected %d secretBytes in %s, found %d", numSecretBytes, theSecret.path, n) + } + + theSecret.bytes = secretBytes + return copyBytes(theSecret.bytes), nil +} diff --git a/internal/testhelper/testhelper.go b/internal/testhelper/testhelper.go index 6b82db898a97bef1e3a196a14cdcba7b1ad68edd..bf5e518e3557997cf74cedf30ca2cbc1281d5d9c 100644 --- a/internal/testhelper/testhelper.go +++ b/internal/testhelper/testhelper.go @@ -16,10 +16,12 @@ import ( "runtime" "strings" "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" ) -func SecretPath() string { - return path.Join(RootDir(), "testdata/test-secret") +func ConfigureSecret() { + secret.SetPath(path.Join(RootDir(), "testdata/test-secret")) } var extractPatchSeriesMatcher = regexp.MustCompile(`^From (\w+)`) diff --git a/internal/upload/accelerate.go b/internal/upload/accelerate.go new file mode 100644 index 0000000000000000000000000000000000000000..8e7a856e8673eda839c570f3c8bc6d1b6eea39e3 --- /dev/null +++ b/internal/upload/accelerate.go @@ -0,0 +1,57 @@ +package upload + +import ( + "fmt" + "mime/multipart" + "net/http" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" + + "github.com/dgrijalva/jwt-go" +) + +const RewrittenFieldsHeader = "Gitlab-Workhorse-Multipart-Fields" + +type savedFileTracker struct { + request *http.Request + rewrittenFields map[string]string +} + +type MultipartClaims struct { + RewrittenFields map[string]string `json:"rewritten_fields"` + jwt.StandardClaims +} + +func Accelerate(tempDir string, h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s := &savedFileTracker{request: r} + HandleFileUploads(w, r, h, tempDir, s) + }) +} + +func (s *savedFileTracker) ProcessFile(fieldName, fileName string, _ *multipart.Writer) error { + if s.rewrittenFields == nil { + s.rewrittenFields = make(map[string]string) + } + s.rewrittenFields[fieldName] = fileName + return nil +} + +func (_ *savedFileTracker) ProcessField(_ string, _ *multipart.Writer) error { + return nil +} + +func (s *savedFileTracker) Finalize() error { + if s.rewrittenFields == nil { + return nil + } + + claims := MultipartClaims{s.rewrittenFields, secret.DefaultClaims} + tokenString, err := secret.JWTTokenString(claims) + if err != nil { + return fmt.Errorf("savedFileTracker.Finalize: %v", err) + } + + s.request.Header.Set(RewrittenFieldsHeader, tokenString) + return nil +} diff --git a/internal/upload/uploads.go b/internal/upload/uploads.go index c548167c48cd07c6b51e53b5b7fc657436a092bd..eb630cc084a81d0c5e2c46cbe9d87ea2c221babe 100644 --- a/internal/upload/uploads.go +++ b/internal/upload/uploads.go @@ -8,13 +8,17 @@ import ( "mime/multipart" "net/http" "os" + "path" + "strings" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" ) +// These methods are allowed to have thread-unsafe implementations. type MultipartFormProcessor interface { ProcessFile(formName, fileName string, writer *multipart.Writer) error ProcessField(formName string, writer *multipart.Writer) error + Finalize() error } func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, tempPath string, filter MultipartFormProcessor) (cleanup func(), err error) { @@ -28,11 +32,11 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te return nil, fmt.Errorf("get multipart reader: %v", err) } - var files []string + var directories []string cleanup = func() { - for _, file := range files { - os.Remove(file) + for _, dir := range directories { + os.RemoveAll(dir) } } @@ -56,22 +60,30 @@ func rewriteFormFilesFromMultipart(r *http.Request, writer *multipart.Writer, te // Copy form field if filename := p.FileName(); filename != "" { + if strings.Contains(filename, "/") || filename == "." || filename == ".." { + return cleanup, fmt.Errorf("illegal filename: %q", filename) + } + // Create temporary directory where the uploaded file will be stored if err := os.MkdirAll(tempPath, 0700); err != nil { return cleanup, fmt.Errorf("mkdir for tempfile: %v", err) } - // Create temporary file in path returned by Authorization filter - file, err := ioutil.TempFile(tempPath, "upload_") + tempDir, err := ioutil.TempDir(tempPath, "multipart-") + if err != nil { + return cleanup, fmt.Errorf("create tempdir: %v", err) + } + directories = append(directories, tempDir) + + file, err := os.OpenFile(path.Join(tempDir, filename), os.O_WRONLY|os.O_CREATE, 0600) if err != nil { - return cleanup, fmt.Errorf("create tempfile: %v", err) + return cleanup, fmt.Errorf("rewriteFormFilesFromMultipart: temp file: %v", err) } defer file.Close() // Add file entry writer.WriteField(name+".path", file.Name()) writer.WriteField(name+".name", filename) - files = append(files, file.Name()) _, err = io.Copy(file, p) if err != nil { @@ -135,6 +147,11 @@ func HandleFileUploads(w http.ResponseWriter, r *http.Request, h http.Handler, t r.ContentLength = int64(body.Len()) r.Header.Set("Content-Type", writer.FormDataContentType()) + if err := filter.Finalize(); err != nil { + helper.Fail500(w, r, fmt.Errorf("handleFileUploads: Finalize: %v", err)) + return + } + // Proxy the request h.ServeHTTP(w, r) } diff --git a/internal/upload/uploads_test.go b/internal/upload/uploads_test.go index f0435cdefa67d8742de4ce448f0c83bc118d7271..048ff4adb24e775ae1b5126c698e27264ea342dc 100644 --- a/internal/upload/uploads_test.go +++ b/internal/upload/uploads_test.go @@ -22,8 +22,7 @@ import ( var nilHandler = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) -type testFormProcessor struct { -} +type testFormProcessor struct{} func (a *testFormProcessor) ProcessFile(formName, fileName string, writer *multipart.Writer) error { if formName != "file" && fileName != "my.file" { @@ -39,6 +38,10 @@ func (a *testFormProcessor) ProcessField(formName string, writer *multipart.Writ return nil } +func (a *testFormProcessor) Finalize() error { + return nil +} + func TestUploadTempPathRequirement(t *testing.T) { response := httptest.NewRecorder() request, err := http.NewRequest("", "", nil) @@ -214,6 +217,47 @@ func TestUploadProcessingFile(t *testing.T) { testhelper.AssertResponseCode(t, response, 500) } +func TestInvalidFileNames(t *testing.T) { + testhelper.ConfigureSecret() + + tempPath, err := ioutil.TempDir("", "uploads") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempPath) + + for _, testCase := range []struct { + filename string + code int + }{ + {"foobar", 200}, // sanity check for test setup below + {"foo/bar", 500}, + {"/../../foobar", 500}, + {".", 500}, + {"..", 500}, + } { + buffer := &bytes.Buffer{} + + writer := multipart.NewWriter(buffer) + file, err := writer.CreateFormFile("file", testCase.filename) + if err != nil { + t.Fatal(err) + } + fmt.Fprint(file, "test") + writer.Close() + + httpRequest, err := http.NewRequest("POST", "/example", buffer) + if err != nil { + t.Fatal(err) + } + httpRequest.Header.Set("Content-Type", writer.FormDataContentType()) + + response := httptest.NewRecorder() + HandleFileUploads(response, httpRequest, nilHandler, tempPath, &savedFileTracker{request: httpRequest}) + testhelper.AssertResponseCode(t, response, testCase.code) + } +} + func newProxy(url string) *proxy.Proxy { parsedURL := helper.URLMustParse(url) return proxy.NewProxy(parsedURL, "123", badgateway.TestRoundTripper(parsedURL)) diff --git a/internal/upstream/routes.go b/internal/upstream/routes.go index 17db205227cc94fae3dc7c95776f5de1aa5d20a2..e17b156557d757e9e3d1f0f4dcbea00d2ae3839f 100644 --- a/internal/upstream/routes.go +++ b/internal/upstream/routes.go @@ -2,6 +2,7 @@ package upstream import ( "net/http" + "path" "regexp" "github.com/gorilla/websocket" @@ -17,6 +18,7 @@ import ( "gitlab.com/gitlab-org/gitlab-workhorse/internal/sendfile" "gitlab.com/gitlab-org/gitlab-workhorse/internal/staticpages" "gitlab.com/gitlab-org/gitlab-workhorse/internal/terminal" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" ) type matcherFunc func(*http.Request) bool @@ -91,7 +93,6 @@ func (u *Upstream) configureRoutes() { api := apipkg.NewAPI( u.Backend, u.Version, - u.SecretPath, u.RoundTripper, ) static := &staticpages.Static{u.DocumentRoot} @@ -109,7 +110,9 @@ func (u *Upstream) configureRoutes() { git.SendPatch, artifacts.SendEntry, ) - ciAPIProxyQueue := queueing.QueueRequests(proxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout) + + uploadAccelerateProxy := upload.Accelerate(path.Join(u.DocumentRoot, "uploads/tmp"), proxy) + ciAPIProxyQueue := queueing.QueueRequests(uploadAccelerateProxy, u.APILimit, u.APIQueueLimit, u.APIQueueTimeout) u.Routes = []routeEntry{ // Git Clone @@ -153,7 +156,7 @@ func (u *Upstream) configureRoutes() { static.ServeExisting( u.URLPrefix, staticpages.CacheDisabled, - static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, proxy)), + static.DeployPage(static.ErrorPagesUnless(u.DevelopmentMode, uploadAccelerateProxy)), ), ), } diff --git a/internal/upstream/upstream.go b/internal/upstream/upstream.go index 9bf48a1d3dc54edcfc648372ad0ed66cbc89ba4c..60ff8596fe6b1dfc5c6064e97f76d027cddbf990 100644 --- a/internal/upstream/upstream.go +++ b/internal/upstream/upstream.go @@ -15,15 +15,20 @@ import ( "gitlab.com/gitlab-org/gitlab-workhorse/internal/badgateway" "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" "gitlab.com/gitlab-org/gitlab-workhorse/internal/urlprefix" ) -var DefaultBackend = helper.URLMustParse("http://localhost:8080") +var ( + DefaultBackend = helper.URLMustParse("http://localhost:8080") + requestHeaderBlacklist = []string{ + upload.RewrittenFieldsHeader, + } +) type Config struct { Backend *url.URL Version string - SecretPath string DocumentRoot string DevelopmentMode bool Socket string @@ -103,5 +108,9 @@ func (u *Upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { return } + for _, h := range requestHeaderBlacklist { + r.Header.Del(h) + } + route.handler.ServeHTTP(w, r) } diff --git a/main.go b/main.go index 169abd713d644756335edb4394ed6c8ce8c42998..fc3422786e4adee46d7da38092f30869e0b47eba 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ import ( "time" "gitlab.com/gitlab-org/gitlab-workhorse/internal/queueing" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" "gitlab.com/gitlab-org/gitlab-workhorse/internal/upstream" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -106,11 +107,11 @@ func main() { }() } + secret.SetPath(*secretPath) upConfig := upstream.Config{ Backend: backendURL, Socket: *authSocket, Version: Version, - SecretPath: *secretPath, DocumentRoot: *documentRoot, DevelopmentMode: *developmentMode, ProxyHeadersTimeout: *proxyHeadersTimeout, diff --git a/main_test.go b/main_test.go index 7047ce9c92770ea150d79812b0679e96a52661ed..d36a89477bd7700e08baaff3e5b3d5ae566b341b 100644 --- a/main_test.go +++ b/main_test.go @@ -8,7 +8,6 @@ import ( "io" "io/ioutil" "log" - "mime/multipart" "net/http" "net/http/httptest" "os" @@ -371,52 +370,6 @@ func TestDeniedPublicUploadsFile(t *testing.T) { } } -func TestArtifactsUpload(t *testing.T) { - reqBody := &bytes.Buffer{} - writer := multipart.NewWriter(reqBody) - file, err := writer.CreateFormFile("file", "my.file") - if err != nil { - t.Fatal(err) - } - fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART") - writer.Close() - - ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { - if strings.HasSuffix(r.URL.Path, "/authorize") { - w.Header().Set("Content-Type", api.ResponseContentType) - if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil { - t.Fatal(err) - } - return - } - err := r.ParseMultipartForm(100000) - if err != nil { - t.Fatal(err) - } - nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file) - if len(r.MultipartForm.Value) != nValues { - t.Errorf("Expected to receive exactly %d values", nValues) - } - if len(r.MultipartForm.File) != 0 { - t.Error("Expected to not receive any files") - } - w.WriteHeader(200) - }) - defer ts.Close() - ws := startWorkhorseServer(ts.URL) - defer ws.Close() - - resource := `/ci/api/v1/builds/123/artifacts` - resp, err := http.Post(ws.URL+resource, writer.FormDataContentType(), reqBody) - if err != nil { - t.Error(err) - } - defer resp.Body.Close() - if resp.StatusCode != 200 { - t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode) - } -} - var sendDataHeader = "Gitlab-Workhorse-Send-Data" func sendDataResponder(command string, literalJSON string) *httptest.Server { @@ -691,10 +644,10 @@ func archiveOKServer(t *testing.T, archiveName string) *httptest.Server { } func startWorkhorseServer(authBackend string) *httptest.Server { + testhelper.ConfigureSecret() config := upstream.Config{ Backend: helper.URLMustParse(authBackend), Version: "123", - SecretPath: testhelper.SecretPath(), DocumentRoot: testDocumentRoot, } diff --git a/upload_test.go b/upload_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1c57d5f47cd1d35617d8e842771bb5fe08e0b453 --- /dev/null +++ b/upload_test.go @@ -0,0 +1,177 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "testing" + + "gitlab.com/gitlab-org/gitlab-workhorse/internal/api" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/secret" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/testhelper" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/upload" + + "github.com/dgrijalva/jwt-go" +) + +func TestArtifactsUpload(t *testing.T) { + reqBody, contentType, err := multipartBodyWithFile() + if err != nil { + t.Fatal(err) + } + + ts := uploadTestServer(t, nil) + defer ts.Close() + ws := startWorkhorseServer(ts.URL) + defer ws.Close() + + resource := `/ci/api/v1/builds/123/artifacts` + resp, err := http.Post(ws.URL+resource, contentType, reqBody) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode) + } +} + +func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest.Server { + return testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/authorize") { + w.Header().Set("Content-Type", api.ResponseContentType) + if _, err := fmt.Fprintf(w, `{"TempPath":"%s"}`, scratchDir); err != nil { + t.Fatal(err) + } + return + } + err := r.ParseMultipartForm(100000) + if err != nil { + t.Fatal(err) + } + nValues := 2 // filename + path for just the upload (no metadata because we are not POSTing a valid zip file) + if len(r.MultipartForm.Value) != nValues { + t.Errorf("Expected to receive exactly %d values", nValues) + } + if len(r.MultipartForm.File) != 0 { + t.Error("Expected to not receive any files") + } + if extraTests != nil { + extraTests(r) + } + w.WriteHeader(200) + }) +} + +func TestAcceleratedUpload(t *testing.T) { + reqBody, contentType, err := multipartBodyWithFile() + if err != nil { + t.Fatal(err) + } + ts := uploadTestServer(t, func(r *http.Request) { + jwtToken, err := jwt.Parse(r.Header.Get(upload.RewrittenFieldsHeader), func(token *jwt.Token) (interface{}, error) { + // Don't forget to validate the alg is what you expect: + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + testhelper.ConfigureSecret() + secretBytes, err := secret.Bytes() + if err != nil { + return nil, fmt.Errorf("read secret from file: %v", err) + } + + return secretBytes, nil + }) + if err != nil { + t.Fatal(err) + } + + rewrittenFields := jwtToken.Claims.(jwt.MapClaims)["rewritten_fields"].(map[string]interface{}) + if len(rewrittenFields) != 1 || len(rewrittenFields["file"].(string)) == 0 { + t.Fatalf("Unexpected rewritten_fields value: %v", rewrittenFields) + } + + }) + + defer ts.Close() + ws := startWorkhorseServer(ts.URL) + defer ws.Close() + + resource := `/example` + resp, err := http.Post(ws.URL+resource, contentType, reqBody) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("GET %q: expected 200, got %d", resource, resp.StatusCode) + } +} + +func multipartBodyWithFile() (io.Reader, string, error) { + result := &bytes.Buffer{} + writer := multipart.NewWriter(result) + file, err := writer.CreateFormFile("file", "my.file") + if err != nil { + return nil, "", err + } + fmt.Fprint(file, "SHOULD BE ON DISK, NOT IN MULTIPART") + return result, writer.FormDataContentType(), writer.Close() +} + +func TestBlockingRewrittenFieldsHeader(t *testing.T) { + canary := "untrusted header passed by user" + testCases := []struct { + desc string + contentType string + body io.Reader + present bool + }{ + {"multipart with file", "", nil, true}, // placeholder + {"no multipart", "text/plain", nil, false}, + } + + if b, c, err := multipartBodyWithFile(); err == nil { + testCases[0].contentType = c + testCases[0].body = b + } else { + t.Fatal(err) + } + + for _, tc := range testCases { + ts := testhelper.TestServerWithHandler(regexp.MustCompile(`.`), func(w http.ResponseWriter, r *http.Request) { + h := upload.RewrittenFieldsHeader + if _, ok := r.Header[h]; ok != tc.present { + t.Errorf("Expectation of presence (%v) violated", tc.present) + } + if r.Header.Get(h) == canary { + t.Errorf("Found canary %q in header %q", canary, h) + } + }) + defer ts.Close() + ws := startWorkhorseServer(ts.URL) + defer ws.Close() + + req, err := http.NewRequest("POST", ws.URL+"/something", tc.body) + if err != nil { + t.Fatal(err) + } + + req.Header.Set("Content-Type", tc.contentType) + req.Header.Set(upload.RewrittenFieldsHeader, canary) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("%s: expected HTTP 200, got %d", tc.desc, resp.StatusCode) + } + + } +}