From 5e76a4dc7d8572964af2e8dcd62c7e3453b2a45f Mon Sep 17 00:00:00 2001
From: Jacob Vosmaer <jacob@gitlab.com>
Date: Tue, 7 Feb 2017 13:59:39 +0100
Subject: [PATCH] Prevent writing the receive-pack response to early

---
 internal/git/receive-pack.go             |   4 +-
 internal/helper/writeafterreader.go      | 129 +++++++++++++++++++++++
 internal/helper/writeafterreader_test.go | 115 ++++++++++++++++++++
 3 files changed, 247 insertions(+), 1 deletion(-)
 create mode 100644 internal/helper/writeafterreader.go
 create mode 100644 internal/helper/writeafterreader_test.go

diff --git a/internal/git/receive-pack.go b/internal/git/receive-pack.go
index 4b8f1fc09aff2..a4dc34b5aa269 100644
--- a/internal/git/receive-pack.go
+++ b/internal/git/receive-pack.go
@@ -14,7 +14,9 @@ func handleReceivePack(w *GitHttpResponseWriter, r *http.Request, a *api.Respons
 	action := getService(r)
 	writePostRPCHeader(w, action)
 
-	cmd, err := startGitCommand(a, r.Body, w, action)
+	cr, cw := helper.NewWriteAfterReader(r.Body, w)
+	defer cw.Flush()
+	cmd, err := startGitCommand(a, cr, cw, action)
 	if err != nil {
 		return fmt.Errorf("startGitCommand: %v", err)
 	}
diff --git a/internal/helper/writeafterreader.go b/internal/helper/writeafterreader.go
new file mode 100644
index 0000000000000..b6955b249c7ee
--- /dev/null
+++ b/internal/helper/writeafterreader.go
@@ -0,0 +1,129 @@
+package helper
+
+import (
+	"io"
+	"io/ioutil"
+	"os"
+	"sync"
+)
+
+type WriteFlusher interface {
+	io.Writer
+	Flush() error
+}
+
+// Couple r and w so that until r has been drained (before r.Read() has
+// returned some error), all writes to w are sent to a tempfile first.
+// The caller must call Flush() on the returned WriteFlusher to ensure
+// all data is propagated to w.
+func NewWriteAfterReader(r io.Reader, w io.Writer) (io.Reader, WriteFlusher) {
+	br := &busyReader{Reader: r}
+	return br, &coupledWriter{Writer: w, busyReader: br}
+}
+
+type busyReader struct {
+	io.Reader
+
+	error
+	errorMutex sync.RWMutex
+}
+
+func (r *busyReader) Read(p []byte) (int, error) {
+	if err := r.getError(); err != nil {
+		return 0, err
+	}
+
+	n, err := r.Reader.Read(p)
+	if err != nil {
+		r.setError(err)
+	}
+	return n, err
+}
+
+func (r *busyReader) IsBusy() bool {
+	return r.getError() == nil
+}
+
+func (r *busyReader) getError() error {
+	r.errorMutex.RLock()
+	defer r.errorMutex.RUnlock()
+	return r.error
+}
+
+func (r *busyReader) setError(err error) {
+	if err == nil {
+		panic("busyReader: attempt to reset error to nil")
+	}
+	r.errorMutex.Lock()
+	defer r.errorMutex.Unlock()
+	r.error = err
+}
+
+type coupledWriter struct {
+	io.Writer
+	*busyReader
+
+	tempfile      *os.File
+	tempfileMutex sync.Mutex
+}
+
+func (w *coupledWriter) Write(data []byte) (int, error) {
+	if w.busyReader.IsBusy() {
+		return w.tempfileWrite(data)
+	}
+
+	if err := w.Flush(); err != nil {
+		return 0, err
+	}
+
+	return w.Writer.Write(data)
+}
+
+func (w *coupledWriter) Flush() error {
+	w.tempfileMutex.Lock()
+	defer w.tempfileMutex.Unlock()
+
+	tempfile := w.tempfile
+	if tempfile == nil {
+		return nil
+	}
+
+	w.tempfile = nil
+	defer tempfile.Close()
+
+	if _, err := tempfile.Seek(0, 0); err != nil {
+		return err
+	}
+	if _, err := io.Copy(w.Writer, tempfile); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (w *coupledWriter) tempfileWrite(data []byte) (int, error) {
+	w.tempfileMutex.Lock()
+	defer w.tempfileMutex.Unlock()
+
+	if w.tempfile == nil {
+		tempfile, err := w.newTempfile()
+		if err != nil {
+			return 0, err
+		}
+		w.tempfile = tempfile
+	}
+
+	return w.tempfile.Write(data)
+}
+
+func (*coupledWriter) newTempfile() (tempfile *os.File, err error) {
+	tempfile, err = ioutil.TempFile("", "gitlab-workhorse-coupledWriter")
+	if err != nil {
+		return nil, err
+	}
+	if err := os.Remove(tempfile.Name()); err != nil {
+		tempfile.Close()
+		return nil, err
+	}
+
+	return tempfile, nil
+}
diff --git a/internal/helper/writeafterreader_test.go b/internal/helper/writeafterreader_test.go
new file mode 100644
index 0000000000000..ea1504aa15e4c
--- /dev/null
+++ b/internal/helper/writeafterreader_test.go
@@ -0,0 +1,115 @@
+package helper
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"testing"
+	"testing/iotest"
+)
+
+func TestBusyReader(t *testing.T) {
+	testData := "test data"
+	r := testReader(testData)
+	br, _ := NewWriteAfterReader(r, &bytes.Buffer{})
+
+	result, err := ioutil.ReadAll(br)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if string(result) != testData {
+		t.Fatalf("expected %q, got %q", testData, result)
+	}
+}
+
+func TestFirstWriteAfterReadDone(t *testing.T) {
+	writeRecorder := &bytes.Buffer{}
+	br, cw := NewWriteAfterReader(&bytes.Buffer{}, writeRecorder)
+	if _, err := io.Copy(ioutil.Discard, br); err != nil {
+		t.Fatalf("copy from busyreader: %v", err)
+	}
+	testData := "test data"
+	if _, err := io.Copy(cw, testReader(testData)); err != nil {
+		t.Fatalf("copy test data: %v", err)
+	}
+	if err := cw.Flush(); err != nil {
+		t.Fatalf("flush error: %v", err)
+	}
+	if result := writeRecorder.String(); result != testData {
+		t.Fatalf("expected %q, got %q", testData, result)
+	}
+}
+
+func TestWriteDelay(t *testing.T) {
+	writeRecorder := &bytes.Buffer{}
+	w := &complainingWriter{Writer: writeRecorder}
+	br, cw := NewWriteAfterReader(&bytes.Buffer{}, w)
+
+	testData1 := "1 test"
+	if _, err := io.Copy(cw, testReader(testData1)); err != nil {
+		t.Fatalf("error on first copy: %v", err)
+	}
+
+	// Unblock the coupled writer by draining the reader
+	if _, err := io.Copy(ioutil.Discard, br); err != nil {
+		t.Fatalf("copy from busyreader: %v", err)
+	}
+	// Now it is no longer an error if 'w' receives a Write()
+	w.CheerUp()
+
+	testData2 := "2 experiment"
+	if _, err := io.Copy(cw, testReader(testData2)); err != nil {
+		t.Fatalf("error on second copy: %v", err)
+	}
+
+	if err := cw.Flush(); err != nil {
+		t.Fatalf("flush error: %v", err)
+	}
+
+	expected := testData1 + testData2
+	if result := writeRecorder.String(); result != expected {
+		t.Fatalf("total write: expected %q, got %q", expected, result)
+	}
+}
+
+func TestComplainingWriterSanity(t *testing.T) {
+	recorder := &bytes.Buffer{}
+	w := &complainingWriter{Writer: recorder}
+
+	testData := "test data"
+	if _, err := io.Copy(w, testReader(testData)); err == nil {
+		t.Error("error expected, none received")
+	}
+
+	w.CheerUp()
+	if _, err := io.Copy(w, testReader(testData)); err != nil {
+		t.Error("copy after CheerUp: %v", err)
+	}
+
+	if result := recorder.String(); result != testData {
+		t.Errorf("expected %q, got %q", testData, result)
+	}
+}
+
+func testReader(data string) io.Reader {
+	return iotest.OneByteReader(bytes.NewBuffer([]byte(data)))
+}
+
+type complainingWriter struct {
+	happy bool
+	io.Writer
+}
+
+func (comp *complainingWriter) Write(data []byte) (int, error) {
+	if comp.happy {
+		return comp.Writer.Write(data)
+	}
+
+	return 0, fmt.Errorf("I am unhappy about you wanting to write %q", data)
+}
+
+func (comp *complainingWriter) CheerUp() {
+	comp.happy = true
+}
-- 
GitLab