From 5e76a4dc7d8572964af2e8dcd62c7e3453b2a45f Mon Sep 17 00:00:00 2001 From: Jacob Vosmaer <jacob@gitlab.com> Date: Tue, 7 Feb 2017 13:59:39 +0100 Subject: [PATCH] Prevent writing the receive-pack response to early --- internal/git/receive-pack.go | 4 +- internal/helper/writeafterreader.go | 129 +++++++++++++++++++++++ internal/helper/writeafterreader_test.go | 115 ++++++++++++++++++++ 3 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 internal/helper/writeafterreader.go create mode 100644 internal/helper/writeafterreader_test.go diff --git a/internal/git/receive-pack.go b/internal/git/receive-pack.go index 4b8f1fc09aff2..a4dc34b5aa269 100644 --- a/internal/git/receive-pack.go +++ b/internal/git/receive-pack.go @@ -14,7 +14,9 @@ func handleReceivePack(w *GitHttpResponseWriter, r *http.Request, a *api.Respons action := getService(r) writePostRPCHeader(w, action) - cmd, err := startGitCommand(a, r.Body, w, action) + cr, cw := helper.NewWriteAfterReader(r.Body, w) + defer cw.Flush() + cmd, err := startGitCommand(a, cr, cw, action) if err != nil { return fmt.Errorf("startGitCommand: %v", err) } diff --git a/internal/helper/writeafterreader.go b/internal/helper/writeafterreader.go new file mode 100644 index 0000000000000..b6955b249c7ee --- /dev/null +++ b/internal/helper/writeafterreader.go @@ -0,0 +1,129 @@ +package helper + +import ( + "io" + "io/ioutil" + "os" + "sync" +) + +type WriteFlusher interface { + io.Writer + Flush() error +} + +// Couple r and w so that until r has been drained (before r.Read() has +// returned some error), all writes to w are sent to a tempfile first. +// The caller must call Flush() on the returned WriteFlusher to ensure +// all data is propagated to w. +func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) { + br := &busyReader{Reader: r} + return br, &coupledWriter{Writer: w, busyReader: br} +} + +type busyReader struct { + io.Reader + + error + errorMutex sync.RWMutex +} + +func (r *busyReader) Read(p []byte) (int, error) { + if err := r.getError(); err != nil { + return 0, err + } + + n, err := r.Reader.Read(p) + if err != nil { + r.setError(err) + } + return n, err +} + +func (r *busyReader) IsBusy() bool { + return r.getError() == nil +} + +func (r *busyReader) getError() error { + r.errorMutex.RLock() + defer r.errorMutex.RUnlock() + return r.error +} + +func (r *busyReader) setError(err error) { + if err == nil { + panic("busyReader: attempt to reset error to nil") + } + r.errorMutex.Lock() + defer r.errorMutex.Unlock() + r.error = err +} + +type coupledWriter struct { + io.Writer + *busyReader + + tempfile *os.File + tempfileMutex sync.Mutex +} + +func (w *coupledWriter) Write(data []byte) (int, error) { + if w.busyReader.IsBusy() { + return w.tempfileWrite(data) + } + + if err := w.Flush(); err != nil { + return 0, err + } + + return w.Writer.Write(data) +} + +func (w *coupledWriter) Flush() error { + w.tempfileMutex.Lock() + defer w.tempfileMutex.Unlock() + + tempfile := w.tempfile + if tempfile == nil { + return nil + } + + w.tempfile = nil + defer tempfile.Close() + + if _, err := tempfile.Seek(0, 0); err != nil { + return err + } + if _, err := io.Copy(w.Writer, tempfile); err != nil { + return err + } + return nil +} + +func (w *coupledWriter) tempfileWrite(data []byte) (int, error) { + w.tempfileMutex.Lock() + defer w.tempfileMutex.Unlock() + + if w.tempfile == nil { + tempfile, err := w.newTempfile() + if err != nil { + return 0, err + } + w.tempfile = tempfile + } + + return w.tempfile.Write(data) +} + +func (*coupledWriter) newTempfile() (tempfile *os.File, err error) { + tempfile, err = ioutil.TempFile("", "gitlab-workhorse-coupledWriter") + if err != nil { + return nil, err + } + if err := os.Remove(tempfile.Name()); err != nil { + tempfile.Close() + return nil, err + } + + return tempfile, nil +} diff --git a/internal/helper/writeafterreader_test.go b/internal/helper/writeafterreader_test.go new file mode 100644 index 0000000000000..ea1504aa15e4c --- /dev/null +++ b/internal/helper/writeafterreader_test.go @@ -0,0 +1,115 @@ +package helper + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "testing" + "testing/iotest" +) + +func TestBusyReader(t *testing.T) { + testData := "test data" + r := testReader(testData) + br, _ := NewWriteAfterReader(r, &bytes.Buffer{}) + + result, err := ioutil.ReadAll(br) + if err != nil { + t.Fatal(err) + } + + if string(result) != testData { + t.Fatalf("expected %q, got %q", testData, result) + } +} + +func TestFirstWriteAfterReadDone(t *testing.T) { + writeRecorder := &bytes.Buffer{} + br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder) + if _, err := io.Copy(ioutil.Discard, br); err != nil { + t.Fatalf("copy from busyreader: %v", err) + } + testData := "test data" + if _, err := io.Copy(cw, testReader(testData)); err != nil { + t.Fatalf("copy test data: %v", err) + } + if err := cw.Flush(); err != nil { + t.Fatalf("flush error: %v", err) + } + if result := writeRecorder.String(); result != testData { + t.Fatalf("expected %q, got %q", testData, result) + } +} + +func TestWriteDelay(t *testing.T) { + writeRecorder := &bytes.Buffer{} + w := &complainingWriter{Writer: writeRecorder} + br, cw := NewWriteAfterReader(&bytes.Buffer{}, w) + + testData1 := "1 test" + if _, err := io.Copy(cw, testReader(testData1)); err != nil { + t.Fatalf("error on first copy: %v", err) + } + + // Unblock the coupled writer by draining the reader + if _, err := io.Copy(ioutil.Discard, br); err != nil { + t.Fatalf("copy from busyreader: %v", err) + } + // Now it is no longer an error if 'w' receives a Write() + w.CheerUp() + + testData2 := "2 experiment" + if _, err := io.Copy(cw, testReader(testData2)); err != nil { + t.Fatalf("error on second copy: %v", err) + } + + if err := cw.Flush(); err != nil { + t.Fatalf("flush error: %v", err) + } + + expected := testData1 + testData2 + if result := writeRecorder.String(); result != expected { + t.Fatalf("total write: expected %q, got %q", expected, result) + } +} + +func TestComplainingWriterSanity(t *testing.T) { + recorder := &bytes.Buffer{} + w := &complainingWriter{Writer: recorder} + + testData := "test data" + if _, err := io.Copy(w, testReader(testData)); err == nil { + t.Error("error expected, none received") + } + + w.CheerUp() + if _, err := io.Copy(w, testReader(testData)); err != nil { + t.Error("copy after CheerUp: %v", err) + } + + if result := recorder.String(); result != testData { + t.Errorf("expected %q, got %q", testData, result) + } +} + +func testReader(data string) io.Reader { + return iotest.OneByteReader(bytes.NewBuffer([]byte(data))) +} + +type complainingWriter struct { + happy bool + io.Writer +} + +func (comp *complainingWriter) Write(data []byte) (int, error) { + if comp.happy { + return comp.Writer.Write(data) + } + + return 0, fmt.Errorf("I am unhappy about you wanting to write %q", data) +} + +func (comp *complainingWriter) CheerUp() { + comp.happy = true +} -- GitLab