diff --git a/Makefile b/Makefile index c2c8c5975f208fb5e189060f4bf733879e8717b0..816323530a1ea65c87999bb2de733c03c7a2d9e3 100644 --- a/Makefile +++ b/Makefile @@ -8,12 +8,17 @@ install: gitlab-workhorse install gitlab-workhorse ${PREFIX}/bin/ .PHONY: test -test: test/data/test.git clean-workhorse gitlab-workhorse +test: test/data/group/test.git clean-workhorse gitlab-workhorse go fmt | awk '{ print "Please run go fmt"; exit 1 }' go test -test/data/test.git: test/data - git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/test.git +coverage: test/data/group/test.git + go test -cover -coverprofile=test.coverage + go tool cover -html=test.coverage -o coverage.html + rm -f test.coverage + +test/data/group/test.git: test/data + git clone --bare https://gitlab.com/gitlab-org/gitlab-test.git test/data/group/test.git test/data: mkdir -p test/data diff --git a/authorization.go b/authorization.go index 8c8c0df91a08f0052e4e8a9f585591f0e7364edb..c12ac90118d3710e0ebe569cb120531b566c4d1b 100644 --- a/authorization.go +++ b/authorization.go @@ -2,13 +2,55 @@ package main import ( "encoding/json" - "errors" "fmt" "io" "net/http" "strings" ) +func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) { + url := u.authBackend + r.URL.RequestURI() + suffix + authReq, err := http.NewRequest(r.Method, url, body) + if err != nil { + return nil, err + } + // Forward all headers from our client to the auth backend. This includes + // HTTP Basic authentication credentials (the 'Authorization' header). + for k, v := range r.Header { + authReq.Header[k] = v + } + + // Clean some headers when issuing a new request without body + if body == nil { + authReq.Header.Del("Content-Type") + authReq.Header.Del("Content-Encoding") + authReq.Header.Del("Content-Length") + authReq.Header.Del("Content-Disposition") + authReq.Header.Del("Accept-Encoding") + + // Hop-by-hop headers. These are removed when sent to the backend. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html + authReq.Header.Del("Transfer-Encoding") + authReq.Header.Del("Connection") + authReq.Header.Del("Keep-Alive") + authReq.Header.Del("Proxy-Authenticate") + authReq.Header.Del("Proxy-Authorization") + authReq.Header.Del("Te") + authReq.Header.Del("Trailers") + authReq.Header.Del("Upgrade") + } + + // Also forward the Host header, which is excluded from the Header map by the http libary. + // This allows the Host header received by the backend to be consistent with other + // requests not going through gitlab-workhorse. + authReq.Host = r.Host + // Set a custom header for the request. This can be used in some + // configurations (Passenger) to solve auth request routing problems. + authReq.Header.Set("Gitlab-Workhorse", Version) + + return authReq, nil +} + func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHandleFunc { return func(w http.ResponseWriter, r *gitRequest) { authReq, err := r.u.newUpstreamRequest(r.Request, nil, suffix) @@ -65,19 +107,3 @@ func preAuthorizeHandler(handleFunc serviceHandleFunc, suffix string) serviceHan handleFunc(w, r) } } - -func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { - return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { - if r.RepoPath == "" { - fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty")) - return - } - - if !looksLikeRepo(r.RepoPath) { - http.Error(w, "Not Found", 404) - return - } - - handleFunc(w, r) - }, "") -} diff --git a/deploy_page.go b/deploy_page.go new file mode 100644 index 0000000000000000000000000000000000000000..5f81a55873de98491d26e3f9d0c55e7a005c7c56 --- /dev/null +++ b/deploy_page.go @@ -0,0 +1,23 @@ +package main + +import ( + "io/ioutil" + "net/http" + "path/filepath" +) + +func handleDeployPage(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc { + return func(w http.ResponseWriter, r *gitRequest) { + deployPage := filepath.Join(*documentRoot, "index.html") + data, err := ioutil.ReadFile(deployPage) + if err != nil { + handler(w, r) + return + } + + setNoCacheHeaders(w.Header()) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(data) + } +} diff --git a/deploy_page_test.go b/deploy_page_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6b6bcbeb4d3534f889cd088f59f7ad5746556d44 --- /dev/null +++ b/deploy_page_test.go @@ -0,0 +1,53 @@ +package main + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestIfNoDeployPageExist(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + w := httptest.NewRecorder() + + executed := false + handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { + executed = true + })(w, nil) + if !executed { + t.Error("The handler should get executed") + } +} + +func TestIfDeployPageExist(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + deployPage := "DEPLOY" + ioutil.WriteFile(filepath.Join(dir, "index.html"), []byte(deployPage), 0600) + + w := httptest.NewRecorder() + + executed := false + handleDeployPage(&dir, func(w http.ResponseWriter, r *gitRequest) { + executed = true + })(w, nil) + if executed { + t.Error("The handler should not get executed") + } + w.Flush() + + assertResponseCode(t, w, 200) + assertResponseBody(t, w, deployPage) +} diff --git a/development.go b/development.go new file mode 100644 index 0000000000000000000000000000000000000000..33c5a660cc3d66053c397c2df2fcd315be434959 --- /dev/null +++ b/development.go @@ -0,0 +1,14 @@ +package main + +import "net/http" + +func handleDevelopmentMode(developmentMode *bool, handler serviceHandleFunc) serviceHandleFunc { + return func(w http.ResponseWriter, r *gitRequest) { + if !*developmentMode { + http.NotFound(w, r.Request) + return + } + + handler(w, r) + } +} diff --git a/development_test.go b/development_test.go new file mode 100644 index 0000000000000000000000000000000000000000..88b60b5a36874d52930f794839ed05e47a9c5a43 --- /dev/null +++ b/development_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestDevelopmentModeEnabled(t *testing.T) { + developmentMode := true + + r, _ := http.NewRequest("GET", "/something", nil) + w := httptest.NewRecorder() + + executed := false + handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { + executed = true + })(w, &gitRequest{Request: r}) + if !executed { + t.Error("The handler should get executed") + } +} + +func TestDevelopmentModeDisabled(t *testing.T) { + developmentMode := false + + r, _ := http.NewRequest("GET", "/something", nil) + w := httptest.NewRecorder() + + executed := false + handleDevelopmentMode(&developmentMode, func(w http.ResponseWriter, r *gitRequest) { + executed = true + })(w, &gitRequest{Request: r}) + if executed { + t.Error("The handler should not get executed") + } + assertResponseCode(t, w, 404) +} diff --git a/error_pages.go b/error_pages.go new file mode 100644 index 0000000000000000000000000000000000000000..5e0bff333d02de2e52dcf0d2ef88f45ba813cf00 --- /dev/null +++ b/error_pages.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" + "io/ioutil" + "log" + "net/http" + "path/filepath" +) + +type errorPageResponseWriter struct { + rw http.ResponseWriter + status int + hijacked bool + path *string +} + +func (s *errorPageResponseWriter) Header() http.Header { + return s.rw.Header() +} + +func (s *errorPageResponseWriter) Write(data []byte) (n int, err error) { + if s.status == 0 { + s.WriteHeader(http.StatusOK) + } + if s.hijacked { + return 0, nil + } + return s.rw.Write(data) +} + +func (s *errorPageResponseWriter) WriteHeader(status int) { + if s.status != 0 { + return + } + + s.status = status + + if 400 <= s.status && s.status <= 599 { + errorPageFile := filepath.Join(*s.path, fmt.Sprintf("%d.html", s.status)) + + // check if custom error page exists, serve this page instead + if data, err := ioutil.ReadFile(errorPageFile); err == nil { + s.hijacked = true + + log.Printf("ErrorPage: serving predefined error page: %d", s.status) + setNoCacheHeaders(s.rw.Header()) + s.rw.Header().Set("Content-Type", "text/html; charset=utf-8") + s.rw.WriteHeader(s.status) + s.rw.Write(data) + return + } + } + + s.rw.WriteHeader(status) +} + +func (s *errorPageResponseWriter) Flush() { + s.WriteHeader(http.StatusOK) +} + +func handleRailsError(documentRoot *string, handler serviceHandleFunc) serviceHandleFunc { + return func(w http.ResponseWriter, r *gitRequest) { + rw := errorPageResponseWriter{ + rw: w, + path: documentRoot, + } + defer rw.Flush() + handler(&rw, r) + } +} diff --git a/error_pages_test.go b/error_pages_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cb512029fbeba75da35e726b2114b87c1760df57 --- /dev/null +++ b/error_pages_test.go @@ -0,0 +1,53 @@ +package main + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestIfErrorPageIsPresented(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + errorPage := "ERROR" + ioutil.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) + + w := httptest.NewRecorder() + + handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) { + w.WriteHeader(404) + fmt.Fprint(w, "Not Found") + })(w, nil) + w.Flush() + + assertResponseCode(t, w, 404) + assertResponseBody(t, w, errorPage) +} + +func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) { + dir, err := ioutil.TempDir("", "error_page") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + w := httptest.NewRecorder() + errorResponse := "ERROR" + + handleRailsError(&dir, func(w http.ResponseWriter, r *gitRequest) { + w.WriteHeader(404) + fmt.Fprint(w, errorResponse) + })(w, nil) + w.Flush() + + assertResponseCode(t, w, 404) + assertResponseBody(t, w, errorResponse) +} diff --git a/git-http.go b/git-http.go index 37b1b80a28044cf3b0a71e6189a2cd10e7835864..af13ce22d15456105b62eaab68cfdb89178405cc 100644 --- a/git-http.go +++ b/git-http.go @@ -5,13 +5,43 @@ In this file we handle the Git 'smart HTTP' protocol package main import ( + "errors" "fmt" "io" + "log" "net/http" + "os" + "path" "path/filepath" "strings" ) +func looksLikeRepo(p string) bool { + // If /path/to/foo.git/objects exists then let's assume it is a valid Git + // repository. + if _, err := os.Stat(path.Join(p, "objects")); err != nil { + log.Print(err) + return false + } + return true +} + +func repoPreAuthorizeHandler(handleFunc serviceHandleFunc) serviceHandleFunc { + return preAuthorizeHandler(func(w http.ResponseWriter, r *gitRequest) { + if r.RepoPath == "" { + fail500(w, errors.New("repoPreAuthorizeHandler: RepoPath empty")) + return + } + + if !looksLikeRepo(r.RepoPath) { + http.Error(w, "Not Found", 404) + return + } + + handleFunc(w, r) + }, "") +} + func handleGetInfoRefs(w http.ResponseWriter, r *gitRequest) { rpc := r.URL.Query().Get("service") if !(rpc == "git-upload-pack" || rpc == "git-receive-pack") { diff --git a/helpers.go b/helpers.go index 394707242b21b61e5185745f8cca3b7b1c1d97c1..bc4cfef318c5f11624931f05ee2a6d10a0d6022f 100644 --- a/helpers.go +++ b/helpers.go @@ -5,23 +5,16 @@ Miscellaneous helpers: logging, errors, subprocesses package main import ( + "errors" "fmt" - "io" - "io/ioutil" "log" "net/http" - "net/url" "os" "os/exec" - "strings" + "path" "syscall" ) -func fail400(w http.ResponseWriter, err error) { - http.Error(w, "Bad request", 400) - logError(err) -} - func fail500(w http.ResponseWriter, err error) { http.Error(w, "Internal server error", 500) logError(err) @@ -31,6 +24,15 @@ func logError(err error) { log.Printf("error: %v", err) } +func httpError(w http.ResponseWriter, r *http.Request, error string, code int) { + if r.ProtoAtLeast(1, 1) { + // Force client to disconnect if we render request error + w.Header().Set("Connection", "close") + } + + http.Error(w, error, code) +} + // Git subprocess helpers func gitCommand(gl_id string, name string, args ...string) *exec.Cmd { cmd := exec.Command(name, args...) @@ -63,20 +65,56 @@ func cleanUpProcessGroup(cmd *exec.Cmd) { cmd.Wait() } -func forwardResponseToClient(w http.ResponseWriter, r *http.Response) { - log.Printf("PROXY:%s %q %d", r.Request.Method, r.Request.URL, r.StatusCode) +func setNoCacheHeaders(header http.Header) { + header.Set("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + header.Set("Pragma", "no-cache") + header.Set("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") +} - for k, v := range r.Header { - w.Header()[k] = v +func openFile(path string) (file *os.File, fi os.FileInfo, err error) { + file, err = os.Open(path) + if err != nil { + return } - w.WriteHeader(r.StatusCode) - io.Copy(w, r.Body) + defer func() { + if err != nil { + file.Close() + } + }() + + fi, err = file.Stat() + if err != nil { + return + } + + // The os.Open can also open directories + if fi.IsDir() { + err = &os.PathError{ + Op: "open", + Path: path, + Err: errors.New("path is directory"), + } + return + } + + return } -func setHttpPostForm(r *http.Request, values url.Values) { - dataBuffer := strings.NewReader(values.Encode()) - r.Body = ioutil.NopCloser(dataBuffer) - r.ContentLength = int64(dataBuffer.Len()) - r.Header.Set("Content-Type", "application/x-www-form-urlencoded") +// Borrowed from: net/http/server.go +// Return the canonical path for p, eliminating . and .. elements. +func cleanURIPath(p string) string { + if p == "" { + return "/" + } + if p[0] != '/' { + p = "/" + p + } + np := path.Clean(p) + // path.Clean removes trailing slash except for root; + // put the trailing slash back if necessary. + if p[len(p)-1] == '/' && np != "/" { + np += "/" + } + return np } diff --git a/helpers_test.go b/helpers_test.go index 8087d3af427d524025f9b08d652d40bbf0e3dc48..65c444e509bd21c0ce8f377d8879b60fce034e35 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -10,3 +10,15 @@ func assertResponseCode(t *testing.T, response *httptest.ResponseRecorder, expec t.Fatalf("for HTTP request expected to get %d, got %d instead", expectedCode, response.Code) } } + +func assertResponseBody(t *testing.T, response *httptest.ResponseRecorder, expectedBody string) { + if response.Body.String() != expectedBody { + t.Fatalf("for HTTP request expected to receive %q, got %q instead as body", expectedBody, response.Body.String()) + } +} + +func assertResponseHeader(t *testing.T, response *httptest.ResponseRecorder, header string, expectedValue string) { + if response.Header().Get(header) != expectedValue { + t.Fatalf("for HTTP request expected to receive the header %q with %q, got %q", header, expectedValue, response.Header().Get(header)) + } +} diff --git a/lfs.go b/lfs.go index 44638b21185bd49a8a39d8f8dc014ee8ed112e3c..66c62a3f37b5708989762b9f5622c4bd9c0f3216 100644 --- a/lfs.go +++ b/lfs.go @@ -5,6 +5,7 @@ In this file we handle git lfs objects downloads and uploads package main import ( + "bytes" "crypto/sha256" "encoding/hex" "errors" @@ -67,20 +68,12 @@ func handleStoreLfsObject(w http.ResponseWriter, r *gitRequest) { fail500(w, fmt.Errorf("handleStoreLfsObject: expected sha256 %s, got %s", r.LfsOid, shaStr)) return } - r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name())) - - storeReq, err := r.u.newUpstreamRequest(r.Request, nil, "") - if err != nil { - fail500(w, fmt.Errorf("handleStoreLfsObject: newUpstreamRequest: %v", err)) - return - } - storeResponse, err := r.u.httpClient.Do(storeReq) - if err != nil { - fail500(w, fmt.Errorf("handleStoreLfsObject: do %v: %v", storeReq.URL.Path, err)) - return - } - defer storeResponse.Body.Close() + // Inject header and body + r.Header.Set("X-GitLab-Lfs-Tmp", filepath.Base(file.Name())) + r.Body = ioutil.NopCloser(&bytes.Buffer{}) + r.ContentLength = 0 - forwardResponseToClient(w, storeResponse) + // And proxy the request + proxyRequest(w, r) } diff --git a/logging.go b/logging.go new file mode 100644 index 0000000000000000000000000000000000000000..bee171b2db36334064033cb8df9f7df37c5a2bf6 --- /dev/null +++ b/logging.go @@ -0,0 +1,52 @@ +package main + +import ( + "fmt" + "net/http" + "time" +) + +type loggingResponseWriter struct { + rw http.ResponseWriter + status int + written int64 + started time.Time +} + +func newLoggingResponseWriter(rw http.ResponseWriter) loggingResponseWriter { + return loggingResponseWriter{ + rw: rw, + started: time.Now(), + } +} + +func (l *loggingResponseWriter) Header() http.Header { + return l.rw.Header() +} + +func (l *loggingResponseWriter) Write(data []byte) (n int, err error) { + if l.status == 0 { + l.WriteHeader(http.StatusOK) + } + n, err = l.rw.Write(data) + l.written += int64(n) + return +} + +func (l *loggingResponseWriter) WriteHeader(status int) { + if l.status != 0 { + return + } + + l.status = status + l.rw.WriteHeader(status) +} + +func (l *loggingResponseWriter) Log(r *http.Request) { + duration := time.Since(l.started) + fmt.Printf("%s %s - - [%s] %q %d %d %q %q %f\n", + r.Host, r.RemoteAddr, l.started, + fmt.Sprintf("%s %s %s", r.Method, r.RequestURI, r.Proto), + l.status, l.written, r.Referer(), r.UserAgent(), duration.Seconds(), + ) +} diff --git a/main.go b/main.go index 8be1e664470cb68e44be376a828788f1c5563f14..0c5a9fcb2d4b5816b2149a8ea97a8ea7ff4acced 100644 --- a/main.go +++ b/main.go @@ -21,20 +21,97 @@ import ( "net/http" _ "net/http/pprof" "os" + "regexp" "syscall" "time" ) +// Current version of GitLab Workhorse var Version = "(unknown version)" // Set at build time in the Makefile +var printVersion = flag.Bool("version", false, "Print version and exit") +var listenAddr = flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") +var listenNetwork = flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") +var listenUmask = flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022") +var authBackend = flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend") +var authSocket = flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") +var pprofListenAddr = flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") +var relativeURLRoot = flag.String("relativeURLRoot", "/", "GitLab relative URL root") +var documentRoot = flag.String("documentRoot", "public", "Path to static files content") +var responseHeadersTimeout = flag.Duration("proxyHeadersTimeout", time.Minute, "How long to wait for response headers when proxying the request") +var developmentMode = flag.Bool("developmentMode", false, "Allow to serve assets from Rails app") + +type httpRoute struct { + method string + regex *regexp.Regexp + handleFunc serviceHandleFunc +} + +const projectPattern = `^/[^/]+/[^/]+/` +const gitProjectPattern = `^/[^/]+/[^/]+\.git/` + +const apiPattern = `^/api/` +const projectsAPIPattern = `^/api/v3/projects/[^/]+/` + +const ciAPIPattern = `^/ci/api/` + +// Routing table +// We match against URI not containing the relativeUrlRoot: +// see upstream.ServeHTTP +var httpRoutes = [...]httpRoute{ + // Git Clone + httpRoute{"GET", regexp.MustCompile(gitProjectPattern + `info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)}, + httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, + httpRoute{"POST", regexp.MustCompile(gitProjectPattern + `git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, + httpRoute{"PUT", regexp.MustCompile(gitProjectPattern + `gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)}, + + // Repository Archive + httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)}, + + // Repository Archive API + httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)}, + httpRoute{"GET", regexp.MustCompile(projectsAPIPattern + `repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)}, + + // CI Artifacts API + httpRoute{"POST", regexp.MustCompile(ciAPIPattern + `v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))}, + + // Explicitly proxy API requests + httpRoute{"", regexp.MustCompile(apiPattern), proxyRequest}, + httpRoute{"", regexp.MustCompile(ciAPIPattern), proxyRequest}, + + // Serve assets + httpRoute{"", regexp.MustCompile(`^/assets/`), + handleServeFile(documentRoot, CacheExpireMax, + handleDevelopmentMode(developmentMode, + handleDeployPage(documentRoot, + handleRailsError(documentRoot, + proxyRequest, + ), + ), + ), + ), + }, + + // Serve static files or forward the requests + httpRoute{"", nil, + handleServeFile(documentRoot, CacheDisabled, + handleDeployPage(documentRoot, + handleRailsError(documentRoot, + proxyRequest, + ), + ), + ), + }, +} + func main() { - printVersion := flag.Bool("version", false, "Print version and exit") - listenAddr := flag.String("listenAddr", "localhost:8181", "Listen address for HTTP server") - listenNetwork := flag.String("listenNetwork", "tcp", "Listen 'network' (tcp, tcp4, tcp6, unix)") - listenUmask := flag.Int("listenUmask", 022, "Umask for Unix socket, default: 022") - authBackend := flag.String("authBackend", "http://localhost:8080", "Authentication/authorization backend") - authSocket := flag.String("authSocket", "", "Optional: Unix domain socket to dial authBackend at") - pprofListenAddr := flag.String("pprofListenAddr", "", "pprof listening address, e.g. 'localhost:6060'") flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) fmt.Fprintf(os.Stderr, "\n %s [OPTIONS]\n\nOptions:\n", os.Args[0]) @@ -65,7 +142,8 @@ func main() { log.Fatal(err) } - var authTransport http.RoundTripper + // Create Proxy Transport + authTransport := http.DefaultTransport if *authSocket != "" { dialer := &net.Dialer{ // The values below are taken from http.DefaultTransport @@ -76,8 +154,10 @@ func main() { Dial: func(_, _ string) (net.Conn, error) { return dialer.Dial("unix", *authSocket) }, + ResponseHeaderTimeout: *responseHeadersTimeout, } } + proxyTransport := &proxyRoundTripper{transport: authTransport} // The profiler will only be activated by HTTP requests. HTTP // requests can only reach the profiler if we start a listener. So by @@ -89,9 +169,7 @@ func main() { }() } - // Because net/http/pprof installs itself in the DefaultServeMux - // we create a fresh one for the Git server. - serveMux := http.NewServeMux() - serveMux.Handle("/", newUpstream(*authBackend, authTransport)) - log.Fatal(http.Serve(listener, serveMux)) + upstream := newUpstream(*authBackend, proxyTransport) + upstream.SetRelativeURLRoot(*relativeURLRoot) + log.Fatal(http.Serve(listener, upstream)) } diff --git a/main_test.go b/main_test.go index 6bf4acb4a279b688a4d51d969eb0932f0007ec72..c41224d92a48b216fba02f1b4d8e1f4d120ad34d 100644 --- a/main_test.go +++ b/main_test.go @@ -18,8 +18,8 @@ import ( const scratchDir = "test/scratch" const testRepoRoot = "test/data" -const testRepo = "test.git" -const testProject = "test" +const testRepo = "group/test.git" +const testProject = "group/test" var checkoutDir = path.Join(scratchDir, "test") var cacheDir = path.Join(scratchDir, "cache") @@ -276,7 +276,6 @@ func preparePushRepo(t *testing.T) { } cloneCmd := exec.Command("git", "clone", path.Join(testRepoRoot, testRepo), checkoutDir) runOrFail(t, cloneCmd) - return } func newBranch() string { @@ -367,89 +366,3 @@ func repoPath(t *testing.T) string { } return path.Join(cwd, testRepoRoot, testRepo) } - -func TestDeniedLfsDownload(t *testing.T) { - contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80" - url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename) - - prepareDownloadDir(t) - deniedXSendfileDownload(t, contentFilename, url) -} - -func TestAllowedLfsDownload(t *testing.T) { - contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80" - url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename) - - prepareDownloadDir(t) - allowedXSendfileDownload(t, contentFilename, url) -} - -func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) { - contentPath := path.Join(cacheDir, contentFilename) - prepareDownloadDir(t) - - // Prepare test server and backend - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Println("UPSTREAM", r.Method, r.URL) - if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" { - t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType) - } - w.Header().Set("X-Sendfile", contentPath) - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename)) - w.Header().Set("Content-Type", fmt.Sprintf(`application/octet-stream`)) - w.WriteHeader(200) - })) - defer ts.Close() - ws := startWorkhorseServer(ts.URL) - defer ws.Close() - - if err := os.MkdirAll(cacheDir, 0755); err != nil { - t.Fatal(err) - } - contentBytes := []byte("content") - if err := ioutil.WriteFile(contentPath, contentBytes, 0644); err != nil { - t.Fatal(err) - } - - downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath)) - downloadCmd.Dir = scratchDir - runOrFail(t, downloadCmd) - - actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename)) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(actual, contentBytes) != 0 { - t.Fatal("Unexpected file contents in download") - } -} - -func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) { - prepareDownloadDir(t) - - // Prepare test server and backend - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Println("UPSTREAM", r.Method, r.URL) - if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" { - t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType) - } - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename)) - w.WriteHeader(200) - fmt.Fprint(w, "Denied") - })) - defer ts.Close() - ws := startWorkhorseServer(ts.URL) - defer ws.Close() - - downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath)) - downloadCmd.Dir = scratchDir - runOrFail(t, downloadCmd) - - actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename)) - if err != nil { - t.Fatal(err) - } - if bytes.Compare(actual, []byte("Denied")) != 0 { - t.Fatal("Unexpected file contents in download") - } -} diff --git a/proxy.go b/proxy.go index 4605b827b01f95cb44f71fd0da715c80f8851d73..0758196c3aa19aa6e3ba47a3a2a0f6ce7073b18e 100644 --- a/proxy.go +++ b/proxy.go @@ -1,23 +1,65 @@ package main import ( + "bytes" "fmt" + "io/ioutil" "net/http" ) -func proxyRequest(w http.ResponseWriter, r *gitRequest) { - upRequest, err := r.u.newUpstreamRequest(r.Request, r.Body, "") +type proxyRoundTripper struct { + transport http.RoundTripper +} + +func (p *proxyRoundTripper) RoundTrip(r *http.Request) (res *http.Response, err error) { + res, err = p.transport.RoundTrip(r) + + // httputil.ReverseProxy translates all errors from this + // RoundTrip function into 500 errors. But the most likely error + // is that the Rails app is not responding, in which case users + // and administrators expect to see a 502 error. To show 502s + // instead of 500s we catch the RoundTrip error here and inject a + // 502 response. if err != nil { - fail500(w, fmt.Errorf("proxyRequest: newUpstreamRequest: %v", err)) - return + logError(fmt.Errorf("proxyRoundTripper: %s %q failed with: %q", r.Method, r.RequestURI, err)) + + res = &http.Response{ + StatusCode: http.StatusBadGateway, + Status: http.StatusText(http.StatusBadGateway), + + Request: r, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + Proto: r.Proto, + Header: make(http.Header), + Trailer: make(http.Header), + Body: ioutil.NopCloser(bytes.NewBufferString(err.Error())), + } + res.Header.Set("Content-Type", "text/plain") + err = nil } + return +} - upResponse, err := r.u.httpClient.Do(upRequest) - if err != nil { - fail500(w, fmt.Errorf("proxyRequest: do %v: %v", upRequest.URL.Path, err)) - return +func headerClone(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 } - defer upResponse.Body.Close() + return h2 +} + +func proxyRequest(w http.ResponseWriter, r *gitRequest) { + // Clone request + req := *r.Request + req.Header = headerClone(r.Header) + + // Set Workhorse version + req.Header.Set("Gitlab-Workhorse", Version) + rw := newSendFileResponseWriter(w, &req) + defer rw.Flush() - forwardResponseToClient(w, upResponse) + r.u.httpProxy.ServeHTTP(&rw, &req) } diff --git a/proxy_test.go b/proxy_test.go index cbb6f12d7d70c675a22389ba7c546c220de7203c..88669cc335a17c931f7914d62ce1f8dcd054ad77 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -4,10 +4,12 @@ import ( "bytes" "fmt" "io" + "net" "net/http" "net/http/httptest" "regexp" "testing" + "time" ) func TestProxyRequest(t *testing.T) { @@ -42,15 +44,94 @@ func TestProxyRequest(t *testing.T) { u: newUpstream(ts.URL, nil), } - response := httptest.NewRecorder() - proxyRequest(response, &request) - assertResponseCode(t, response, 202) + w := httptest.NewRecorder() + proxyRequest(w, &request) + assertResponseCode(t, w, 202) + assertResponseBody(t, w, "RESPONSE") - if response.Body.String() != "RESPONSE" { - t.Fatal("Expected RESPONSE in response body:", response.Body.String()) + if w.Header().Get("Custom-Response-Header") != "test" { + t.Fatal("Expected custom response header") } +} - if response.Header().Get("Custom-Response-Header") != "test" { - t.Fatal("Expected custom response header") +func TestProxyError(t *testing.T) { + httpRequest, err := http.NewRequest("POST", "/url/path", bytes.NewBufferString("REQUEST")) + if err != nil { + t.Fatal(err) + } + httpRequest.Header.Set("Custom-Header", "test") + + transport := proxyRoundTripper{ + transport: http.DefaultTransport, + } + + request := gitRequest{ + Request: httpRequest, + u: newUpstream("http://localhost:655575/", &transport), + } + + w := httptest.NewRecorder() + proxyRequest(w, &request) + assertResponseCode(t, w, 502) + assertResponseBody(t, w, "dial tcp: invalid port 655575") +} + +func TestProxyReadTimeout(t *testing.T) { + ts := testServerWithHandler(nil, func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Minute) + }) + + httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil) + if err != nil { + t.Fatal(err) } + + transport := &proxyRoundTripper{ + transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: time.Millisecond, + }, + } + + request := gitRequest{ + Request: httpRequest, + u: newUpstream(ts.URL, transport), + } + + w := httptest.NewRecorder() + proxyRequest(w, &request) + assertResponseCode(t, w, 502) + assertResponseBody(t, w, "net/http: timeout awaiting response headers") +} + +func TestProxyHandlerTimeout(t *testing.T) { + ts := testServerWithHandler(nil, + http.TimeoutHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Second) + }), time.Millisecond, "Request took too long").ServeHTTP, + ) + + httpRequest, err := http.NewRequest("POST", "http://localhost/url/path", nil) + if err != nil { + t.Fatal(err) + } + + transport := &proxyRoundTripper{ + transport: http.DefaultTransport, + } + + request := gitRequest{ + Request: httpRequest, + u: newUpstream(ts.URL, transport), + } + + w := httptest.NewRecorder() + proxyRequest(w, &request) + assertResponseCode(t, w, 503) + assertResponseBody(t, w, "Request took too long") } diff --git a/sendfile.go b/sendfile.go new file mode 100644 index 0000000000000000000000000000000000000000..64fdf5f39f1b75ad67912455eb75f8b4c2143ef9 --- /dev/null +++ b/sendfile.go @@ -0,0 +1,78 @@ +/* +The xSendFile middleware transparently sends static files in HTTP responses +via the X-Sendfile mechanism. All that is needed in the Rails code is the +'send_file' method. +*/ + +package main + +import ( + "log" + "net/http" +) + +type sendFileResponseWriter struct { + rw http.ResponseWriter + status int + hijacked bool + req *http.Request +} + +func newSendFileResponseWriter(rw http.ResponseWriter, req *http.Request) sendFileResponseWriter { + s := sendFileResponseWriter{ + rw: rw, + req: req, + } + req.Header.Set("X-Sendfile-Type", "X-Sendfile") + return s +} + +func (s *sendFileResponseWriter) Header() http.Header { + return s.rw.Header() +} + +func (s *sendFileResponseWriter) Write(data []byte) (n int, err error) { + if s.status == 0 { + s.WriteHeader(http.StatusOK) + } + if s.hijacked { + return + } + return s.rw.Write(data) +} + +func (s *sendFileResponseWriter) WriteHeader(status int) { + if s.status != 0 { + return + } + + s.status = status + + // Check X-Sendfile header + file := s.Header().Get("X-Sendfile") + s.Header().Del("X-Sendfile") + + // If file is empty or status is not 200 pass through header + if file == "" || s.status != http.StatusOK { + s.rw.WriteHeader(s.status) + return + } + + // Mark this connection as hijacked + s.hijacked = true + + // Serve the file + log.Printf("Send file %q for %s %q", file, s.req.Method, s.req.RequestURI) + content, fi, err := openFile(file) + if err != nil { + http.NotFound(s.rw, s.req) + return + } + defer content.Close() + + http.ServeContent(s.rw, s.req, "", fi.ModTime(), content) +} + +func (s *sendFileResponseWriter) Flush() { + s.WriteHeader(http.StatusOK) +} diff --git a/sendfile_test.go b/sendfile_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9ffa615428137fb12f52f2e9cd23e81783bc93ac --- /dev/null +++ b/sendfile_test.go @@ -0,0 +1,100 @@ +package main + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path" + "testing" +) + +func TestDeniedLfsDownload(t *testing.T) { + contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80" + url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename) + + prepareDownloadDir(t) + deniedXSendfileDownload(t, contentFilename, url) +} + +func TestAllowedLfsDownload(t *testing.T) { + contentFilename := "b68143e6463773b1b6c6fd009a76c32aeec041faff32ba2ed42fd7f708a17f80" + url := fmt.Sprintf("gitlab-lfs/objects/%s", contentFilename) + + prepareDownloadDir(t) + allowedXSendfileDownload(t, contentFilename, url) +} + +func allowedXSendfileDownload(t *testing.T, contentFilename string, filePath string) { + contentPath := path.Join(cacheDir, contentFilename) + prepareDownloadDir(t) + + // Prepare test server and backend + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println("UPSTREAM", r.Method, r.URL) + if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" { + t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType) + } + w.Header().Set("X-Sendfile", contentPath) + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename)) + w.Header().Set("Content-Type", fmt.Sprintf(`application/octet-stream`)) + w.WriteHeader(200) + })) + defer ts.Close() + ws := startWorkhorseServer(ts.URL) + defer ws.Close() + + if err := os.MkdirAll(cacheDir, 0755); err != nil { + t.Fatal(err) + } + contentBytes := []byte("content") + if err := ioutil.WriteFile(contentPath, contentBytes, 0644); err != nil { + t.Fatal(err) + } + + downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath)) + downloadCmd.Dir = scratchDir + runOrFail(t, downloadCmd) + + actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename)) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(actual, contentBytes) != 0 { + t.Fatal("Unexpected file contents in download") + } +} + +func deniedXSendfileDownload(t *testing.T, contentFilename string, filePath string) { + prepareDownloadDir(t) + + // Prepare test server and backend + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Println("UPSTREAM", r.Method, r.URL) + if xSendfileType := r.Header.Get("X-Sendfile-Type"); xSendfileType != "X-Sendfile" { + t.Fatalf(`X-Sendfile-Type want "X-Sendfile" got %q`, xSendfileType) + } + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, contentFilename)) + w.WriteHeader(200) + fmt.Fprint(w, "Denied") + })) + defer ts.Close() + ws := startWorkhorseServer(ts.URL) + defer ws.Close() + + downloadCmd := exec.Command("curl", "-J", "-O", fmt.Sprintf("%s/%s", ws.URL, filePath)) + downloadCmd.Dir = scratchDir + runOrFail(t, downloadCmd) + + actual, err := ioutil.ReadFile(path.Join(scratchDir, contentFilename)) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(actual, []byte("Denied")) != 0 { + t.Fatal("Unexpected file contents in download") + } +} diff --git a/servefile.go b/servefile.go new file mode 100644 index 0000000000000000000000000000000000000000..5af65cac604eecb0b89a4049cf500568a1f86862 --- /dev/null +++ b/servefile.go @@ -0,0 +1,70 @@ +package main + +import ( + "log" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +type CacheMode int + +const ( + CacheDisabled CacheMode = iota + CacheExpireMax +) + +func handleServeFile(documentRoot *string, cache CacheMode, notFoundHandler serviceHandleFunc) serviceHandleFunc { + return func(w http.ResponseWriter, r *gitRequest) { + file := filepath.Join(*documentRoot, r.relativeURIPath) + + // The filepath.Join does Clean traversing directories up + if !strings.HasPrefix(file, *documentRoot) { + fail500(w, &os.PathError{ + Op: "open", + Path: file, + Err: os.ErrInvalid, + }) + return + } + + var content *os.File + var fi os.FileInfo + var err error + + // Serve pre-gzipped assets + if acceptEncoding := r.Header.Get("Accept-Encoding"); strings.Contains(acceptEncoding, "gzip") { + content, fi, err = openFile(file + ".gz") + if err == nil { + w.Header().Set("Content-Encoding", "gzip") + } + } + + // If not found, open the original file + if content == nil || err != nil { + content, fi, err = openFile(file) + } + if err != nil { + if notFoundHandler != nil { + notFoundHandler(w, r) + } else { + http.NotFound(w, r.Request) + } + return + } + defer content.Close() + + switch cache { + case CacheExpireMax: + // Cache statically served files for 1 year + cacheUntil := time.Now().AddDate(1, 0, 0).Format(http.TimeFormat) + w.Header().Set("Cache-Control", "public") + w.Header().Set("Expires", cacheUntil) + } + + log.Printf("Send static file %q (%q) for %s %q", file, w.Header().Get("Content-Encoding"), r.Method, r.RequestURI) + http.ServeContent(w, r.Request, filepath.Base(file), fi.ModTime(), content) + } +} diff --git a/servefile_test.go b/servefile_test.go new file mode 100644 index 0000000000000000000000000000000000000000..681d14e7e0efc43f4063c86a957968a961bfa963 --- /dev/null +++ b/servefile_test.go @@ -0,0 +1,149 @@ +package main + +import ( + "bytes" + "compress/gzip" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestServingNonExistingFile(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/file", nil) + request := &gitRequest{ + Request: httpRequest, + relativeURIPath: "/static/file", + } + + w := httptest.NewRecorder() + handleServeFile(&dir, CacheDisabled, nil)(w, request) + assertResponseCode(t, w, 404) +} + +func TestServingDirectory(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + request := &gitRequest{ + Request: httpRequest, + relativeURIPath: "/", + } + + w := httptest.NewRecorder() + handleServeFile(&dir, CacheDisabled, nil)(w, request) + assertResponseCode(t, w, 404) +} + +func TestServingMalformedUri(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/file", nil) + request := &gitRequest{ + Request: httpRequest, + relativeURIPath: "/../../../static/file", + } + + w := httptest.NewRecorder() + handleServeFile(&dir, CacheDisabled, nil)(w, request) + assertResponseCode(t, w, 500) +} + +func TestExecutingHandlerWhenNoFileFound(t *testing.T) { + dir := "/path/to/non/existing/directory" + httpRequest, _ := http.NewRequest("GET", "/file", nil) + request := &gitRequest{ + Request: httpRequest, + relativeURIPath: "/static/file", + } + + executed := false + handleServeFile(&dir, CacheDisabled, func(w http.ResponseWriter, r *gitRequest) { + executed = (r == request) + })(nil, request) + if !executed { + t.Error("The handler should get executed") + } +} + +func TestServingTheActualFile(t *testing.T) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + request := &gitRequest{ + Request: httpRequest, + relativeURIPath: "/file", + } + + fileContent := "STATIC" + ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) + + w := httptest.NewRecorder() + handleServeFile(&dir, CacheDisabled, nil)(w, request) + assertResponseCode(t, w, 200) + if w.Body.String() != fileContent { + t.Error("We should serve the file: ", w.Body.String()) + } +} + +func testServingThePregzippedFile(t *testing.T, enableGzip bool) { + dir, err := ioutil.TempDir("", "deploy") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(dir) + + httpRequest, _ := http.NewRequest("GET", "/file", nil) + request := &gitRequest{ + Request: httpRequest, + relativeURIPath: "/file", + } + + if enableGzip { + httpRequest.Header.Set("Accept-Encoding", "gzip, deflate") + } + + fileContent := "STATIC" + + var fileGzipContent bytes.Buffer + fileGzip := gzip.NewWriter(&fileGzipContent) + fileGzip.Write([]byte(fileContent)) + fileGzip.Close() + + ioutil.WriteFile(filepath.Join(dir, "file.gz"), fileGzipContent.Bytes(), 0600) + ioutil.WriteFile(filepath.Join(dir, "file"), []byte(fileContent), 0600) + + w := httptest.NewRecorder() + handleServeFile(&dir, CacheDisabled, nil)(w, request) + assertResponseCode(t, w, 200) + if enableGzip { + assertResponseHeader(t, w, "Content-Encoding", "gzip") + if bytes.Compare(w.Body.Bytes(), fileGzipContent.Bytes()) != 0 { + t.Error("We should serve the pregzipped file") + } + } else { + assertResponseCode(t, w, 200) + assertResponseHeader(t, w, "Content-Encoding", "") + if w.Body.String() != fileContent { + t.Error("We should serve the file: ", w.Body.String()) + } + } +} + +func TestServingThePregzippedFile(t *testing.T) { + testServingThePregzippedFile(t, true) +} + +func TestServingThePregzippedFileWithoutEncoding(t *testing.T) { + testServingThePregzippedFile(t, false) +} diff --git a/uploads.go b/uploads.go index 9468474d176389c3d2fd1ef586cbad3506d859f7..7848cacceaa733a284e245241bbdfef284d94a4c 100644 --- a/uploads.go +++ b/uploads.go @@ -111,25 +111,11 @@ func handleFileUploads(w http.ResponseWriter, r *gitRequest) { // Close writer writer.Close() - // Create request - upstreamRequest, err := r.u.newUpstreamRequest(r.Request, nil, "") - if err != nil { - fail500(w, fmt.Errorf("handleFileUploads: newUpstreamRequest: %v", err)) - return - } - - // Set multipart form data - upstreamRequest.Body = ioutil.NopCloser(&body) - upstreamRequest.ContentLength = int64(body.Len()) - upstreamRequest.Header.Set("Content-Type", writer.FormDataContentType()) - - // Forward request to backend - upstreamResponse, err := r.u.httpClient.Do(upstreamRequest) - if err != nil { - fail500(w, fmt.Errorf("handleFileUploads: do request %v: %v", upstreamRequest.URL.Path, err)) - return - } - defer upstreamResponse.Body.Close() + // Hijack the request + r.Body = ioutil.NopCloser(&body) + r.ContentLength = int64(body.Len()) + r.Header.Set("Content-Type", writer.FormDataContentType()) - forwardResponseToClient(w, upstreamResponse) + // Proxy the request + proxyRequest(w, r) } diff --git a/upstream.go b/upstream.go index cbe0b5f5913d4bdd326bd8adc1524c0034ae58ec..b6ac589714f675b5057cfc62830f94445c7b337e 100644 --- a/upstream.go +++ b/upstream.go @@ -7,25 +7,21 @@ In this file we handle request routing and interaction with the authBackend. package main import ( - "io" + "fmt" "log" "net/http" - "os" - "path" - "regexp" + "net/http/httputil" + "net/url" + "strings" ) type serviceHandleFunc func(w http.ResponseWriter, r *gitRequest) type upstream struct { - httpClient *http.Client - authBackend string -} - -type gitService struct { - method string - regex *regexp.Regexp - handleFunc serviceHandleFunc + httpClient *http.Client + httpProxy *httputil.ReverseProxy + authBackend string + relativeURLRoot string } type authorizationResponse struct { @@ -56,50 +52,79 @@ type authorizationResponse struct { TempPath string } -// A gitReqest is an *http.Request decorated with attributes returned by the +// A gitRequest is an *http.Request decorated with attributes returned by the // GitLab Rails application. type gitRequest struct { *http.Request authorizationResponse u *upstream -} -// Routing table -var gitServices = [...]gitService{ - gitService{"GET", regexp.MustCompile(`/info/refs\z`), repoPreAuthorizeHandler(handleGetInfoRefs)}, - gitService{"POST", regexp.MustCompile(`/git-upload-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, - gitService{"POST", regexp.MustCompile(`/git-receive-pack\z`), repoPreAuthorizeHandler(contentEncodingHandler(handlePostRPC))}, - gitService{"GET", regexp.MustCompile(`/repository/archive\z`), repoPreAuthorizeHandler(handleGetArchive)}, - gitService{"GET", regexp.MustCompile(`/repository/archive.zip\z`), repoPreAuthorizeHandler(handleGetArchive)}, - gitService{"GET", regexp.MustCompile(`/repository/archive.tar\z`), repoPreAuthorizeHandler(handleGetArchive)}, - gitService{"GET", regexp.MustCompile(`/repository/archive.tar.gz\z`), repoPreAuthorizeHandler(handleGetArchive)}, - gitService{"GET", regexp.MustCompile(`/repository/archive.tar.bz2\z`), repoPreAuthorizeHandler(handleGetArchive)}, - gitService{"GET", regexp.MustCompile(`/uploads/`), handleSendFile}, - - // Git LFS - gitService{"PUT", regexp.MustCompile(`/gitlab-lfs/objects/([0-9a-f]{64})/([0-9]+)\z`), lfsAuthorizeHandler(handleStoreLfsObject)}, - gitService{"GET", regexp.MustCompile(`/gitlab-lfs/objects/([0-9a-f]{64})\z`), handleSendFile}, - - // CI artifacts - gitService{"GET", regexp.MustCompile(`/builds/download\z`), handleSendFile}, - gitService{"GET", regexp.MustCompile(`/ci/api/v1/builds/[0-9]+/artifacts\z`), handleSendFile}, - gitService{"POST", regexp.MustCompile(`/ci/api/v1/builds/[0-9]+/artifacts\z`), artifactsAuthorizeHandler(contentEncodingHandler(handleFileUploads))}, - gitService{"DELETE", regexp.MustCompile(`/ci/api/v1/builds/[0-9]+/artifacts\z`), proxyRequest}, + // This field contains the URL.Path stripped from RelativeUrlRoot + relativeURIPath string } func newUpstream(authBackend string, authTransport http.RoundTripper) *upstream { - return &upstream{&http.Client{Transport: authTransport}, authBackend} + u, err := url.Parse(authBackend) + if err != nil { + log.Fatalln(err) + } + + up := &upstream{ + authBackend: authBackend, + httpClient: &http.Client{Transport: authTransport}, + httpProxy: httputil.NewSingleHostReverseProxy(u), + relativeURLRoot: "/", + } + up.httpProxy.Transport = authTransport + return up +} + +func (u *upstream) SetRelativeURLRoot(relativeURLRoot string) { + u.relativeURLRoot = relativeURLRoot + + if !strings.HasSuffix(u.relativeURLRoot, "/") { + u.relativeURLRoot += "/" + } } -func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var g gitService +func (u *upstream) ServeHTTP(ow http.ResponseWriter, r *http.Request) { + var g httpRoute + + w := newLoggingResponseWriter(ow) + defer w.Log(r) - log.Printf("%s %q", r.Method, r.URL) + // Drop WebSocket connection and CONNECT method + if r.RequestURI == "*" { + httpError(&w, r, "Connection upgrade not allowed", http.StatusBadRequest) + return + } + + // Disallow connect + if r.Method == "CONNECT" { + httpError(&w, r, "CONNECT not allowed", http.StatusBadRequest) + return + } + + // Check URL Root + URIPath := cleanURIPath(r.URL.Path) + if !strings.HasPrefix(URIPath, u.relativeURLRoot) { + httpError(&w, r, fmt.Sprintf("Not found %q", URIPath), http.StatusNotFound) + return + } + + // Strip prefix and add "/" + // To match against non-relative URL + // Making it simpler for our matcher + relativeURIPath := cleanURIPath(strings.TrimPrefix(URIPath, u.relativeURLRoot)) // Look for a matching Git service foundService := false - for _, g = range gitServices { - if r.Method == g.method && g.regex.MatchString(r.URL.Path) { + for _, g = range httpRoutes { + if g.method != "" && r.Method != g.method { + continue + } + + if g.regex == nil || g.regex.MatchString(relativeURIPath) { foundService = true break } @@ -107,57 +132,15 @@ func (u *upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !foundService { // The protocol spec in git/Documentation/technical/http-protocol.txt // says we must return 403 if no matching service is found. - http.Error(w, "Forbidden", 403) + httpError(&w, r, "Forbidden", http.StatusForbidden) return } request := gitRequest{ - Request: r, - u: u, - } - - g.handleFunc(w, &request) -} - -func looksLikeRepo(p string) bool { - // If /path/to/foo.git/objects exists then let's assume it is a valid Git - // repository. - if _, err := os.Stat(path.Join(p, "objects")); err != nil { - log.Print(err) - return false + Request: r, + relativeURIPath: relativeURIPath, + u: u, } - return true -} - -func (u *upstream) newUpstreamRequest(r *http.Request, body io.Reader, suffix string) (*http.Request, error) { - url := u.authBackend + r.URL.RequestURI() + suffix - authReq, err := http.NewRequest(r.Method, url, body) - if err != nil { - return nil, err - } - // Forward all headers from our client to the auth backend. This includes - // HTTP Basic authentication credentials (the 'Authorization' header). - for k, v := range r.Header { - authReq.Header[k] = v - } - - // Clean some headers when issuing a new request without body - if body == nil { - authReq.Header.Del("Content-Type") - authReq.Header.Del("Content-Encoding") - authReq.Header.Del("Content-Length") - authReq.Header.Del("Content-Disposition") - authReq.Header.Del("Accept-Encoding") - authReq.Header.Del("Transfer-Encoding") - } - - // Also forward the Host header, which is excluded from the Header map by the http libary. - // This allows the Host header received by the backend to be consistent with other - // requests not going through gitlab-workhorse. - authReq.Host = r.Host - // Set a custom header for the request. This can be used in some - // configurations (Passenger) to solve auth request routing problems. - authReq.Header.Set("Gitlab-Workhorse", Version) - return authReq, nil + g.handleFunc(&w, &request) } diff --git a/xsendfile.go b/xsendfile.go deleted file mode 100644 index 591d56546edf5accfb07a98c685fdb2cd595af28..0000000000000000000000000000000000000000 --- a/xsendfile.go +++ /dev/null @@ -1,69 +0,0 @@ -/* -The xSendFile middleware transparently sends static files in HTTP responses -via the X-Sendfile mechanism. All that is needed in the Rails code is the -'send_file' method. -*/ - -package main - -import ( - "fmt" - "io" - "log" - "net/http" - "os" -) - -func handleSendFile(w http.ResponseWriter, r *gitRequest) { - upRequest, err := r.u.newUpstreamRequest(r.Request, r.Body, "") - if err != nil { - fail500(w, fmt.Errorf("handleSendFile: newUpstreamRequest: %v", err)) - return - } - - upRequest.Header.Set("X-Sendfile-Type", "X-Sendfile") - upResponse, err := r.u.httpClient.Do(upRequest) - r.Body.Close() - if err != nil { - fail500(w, fmt.Errorf("handleSendfile: do upstream request: %v", err)) - return - } - - defer upResponse.Body.Close() - // Get X-Sendfile - sendfile := upResponse.Header.Get("X-Sendfile") - upResponse.Header.Del("X-Sendfile") - - // Copy headers from Rails upResponse - for k, v := range upResponse.Header { - w.Header()[k] = v - } - - // Use accelerated file serving - if sendfile == "" { - // Copy request body otherwise - w.WriteHeader(upResponse.StatusCode) - - // Copy body from Rails upResponse - if _, err := io.Copy(w, upResponse.Body); err != nil { - fail500(w, fmt.Errorf("handleSendFile: copy upstream response: %v", err)) - } - return - } - - log.Printf("Serving file %q", sendfile) - upResponse.Body.Close() - content, err := os.Open(sendfile) - if err != nil { - fail500(w, fmt.Errorf("handleSendile: open sendfile: %v", err)) - return - } - defer content.Close() - - fi, err := content.Stat() - if err != nil { - fail500(w, fmt.Errorf("handleSendfile: get mtime: %v", err)) - return - } - http.ServeContent(w, r.Request, "", fi.ModTime(), content) -}