Skip to content
代码片段 群组 项目
提交 f9e2c8c4 编辑于 作者: Alessio Caiazza's avatar Alessio Caiazza
浏览文件

Merge branch 'ck3g-add-streaming-to-code_completions-api-endpoint' into 'master'

No related branches found
No related tags found
无相关合并请求
显示
142 个添加21 个删除
...@@ -119,6 +119,7 @@ def gitlab_realm ...@@ -119,6 +119,7 @@ def gitlab_realm
::CodeSuggestions::InstructionsExtractor::INTENT_GENERATION ::CodeSuggestions::InstructionsExtractor::INTENT_GENERATION
], ],
desc: 'The intent of the completion request, current options are "completion" or "generation"' desc: 'The intent of the completion request, current options are "completion" or "generation"'
optional :stream, type: Boolean, default: false, desc: 'The option to stream code completion response'
end end
post do post do
if Gitlab.org_or_com? if Gitlab.org_or_com?
......
...@@ -245,6 +245,7 @@ def is_even(n: int) -> ...@@ -245,6 +245,7 @@ def is_even(n: int) ->
content_above_cursor: prefix, content_above_cursor: prefix,
content_below_cursor: '' content_below_cursor: ''
}, },
stream: false,
**additional_params **additional_params
} }
end end
...@@ -409,6 +410,21 @@ def request ...@@ -409,6 +410,21 @@ def request
end end
end end
end end
context 'when passing stream parameter' do
let(:additional_params) { { stream: true } }
it 'passes stream into TaskFactory.new' do
expect(::CodeSuggestions::TaskFactory).to receive(:new)
.with(
current_user,
params: hash_including(stream: true),
unsafe_passthrough_params: kind_of(Hash)
).and_call_original
post_api
end
end
end end
end end
......
...@@ -23,7 +23,7 @@ require ( ...@@ -23,7 +23,7 @@ require (
github.com/smartystreets/goconvey v1.8.1 github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
gitlab.com/gitlab-org/gitaly/v16 v16.4.1 gitlab.com/gitlab-org/gitaly/v16 v16.4.1
gitlab.com/gitlab-org/labkit v1.20.0 gitlab.com/gitlab-org/labkit v1.21.0
gocloud.dev v0.34.0 gocloud.dev v0.34.0
golang.org/x/image v0.7.0 golang.org/x/image v0.7.0
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 golang.org/x/lint v0.0.0-20210508222113-6edffad5e616
......
...@@ -452,8 +452,8 @@ github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPR ...@@ -452,8 +452,8 @@ github.com/yusufpapurcu/wmi v1.2.2 h1:KBNDSne4vP5mbSWnJbO+51IMOXJB67QiYCSBrubbPR
github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
gitlab.com/gitlab-org/gitaly/v16 v16.4.1 h1:Qh5TFK+Jy/mBV8hCfNro2VCqRrhgt3M2iTrdYVF5N6o= gitlab.com/gitlab-org/gitaly/v16 v16.4.1 h1:Qh5TFK+Jy/mBV8hCfNro2VCqRrhgt3M2iTrdYVF5N6o=
gitlab.com/gitlab-org/gitaly/v16 v16.4.1/go.mod h1:TdN/Q3OqxU75pcp8V5YWpnE8Gk6dagwlC/HefNnW1IE= gitlab.com/gitlab-org/gitaly/v16 v16.4.1/go.mod h1:TdN/Q3OqxU75pcp8V5YWpnE8Gk6dagwlC/HefNnW1IE=
gitlab.com/gitlab-org/labkit v1.20.0 h1:DGIVAdzbCR8sq2TppBvAh35wWBYIOy5dBL5wqFK3Wa8= gitlab.com/gitlab-org/labkit v1.21.0 h1:hLmdBDtXjD1yOmZ+uJOac3a5Tlo83QaezwhES4IYik4=
gitlab.com/gitlab-org/labkit v1.20.0/go.mod h1:zeATDAaSBelPcPLbTTq8J3ZJEHyPTLVBM1q3nva+/W4= gitlab.com/gitlab-org/labkit v1.21.0/go.mod h1:zeATDAaSBelPcPLbTTq8J3ZJEHyPTLVBM1q3nva+/W4=
go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
......
...@@ -60,3 +60,8 @@ func (b *blocker) WriteHeader(status int) { ...@@ -60,3 +60,8 @@ func (b *blocker) WriteHeader(status int) {
func (b *blocker) flush() { func (b *blocker) flush() {
b.WriteHeader(http.StatusOK) b.WriteHeader(http.StatusOK)
} }
// Unwrap lets http.ResponseController get the underlying http.ResponseWriter.
func (b *blocker) Unwrap() http.ResponseWriter {
return b.rw
}
...@@ -54,3 +54,13 @@ func TestBlocker(t *testing.T) { ...@@ -54,3 +54,13 @@ func TestBlocker(t *testing.T) {
}) })
} }
} }
func TestBlockerFlushable(t *testing.T) {
rw := httptest.NewRecorder()
b := blocker{rw: rw}
rc := http.NewResponseController(&b)
err := rc.Flush()
require.NoError(t, err, "the underlying response writer is not flushable")
require.True(t, rw.Flushed)
}
...@@ -54,3 +54,8 @@ func (c *countingResponseWriter) Count() int64 { ...@@ -54,3 +54,8 @@ func (c *countingResponseWriter) Count() int64 {
func (c *countingResponseWriter) Status() int { func (c *countingResponseWriter) Status() int {
return c.status return c.status
} }
// Unwrap lets http.ResponseController get the underlying http.ResponseWriter.
func (c *countingResponseWriter) Unwrap() http.ResponseWriter {
return c.rw
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"io" "io"
"net/http" "net/http"
"net/http/httptest"
"testing" "testing"
"testing/iotest" "testing/iotest"
...@@ -48,3 +49,13 @@ func TestCountingResponseWriterWrite(t *testing.T) { ...@@ -48,3 +49,13 @@ func TestCountingResponseWriterWrite(t *testing.T) {
require.Equal(t, string(testData), string(trw.data)) require.Equal(t, string(testData), string(trw.data))
} }
func TestCountingResponseWriterFlushable(t *testing.T) {
rw := httptest.NewRecorder()
crw := countingResponseWriter{rw: rw}
rc := http.NewResponseController(&crw)
err := rc.Flush()
require.NoError(t, err, "the underlying response writer is not flushable")
require.True(t, rw.Flushed)
}
...@@ -2,7 +2,6 @@ package contentprocessor ...@@ -2,7 +2,6 @@ package contentprocessor
import ( import (
"bytes" "bytes"
"io"
"net/http" "net/http"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/headers" "gitlab.com/gitlab-org/gitlab/workhorse/internal/headers"
...@@ -30,7 +29,7 @@ func SetContentHeaders(h http.Handler) http.Handler { ...@@ -30,7 +29,7 @@ func SetContentHeaders(h http.Handler) http.Handler {
status: http.StatusOK, status: http.StatusOK,
} }
defer cd.flush() defer cd.Flush()
h.ServeHTTP(cd, r) h.ServeHTTP(cd, r)
}) })
...@@ -71,7 +70,8 @@ func (cd *contentDisposition) flushBuffer() error { ...@@ -71,7 +70,8 @@ func (cd *contentDisposition) flushBuffer() error {
if cd.buf.Len() > 0 { if cd.buf.Len() > 0 {
cd.writeContentHeaders() cd.writeContentHeaders()
cd.WriteHeader(cd.status) cd.WriteHeader(cd.status)
_, err := io.Copy(cd.rw, cd.buf) _, err := cd.rw.Write(cd.buf.Bytes())
cd.buf.Reset()
return err return err
} }
...@@ -121,6 +121,20 @@ func (cd *contentDisposition) isUnbuffered() bool { ...@@ -121,6 +121,20 @@ func (cd *contentDisposition) isUnbuffered() bool {
return cd.flushed || !cd.active return cd.flushed || !cd.active
} }
func (cd *contentDisposition) flush() { func (cd *contentDisposition) Flush() {
cd.flushBuffer() cd.FlushError()
}
// FlushError lets http.ResponseController to be used to flush the underlying http.ResponseWriter.
func (cd *contentDisposition) FlushError() error {
err := cd.flushBuffer()
if err != nil {
return err
}
return http.NewResponseController(cd.rw).Flush()
}
// Unwrap lets http.ResponseController get the underlying http.ResponseWriter.
func (cd *contentDisposition) Unwrap() http.ResponseWriter {
return cd.rw
} }
...@@ -104,3 +104,8 @@ func (s *sendDataResponseWriter) tryInject() bool { ...@@ -104,3 +104,8 @@ func (s *sendDataResponseWriter) tryInject() bool {
func (s *sendDataResponseWriter) flush() { func (s *sendDataResponseWriter) flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
// Unwrap lets http.ResponseController get the underlying http.ResponseWriter.
func (s *sendDataResponseWriter) Unwrap() http.ResponseWriter {
return s.rw
}
...@@ -103,6 +103,11 @@ func (s *sendFileResponseWriter) WriteHeader(status int) { ...@@ -103,6 +103,11 @@ func (s *sendFileResponseWriter) WriteHeader(status int) {
s.rw.WriteHeader(s.status) s.rw.WriteHeader(s.status)
} }
// Unwrap lets http.ResponseController get the underlying http.ResponseWriter.
func (s *sendFileResponseWriter) Unwrap() http.ResponseWriter {
return s.rw
}
func sendFileFromDisk(w http.ResponseWriter, r *http.Request, file string) { func sendFileFromDisk(w http.ResponseWriter, r *http.Request, file string) {
log.WithContextFields(r.Context(), log.Fields{ log.WithContextFields(r.Context(), log.Fields{
"file": file, "file": file,
......
...@@ -170,3 +170,13 @@ func makeRequest(t *testing.T, fixturePath string, httpHeaders map[string]string ...@@ -170,3 +170,13 @@ func makeRequest(t *testing.T, fixturePath string, httpHeaders map[string]string
return resp return resp
} }
func TestSendFileResponseWriterFlushable(t *testing.T) {
rw := httptest.NewRecorder()
sfrw := sendFileResponseWriter{rw: rw}
rc := http.NewResponseController(&sfrw)
err := rc.Flush()
require.NoError(t, err, "the underlying response writer is not flushable")
require.True(t, rw.Flushed)
}
...@@ -158,7 +158,11 @@ func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string) ...@@ -158,7 +158,11 @@ func (e *entry) Inject(w http.ResponseWriter, r *http.Request, sendData string)
w.WriteHeader(resp.StatusCode) w.WriteHeader(resp.StatusCode)
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.Copy(w, resp.Body)
// Flushes the response right after it received.
// Important for streaming responses, where content delivered in chunks.
// Without flushing the body gets buffered by the HTTP server's internal buffer.
n, err := io.Copy(newFlushingResponseWriter(w), resp.Body)
sendURLBytes.Add(float64(n)) sendURLBytes.Add(float64(n))
if err != nil { if err != nil {
...@@ -190,3 +194,25 @@ func newClient(params entryParams) *http.Client { ...@@ -190,3 +194,25 @@ func newClient(params entryParams) *http.Client {
return client return client
} }
func newFlushingResponseWriter(w http.ResponseWriter) *httpFlushingResponseWriter {
return &httpFlushingResponseWriter{
ResponseWriter: w,
controller: http.NewResponseController(w),
}
}
type httpFlushingResponseWriter struct {
http.ResponseWriter
controller *http.ResponseController
}
// Write flushes the response once its written
func (h *httpFlushingResponseWriter) Write(data []byte) (int, error) {
n, err := h.ResponseWriter.Write(data)
if err != nil {
return n, err
}
return n, h.controller.Flush()
}
...@@ -11,14 +11,12 @@ import ( ...@@ -11,14 +11,12 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
) )
var ( var staticErrorResponses = promauto.NewCounterVec(
staticErrorResponses = promauto.NewCounterVec( prometheus.CounterOpts{
prometheus.CounterOpts{ Name: "gitlab_workhorse_static_error_responses",
Name: "gitlab_workhorse_static_error_responses", Help: "How many HTTP responses have been changed to a static error page, by HTTP status code.",
Help: "How many HTTP responses have been changed to a static error page, by HTTP status code.", },
}, []string{"code"},
[]string{"code"},
)
) )
type ErrorFormat int type ErrorFormat int
...@@ -120,6 +118,11 @@ func (s *errorPageResponseWriter) flush() { ...@@ -120,6 +118,11 @@ func (s *errorPageResponseWriter) flush() {
s.WriteHeader(http.StatusOK) s.WriteHeader(http.StatusOK)
} }
// Unwrap lets http.ResponseController get the underlying http.ResponseWriter.
func (s *errorPageResponseWriter) Unwrap() http.ResponseWriter {
return s.rw
}
func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler { func (st *Static) ErrorPagesUnless(disabled bool, format ErrorFormat, handler http.Handler) http.Handler {
if disabled { if disabled {
return handler return handler
......
...@@ -17,7 +17,7 @@ func TestIfErrorPageIsPresented(t *testing.T) { ...@@ -17,7 +17,7 @@ func TestIfErrorPageIsPresented(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
errorPage := "ERROR" errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0600) os.WriteFile(filepath.Join(dir, "404.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
...@@ -57,7 +57,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) { ...@@ -57,7 +57,7 @@ func TestIfErrorPageIsIgnoredInDevelopment(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
errorPage := "ERROR" errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
serverError := "Interesting Server Error" serverError := "Interesting Server Error"
...@@ -76,7 +76,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) { ...@@ -76,7 +76,7 @@ func TestIfErrorPageIsIgnoredIfCustomError(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
errorPage := "ERROR" errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
serverError := "Interesting Server Error" serverError := "Interesting Server Error"
...@@ -107,7 +107,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) { ...@@ -107,7 +107,7 @@ func TestErrorPageInterceptedByContentType(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
errorPage := "ERROR" errorPage := "ERROR"
os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0600) os.WriteFile(filepath.Join(dir, "500.html"), []byte(errorPage), 0o600)
w := httptest.NewRecorder() w := httptest.NewRecorder()
serverError := "Interesting Server Error" serverError := "Interesting Server Error"
...@@ -168,3 +168,13 @@ func TestIfErrorPageIsPresentedText(t *testing.T) { ...@@ -168,3 +168,13 @@ func TestIfErrorPageIsPresentedText(t *testing.T) {
testhelper.RequireResponseBody(t, w, errorPage) testhelper.RequireResponseBody(t, w, errorPage)
testhelper.RequireResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8") testhelper.RequireResponseHeader(t, w, "Content-Type", "text/plain; charset=utf-8")
} }
func TestErrorPageResponseWriterFlushable(t *testing.T) {
rw := httptest.NewRecorder()
eprw := errorPageResponseWriter{rw: rw}
rc := http.NewResponseController(&eprw)
err := rc.Flush()
require.NoError(t, err, "the underlying response writer is not flushable")
require.True(t, rw.Flushed)
}
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册