Skip to content
代码片段 群组 项目
未验证 提交 0a6ca27e 编辑于 作者: Archish Thakkar's avatar Archish Thakkar 提交者: GitLab
浏览文件

Lint fixes for staticpages package

上级 65826a17
No related branches found
No related tags found
无相关合并请求
// Package staticpages provides functionality for serving static pages and handling errors.
package staticpages
import (
"fmt"
"net/http"
"os"
"path/filepath"
)
// DeployPage deploys the index.html page by serving it using the provided handler.
func (s *Static) DeployPage(handler http.Handler) http.Handler {
deployPage := filepath.Join(s.DocumentRoot, "index.html")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, err := os.ReadFile(deployPage)
cleanURL := filepath.Clean(deployPage)
data, err := os.ReadFile(cleanURL)
if err != nil {
handler.ServeHTTP(w, r)
return
......@@ -19,6 +23,9 @@ func (s *Static) DeployPage(handler http.Handler) http.Handler {
setNoCacheHeaders(w.Header())
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
w.Write(data)
_, err = w.Write(data)
if err != nil {
fmt.Printf("Error reading deploy page file: %v\n", err)
}
})
}
......@@ -19,11 +19,15 @@ var staticErrorResponses = promauto.NewCounterVec(
[]string{"code"},
)
// ErrorFormat represents the format for error handling or reporting.
type ErrorFormat int
const (
// ErrorFormatHTML represents the HTML format for error handling.
ErrorFormatHTML ErrorFormat = iota
// ErrorFormatJSON represents the JSON format for error handling.
ErrorFormatJSON
// ErrorFormatText represents the text format for error handling.
ErrorFormatText
)
......@@ -85,7 +89,10 @@ func (s *errorPageResponseWriter) WriteHeader(status int) {
s.rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
s.rw.Header().Del("Transfer-Encoding")
s.rw.WriteHeader(s.status)
s.rw.Write(data)
_, err := s.rw.Write(data)
if err != nil {
fmt.Printf("Error reading deploy page file: %v\n", err)
}
}
func (s *errorPageResponseWriter) writeHTML() (string, []byte) {
......@@ -93,7 +100,8 @@ func (s *errorPageResponseWriter) writeHTML() (string, []byte) {
errorPageFile := filepath.Join(s.path, fmt.Sprintf("%d.html", s.status))
// check if custom error page exists, serve this page instead
if data, err := os.ReadFile(errorPageFile); err == nil {
cleanPath := filepath.Clean(errorPageFile)
if data, err := os.ReadFile(cleanPath); err == nil {
return "text/html; charset=utf-8", data
}
}
......@@ -123,14 +131,15 @@ func (s *errorPageResponseWriter) Unwrap() http.ResponseWriter {
return s.rw
}
func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler {
// ErrorPagesUnless sets up error pages for specific formats unless explicitly disabled.
func (s *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler {
if disabled {
return handler
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rw := errorPageResponseWriter{
rw: w,
path: st.DocumentRoot,
path: s.DocumentRoot,
format: format,
}
defer rw.flush()
......
......@@ -13,10 +13,14 @@ import (
"gitlab.com/gitlab-org/gitlab/workhorse/internal/testhelper"
)
const (
errorPage = "ERROR"
serverError = "Interesting Server Error"
)
func TestIfErrorPageIsPresented(t *testing.T) {
dir := t.TempDir()
errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder()
......@@ -25,7 +29,7 @@ func TestIfErrorPageIsPresented(t *testing.T) {
upstreamBody := "Not Found"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
require.Len(t, upstreamBody, n, "bytes written")
})
st := &Static{DocumentRoot: dir}
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
......@@ -40,27 +44,24 @@ func TestIfErrorPassedIfNoErrorPageIsFound(t *testing.T) {
dir := t.TempDir()
w := httptest.NewRecorder()
errorResponse := "ERROR"
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(404)
fmt.Fprint(w, errorResponse)
fmt.Fprint(w, errorPage)
})
st := &Static{DocumentRoot: dir}
st.ErrorPagesUnless(false, ErrorFormatHTML, h).ServeHTTP(w, nil)
w.Flush()
require.Equal(t, 404, w.Code)
testhelper.RequireResponseBody(t, w, errorResponse)
testhelper.RequireResponseBody(t, w, errorPage)
}
func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
dir := t.TempDir()
errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder()
serverError := "Interesting Server Error"
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(500)
fmt.Fprint(w, serverError)
......@@ -75,11 +76,9 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) {
dir := t.TempDir()
errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder()
serverError := "Interesting Server Error"
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Add("X-GitLab-Custom-Error", "1")
w.WriteHeader(500)
......@@ -106,11 +105,9 @@ func TestErrorPageInterceptedByContentType(t *testing.T) {
for _, tc := range testCases {
dir := t.TempDir()
errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder()
serverError := "Interesting Server Error"
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Add("Content-Type", tc.contentType)
w.WriteHeader(500)
......@@ -138,7 +135,7 @@ func TestIfErrorPageIsPresentedJSON(t *testing.T) {
upstreamBody := "This string is ignored"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
require.Len(t, upstreamBody, n, "bytes written")
})
st := &Static{}
st.ErrorPagesUnless(false, ErrorFormatJSON, h).ServeHTTP(w, nil)
......@@ -158,7 +155,7 @@ func TestIfErrorPageIsPresentedText(t *testing.T) {
upstreamBody := "This string is ignored"
n, err := fmt.Fprint(w, upstreamBody)
require.NoError(t, err)
require.Equal(t, len(upstreamBody), n, "bytes written")
require.Len(t, upstreamBody, n, "bytes written")
})
st := &Static{}
st.ErrorPagesUnless(false, ErrorFormatText, h).ServeHTTP(w, nil)
......
......@@ -16,10 +16,13 @@ import (
"gitlab.com/gitlab-org/gitlab/workhorse/internal/urlprefix"
)
// CacheMode represents the caching mode used in the application.
type CacheMode int
const (
// CacheDisabled represents a cache mode where caching is disabled.
CacheDisabled CacheMode = iota
// CacheExpireMax represents the maximum duration for cache expiration.
CacheExpireMax
)
......@@ -29,12 +32,11 @@ const (
func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoundHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if notFoundHandler == nil {
notFoundHandler = http.HandlerFunc(http.NotFound)
notFoundHandler = http.NotFoundHandler()
}
// We intentionally use r.URL.Path instead of r.URL.EscaptedPath() below.
// This is to make it possible to serve static files with e.g. a space
// %20 in their name.
// This is to make it possible to serve static files with e.g. a space %20 in their name.
relativePath, err := s.validatePath(prefix.Strip(r.URL.Path))
if err != nil {
notFoundHandler.ServeHTTP(w, r)
......@@ -47,7 +49,6 @@ func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoun
notFoundHandler.ServeHTTP(w, r)
return
}
var content *os.File
var fi os.FileInfo
......@@ -67,13 +68,15 @@ func (s *Static) ServeExisting(prefix urlprefix.Prefix, cache CacheMode, notFoun
notFoundHandler.ServeHTTP(w, r)
return
}
w.Header().Set("X-Content-Type-Options", "nosniff")
defer content.Close()
defer func() {
if err := content.Close(); err != nil {
fmt.Printf("Error closing file: %v\n", err)
}
}()
switch cache {
case CacheExpireMax:
if cache == CacheExpireMax {
// Cache statically served files for 1 year
cacheUntil := time.Now().AddDate(1, 0, 0).Format(http.TimeFormat)
w.Header().Set("Cache-Control", "public")
......
......@@ -14,12 +14,15 @@ import (
"github.com/stretchr/testify/require"
)
const (
nonExistingDir = "/path/to/non/existing/directory"
)
func TestServingNonExistingFile(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
w := httptest.NewRecorder()
st := &Static{DocumentRoot: dir}
st := &Static{DocumentRoot: nonExistingDir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
require.Equal(t, 404, w.Code)
}
......@@ -35,21 +38,19 @@ func TestServingDirectory(t *testing.T) {
}
func TestServingMalformedUri(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/../../../static/file", nil)
w := httptest.NewRecorder()
st := &Static{DocumentRoot: dir}
st := &Static{DocumentRoot: nonExistingDir}
st.ServeExisting("/", CacheDisabled, nil).ServeHTTP(w, httpRequest)
require.Equal(t, 404, w.Code)
}
func TestExecutingHandlerWhenNoFileFound(t *testing.T) {
dir := "/path/to/non/existing/directory"
httpRequest, _ := http.NewRequest("GET", "/file", nil)
executed := false
st := &Static{DocumentRoot: dir}
st := &Static{DocumentRoot: nonExistingDir}
st.ServeExisting("/", CacheDisabled, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
executed = (r == httpRequest)
})).ServeHTTP(nil, httpRequest)
......
......@@ -2,6 +2,7 @@ package staticpages
import "net/http"
// Static represents a package for serving static pages and handling errors.
type Static struct {
DocumentRoot string
Exclude []string
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册