diff --git a/ee/lib/api/code_suggestions.rb b/ee/lib/api/code_suggestions.rb index 8f6e39f9970b896346d7b22df2caee0b1e076eba..69e2a29e68ee7f7231fdf8e526f4f70760f7b152 100644 --- a/ee/lib/api/code_suggestions.rb +++ b/ee/lib/api/code_suggestions.rb @@ -119,6 +119,7 @@ def gitlab_realm ::CodeSuggestions::InstructionsExtractor::INTENT_GENERATION ], desc: 'The intent of the completion request, current options are "completion" or "generation"' + optional :stream, type: Boolean, default: false, desc: 'The option to stream code completion response' end post do if Gitlab.org_or_com? diff --git a/ee/spec/requests/api/code_suggestions_spec.rb b/ee/spec/requests/api/code_suggestions_spec.rb index 7221a534728af20ee766f5093c4e1d67fff775b6..b8a76901a6b43f63a1d4f0726defb6e33d61d96e 100644 --- a/ee/spec/requests/api/code_suggestions_spec.rb +++ b/ee/spec/requests/api/code_suggestions_spec.rb @@ -245,6 +245,7 @@ def is_even(n: int) -> content_above_cursor: prefix, content_below_cursor: '' }, + stream: false, **additional_params } end @@ -409,6 +410,21 @@ def request end end end + + context 'when passing stream parameter' do + let(:additional_params) { { stream: true } } + + it 'passes stream into TaskFactory.new' do + expect(::CodeSuggestions::TaskFactory).to receive(:new) + .with( + current_user, + params: hash_including(stream: true), + unsafe_passthrough_params: kind_of(Hash) + ).and_call_original + + post_api + end + end end end diff --git a/workhorse/go.mod b/workhorse/go.mod index 00fddbe2323ec969832e40b7f3786319978ecd3d..35ccb2847548f3f92b2911f17b70a9225e14b600 100644 --- a/workhorse/go.mod +++ b/workhorse/go.mod @@ -23,7 +23,7 @@ require ( github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.8.4 gitlab.com/gitlab-org/gitaly/v16 v16.4.1 - gitlab.com/gitlab-org/labkit v1.20.0 + gitlab.com/gitlab-org/labkit v1.21.0 gocloud.dev v0.34.0 golang.org/x/image v0.7.0 golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 diff --git a/workhorse/go.sum b/workhorse/go.sum index d35e2948db78d4cc6166530f6f65d0effdca045f..99006c779e8bf0dd6952da592c019942cfd7424b 100644 --- a/workhorse/go.sum +++ b/workhorse/go.sum @@ -452,8 +452,8 @@ github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPR github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= gitlab.com/gitlab-org/gitaly/v16 v16.4.1 h1:Qh5TFK+Jy/mBV8hCfNro2VCqRrhgt3M2iTrdYVF5N6o= gitlab.com/gitlab-org/gitaly/v16 v16.4.1/go.mod h1:TdN/Q3OqxU75pcp8V5YWpnE8Gk6dagwlC/HefNnW1IE= -gitlab.com/gitlab-org/labkit v1.20.0 h1:DGIVAdzbCR8sq2TppBvAh35wWBYIOy5dBL5wqFK3Wa8= -gitlab.com/gitlab-org/labkit v1.20.0/go.mod h1:zeATDAaSBelPcPLbTTq8J3ZJEHyPTLVBM1q3nva+/W4= +gitlab.com/gitlab-org/labkit v1.21.0 h1:hLmdBDtXjD1yOmZ+uJOac3a5Tlo83QaezwhES4IYik4= +gitlab.com/gitlab-org/labkit v1.21.0/go.mod h1:zeATDAaSBelPcPLbTTq8J3ZJEHyPTLVBM1q3nva+/W4= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= diff --git a/workhorse/internal/api/block.go b/workhorse/internal/api/block.go index aac43f8cf77ca28d8e72bbd96ae992e4158c1e12..f6eefcdf2ccf1640f5e473f8a5244d79f470397a 100644 --- a/workhorse/internal/api/block.go +++ b/workhorse/internal/api/block.go @@ -60,3 +60,8 @@ func (b *blocker) WriteHeader(status int) { func (b *blocker) flush() { b.WriteHeader(http.StatusOK) } + +// Unwrap lets http.ResponseController get the underlying http.ResponseWriter. +func (b *blocker) Unwrap() http.ResponseWriter { + return b.rw +} diff --git a/workhorse/internal/api/block_test.go b/workhorse/internal/api/block_test.go index c1ffe93dfb859a864e30704a99981032e492914f..a28dd8d203c909339bf6d2a83d3fd8736174f5d2 100644 --- a/workhorse/internal/api/block_test.go +++ b/workhorse/internal/api/block_test.go @@ -54,3 +54,13 @@ func TestBlocker(t *testing.T) { }) } } + +func TestBlockerFlushable(t *testing.T) { + rw := httptest.NewRecorder() + b := blocker{rw: rw} + rc := http.NewResponseController(&b) + + err := rc.Flush() + require.NoError(t, err, "the underlying response writer is not flushable") + require.True(t, rw.Flushed) +} diff --git a/workhorse/internal/helper/countingresponsewriter.go b/workhorse/internal/helper/countingresponsewriter.go index a79d51d4c6ac26061b4f2332a91cf29525bc60d5..9bcecf2d9b11962a3c838b4a6bff84cd474610c2 100644 --- a/workhorse/internal/helper/countingresponsewriter.go +++ b/workhorse/internal/helper/countingresponsewriter.go @@ -54,3 +54,8 @@ func (c *countingResponseWriter) Count() int64 { func (c *countingResponseWriter) Status() int { return c.status } + +// Unwrap lets http.ResponseController get the underlying http.ResponseWriter. +func (c *countingResponseWriter) Unwrap() http.ResponseWriter { + return c.rw +} diff --git a/workhorse/internal/helper/countingresponsewriter_test.go b/workhorse/internal/helper/countingresponsewriter_test.go index f9f2f4ced5b0d84d9d7d6a328a2e355d99a10b96..d070b215b90a61a84dbcae7f262a2f9458d05703 100644 --- a/workhorse/internal/helper/countingresponsewriter_test.go +++ b/workhorse/internal/helper/countingresponsewriter_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io" "net/http" + "net/http/httptest" "testing" "testing/iotest" @@ -48,3 +49,13 @@ func TestCountingResponseWriterWrite(t *testing.T) { require.Equal(t, string(testData), string(trw.data)) } + +func TestCountingResponseWriterFlushable(t *testing.T) { + rw := httptest.NewRecorder() + crw := countingResponseWriter{rw: rw} + rc := http.NewResponseController(&crw) + + err := rc.Flush() + require.NoError(t, err, "the underlying response writer is not flushable") + require.True(t, rw.Flushed) +} diff --git a/workhorse/internal/senddata/contentprocessor/contentprocessor.go b/workhorse/internal/senddata/contentprocessor/contentprocessor.go index 1c97bd923aa9006c3e93d8e1a8288de3863901b4..f604a481fa264dc9aa6c625b030352cb28f35d19 100644 --- a/workhorse/internal/senddata/contentprocessor/contentprocessor.go +++ b/workhorse/internal/senddata/contentprocessor/contentprocessor.go @@ -2,7 +2,6 @@ package contentprocessor import ( "bytes" - "io" "net/http" "gitlab.com/gitlab-org/gitlab/workhorse/internal/headers" @@ -30,7 +29,7 @@ func SetContentHeaders(h http.Handler) http.Handler { status: http.StatusOK, } - defer cd.flush() + defer cd.Flush() h.ServeHTTP(cd, r) }) @@ -71,7 +70,8 @@ func (cd *contentDisposition) flushBuffer() error { if cd.buf.Len() > 0 { cd.writeContentHeaders() cd.WriteHeader(cd.status) - _, err := io.Copy(cd.rw, cd.buf) + _, err := cd.rw.Write(cd.buf.Bytes()) + cd.buf.Reset() return err } @@ -121,6 +121,20 @@ func (cd *contentDisposition) isUnbuffered() bool { return cd.flushed || !cd.active } -func (cd *contentDisposition) flush() { - cd.flushBuffer() +func (cd *contentDisposition) Flush() { + cd.FlushError() +} + +// FlushError lets http.ResponseController to be used to flush the underlying http.ResponseWriter. +func (cd *contentDisposition) FlushError() error { + err := cd.flushBuffer() + if err != nil { + return err + } + return http.NewResponseController(cd.rw).Flush() +} + +// Unwrap lets http.ResponseController get the underlying http.ResponseWriter. +func (cd *contentDisposition) Unwrap() http.ResponseWriter { + return cd.rw } diff --git a/workhorse/internal/senddata/senddata.go b/workhorse/internal/senddata/senddata.go index 4cb96890ee2e8d431510ba52b0e5c9213fa8f1f8..f0d1da021799dd7d115343a24dd8c67e38ac172a 100644 --- a/workhorse/internal/senddata/senddata.go +++ b/workhorse/internal/senddata/senddata.go @@ -104,3 +104,8 @@ func (s *sendDataResponseWriter) tryInject() bool { func (s *sendDataResponseWriter) flush() { s.WriteHeader(http.StatusOK) } + +// Unwrap lets http.ResponseController get the underlying http.ResponseWriter. +func (s *sendDataResponseWriter) Unwrap() http.ResponseWriter { + return s.rw +} diff --git a/workhorse/internal/sendfile/sendfile.go b/workhorse/internal/sendfile/sendfile.go index 70d93f1109c7ba9ad44a7296f858b7c1942c0c60..09562cd27ffb28770e24f203b0d24a53f50f5913 100644 --- a/workhorse/internal/sendfile/sendfile.go +++ b/workhorse/internal/sendfile/sendfile.go @@ -103,6 +103,11 @@ func (s *sendFileResponseWriter) WriteHeader(status int) { s.rw.WriteHeader(s.status) } +// Unwrap lets http.ResponseController get the underlying http.ResponseWriter. +func (s *sendFileResponseWriter) Unwrap() http.ResponseWriter { + return s.rw +} + func sendFileFromDisk(w http.ResponseWriter, r *http.Request, file string) { log.WithContextFields(r.Context(), log.Fields{ "file": file, diff --git a/workhorse/internal/sendfile/sendfile_test.go b/workhorse/internal/sendfile/sendfile_test.go index 002de7f9f3e82599d4a43c18b280658856e60125..72dac5339f944bea49309fe68f04a5d04edfa680 100644 --- a/workhorse/internal/sendfile/sendfile_test.go +++ b/workhorse/internal/sendfile/sendfile_test.go @@ -170,3 +170,13 @@ func makeRequest(t *testing.T, fixturePath string, httpHeaders map[string]string return resp } + +func TestSendFileResponseWriterFlushable(t *testing.T) { + rw := httptest.NewRecorder() + sfrw := sendFileResponseWriter{rw: rw} + rc := http.NewResponseController(&sfrw) + + err := rc.Flush() + require.NoError(t, err, "the underlying response writer is not flushable") + require.True(t, rw.Flushed) +} diff --git a/workhorse/internal/sendurl/sendurl.go b/workhorse/internal/sendurl/sendurl.go index e011f57c6bc11a87ad63ef4d66afba1a27f66a58..116c68ecba914b4e84b068f75b3bdaca87681d20 100644 --- a/workhorse/internal/sendurl/sendurl.go +++ b/workhorse/internal/sendurl/sendurl.go @@ -158,7 +158,11 @@ func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) w.WriteHeader(resp.StatusCode) defer resp.Body.Close() - n, err := io.Copy(w, resp.Body) + + // Flushes the response right after it received. + // Important for streaming responses, where content delivered in chunks. + // Without flushing the body gets buffered by the HTTP server's internal buffer. + n, err := io.Copy(newFlushingResponseWriter(w), resp.Body) sendURLBytes.Add(float64(n)) if err != nil { @@ -190,3 +194,25 @@ func newClient(params entryParams) *http.Client { return client } + +func newFlushingResponseWriter(w http.ResponseWriter) *httpFlushingResponseWriter { + return &httpFlushingResponseWriter{ + ResponseWriter: w, + controller: http.NewResponseController(w), + } +} + +type httpFlushingResponseWriter struct { + http.ResponseWriter + controller *http.ResponseController +} + +// Write flushes the response once its written +func (h *httpFlushingResponseWriter) Write(data []byte) (int, error) { + n, err := h.ResponseWriter.Write(data) + if err != nil { + return n, err + } + + return n, h.controller.Flush() +} diff --git a/workhorse/internal/staticpages/error_pages.go b/workhorse/internal/staticpages/error_pages.go index d1aa7603658f69c7c96432ed1b7634b9cb0825bf..118886dfa40d1be0778f30c46725b698b764f5a2 100644 --- a/workhorse/internal/staticpages/error_pages.go +++ b/workhorse/internal/staticpages/error_pages.go @@ -11,14 +11,12 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) -var ( - staticErrorResponses = promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: "gitlab_workhorse_static_error_responses", - Help: "How many HTTP responses have been changed to a static error page, by HTTP status code.", - }, - []string{"code"}, - ) +var staticErrorResponses = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "gitlab_workhorse_static_error_responses", + Help: "How many HTTP responses have been changed to a static error page, by HTTP status code.", + }, + []string{"code"}, ) type ErrorFormat int @@ -120,6 +118,11 @@ func (s *errorPageResponseWriter) flush() { s.WriteHeader(http.StatusOK) } +// Unwrap lets http.ResponseController get the underlying http.ResponseWriter. +func (s *errorPageResponseWriter) Unwrap() http.ResponseWriter { + return s.rw +} + func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler { if disabled { return handler diff --git a/workhorse/internal/staticpages/error_pages_test.go b/workhorse/internal/staticpages/error_pages_test.go index 12c268fb40b04f484ab73ba00f8d15d36658ca5c..2d8646e83020cab10da53ab08ab173afe7f8bd01 100644 --- a/workhorse/internal/staticpages/error_pages_test.go +++ b/workhorse/internal/staticpages/error_pages_test.go @@ -17,7 +17,7 @@ func TestIfErrorPageIsPresented(t *testing.T) { dir := t.TempDir() errorPage := "ERROR" - os.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) + os.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0o600) w := httptest.NewRecorder() h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -57,7 +57,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { dir := t.TempDir() errorPage := "ERROR" - os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600) w := httptest.NewRecorder() serverError := "Interesting Server Error" @@ -76,7 +76,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) { dir := t.TempDir() errorPage := "ERROR" - os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600) w := httptest.NewRecorder() serverError := "Interesting Server Error" @@ -107,7 +107,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) { dir := t.TempDir() errorPage := "ERROR" - os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) + os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600) w := httptest.NewRecorder() serverError := "Interesting Server Error" @@ -168,3 +168,13 @@ func TestIfErrorPageIsPresentedText(t *testing.T) { testhelper.RequireResponseBody(t, w, errorPage) testhelper.RequireResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8") } + +func TestErrorPageResponseWriterFlushable(t *testing.T) { + rw := httptest.NewRecorder() + eprw := errorPageResponseWriter{rw: rw} + rc := http.NewResponseController(&eprw) + + err := rc.Flush() + require.NoError(t, err, "the underlying response writer is not flushable") + require.True(t, rw.Flushed) +}