diff --git a/changelogs/unreleased/jv-uploader-readerfrom.yml b/changelogs/unreleased/jv-uploader-readerfrom.yml new file mode 100644 index 0000000000000000000000000000000000000000..6f8a06865dc4f3d1b137a649f903412826d1b588 --- /dev/null +++ b/changelogs/unreleased/jv-uploader-readerfrom.yml @@ -0,0 +1,5 @@ +--- +title: Push uploader control flow into objectstore package +merge_request: 608 +author: +type: other diff --git a/internal/filestore/file_handler.go b/internal/filestore/file_handler.go index fc70e89dcf91bc11b395dd035a1b3c996af98f9a..935eb3b7f3b2f3c285a73c3ed0eb3bd43fec420c 100644 --- a/internal/filestore/file_handler.go +++ b/internal/filestore/file_handler.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "os" "strconv" + "time" "github.com/dgrijalva/jwt-go" @@ -98,38 +99,28 @@ func (fh *FileHandler) GitLabFinalizeFields(prefix string) (map[string]string, e return data, nil } -// Upload represents a destination where we store an upload -type uploadWriter interface { - io.WriteCloser - CloseWithError(error) error - ETag() string +type consumer interface { + Consume(context.Context, io.Reader, time.Time) (int64, error) } // SaveFileFromReader persists the provided reader content to all the location specified in opts. A cleanup will be performed once ctx is Done // Make sure the provided context will not expire before finalizing upload with GitLab Rails. func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts *SaveFileOpts) (fh *FileHandler, err error) { - var uploadWriter uploadWriter + var uploadDestination consumer fh = &FileHandler{ Name: opts.TempFilePrefix, RemoteID: opts.RemoteID, RemoteURL: opts.RemoteURL, } hashes := newMultiHash() - writers := []io.Writer{hashes.Writer} - defer func() { - for _, w := range writers { - if closer, ok := w.(io.WriteCloser); ok { - closer.Close() - } - } - }() + reader = io.TeeReader(reader, hashes.Writer) var clientMode string switch { case opts.IsLocal(): clientMode = "local" - uploadWriter, err = fh.uploadLocalFile(ctx, opts) + uploadDestination, err = fh.uploadLocalFile(ctx, opts) case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsGoCloud(): clientMode = fmt.Sprintf("go_cloud:%s", opts.ObjectStorageConfig.Provider) p := &objectstore.GoCloudObjectParams{ @@ -137,38 +128,31 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts Mux: opts.ObjectStorageConfig.URLMux, BucketURL: opts.ObjectStorageConfig.GoCloudConfig.URL, ObjectName: opts.RemoteTempObjectID, - Deadline: opts.Deadline, } - uploadWriter, err = objectstore.NewGoCloudObject(p) + uploadDestination, err = objectstore.NewGoCloudObject(p) case opts.UseWorkhorseClientEnabled() && opts.ObjectStorageConfig.IsAWS() && opts.ObjectStorageConfig.IsValid(): clientMode = "s3" - uploadWriter, err = objectstore.NewS3Object( - ctx, + uploadDestination, err = objectstore.NewS3Object( opts.RemoteTempObjectID, opts.ObjectStorageConfig.S3Credentials, opts.ObjectStorageConfig.S3Config, - opts.Deadline, ) case opts.IsMultipart(): clientMode = "multipart" - uploadWriter, err = objectstore.NewMultipart( - ctx, + uploadDestination, err = objectstore.NewMultipart( opts.PresignedParts, opts.PresignedCompleteMultipart, opts.PresignedAbortMultipart, opts.PresignedDelete, opts.PutHeaders, - opts.Deadline, opts.PartSize, ) default: clientMode = "http" - uploadWriter, err = objectstore.NewObject( - ctx, + uploadDestination, err = objectstore.NewObject( opts.PresignedPut, opts.PresignedDelete, opts.PutHeaders, - opts.Deadline, size, ) } @@ -177,34 +161,22 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts return nil, err } - writers = append(writers, uploadWriter) - - defer func() { - if err != nil { - uploadWriter.CloseWithError(err) - } - }() - if opts.MaximumSize > 0 { if size > opts.MaximumSize { return nil, SizeError(fmt.Errorf("the upload size %d is over maximum of %d bytes", size, opts.MaximumSize)) } - // We allow to read an extra byte to check later if we exceed the max size - reader = &io.LimitedReader{R: reader, N: opts.MaximumSize + 1} + reader = &hardLimitReader{r: reader, n: opts.MaximumSize} } - multiWriter := io.MultiWriter(writers...) - fh.Size, err = io.Copy(multiWriter, reader) + fh.Size, err = uploadDestination.Consume(ctx, reader, opts.Deadline) if err != nil { + if err == objectstore.ErrNotEnoughParts { + err = ErrEntityTooLarge + } return nil, err } - if opts.MaximumSize > 0 && fh.Size > opts.MaximumSize { - // An extra byte was read thus exceeding the max size - return nil, ErrEntityTooLarge - } - if size != -1 && size != fh.Size { return nil, SizeError(fmt.Errorf("expected %d bytes but got only %d", size, fh.Size)) } @@ -226,25 +198,11 @@ func SaveFileFromReader(ctx context.Context, reader io.Reader, size int64, opts } logger.Info("saved file") - fh.hashes = hashes.finish() - - // we need to close the writer in order to get ETag header - err = uploadWriter.Close() - if err != nil { - if err == objectstore.ErrNotEnoughParts { - return nil, ErrEntityTooLarge - } - return nil, err - } - - etag := uploadWriter.ETag() - fh.hashes["etag"] = etag - - return fh, err + return fh, nil } -func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) (uploadWriter, error) { +func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) (consumer, error) { // make sure TempFolder exists err := os.MkdirAll(opts.LocalTempPath, 0700) if err != nil { @@ -262,13 +220,19 @@ func (fh *FileHandler) uploadLocalFile(ctx context.Context, opts *SaveFileOpts) }() fh.LocalPath = file.Name() - return &nopUpload{file}, nil + return &localUpload{file}, nil } -type nopUpload struct{ io.WriteCloser } +type localUpload struct{ io.WriteCloser } -func (nop *nopUpload) CloseWithError(error) error { return nop.Close() } -func (nop *nopUpload) ETag() string { return "" } +func (loc *localUpload) Consume(_ context.Context, r io.Reader, _ time.Time) (int64, error) { + n, err := io.Copy(loc.WriteCloser, r) + errClose := loc.Close() + if err == nil { + err = errClose + } + return n, err +} // SaveFileFromDisk open the local file fileName and calls SaveFileFromReader func SaveFileFromDisk(ctx context.Context, fileName string, opts *SaveFileOpts) (fh *FileHandler, err error) { diff --git a/internal/filestore/file_handler_test.go b/internal/filestore/file_handler_test.go index 9f7ebba3b43f7b467ac946f929bdd369fa95b735..6967c4b223e7293778f056df9601975656b5f04c 100644 --- a/internal/filestore/file_handler_test.go +++ b/internal/filestore/file_handler_test.go @@ -413,5 +413,4 @@ func checkFileHandlerWithFields(t *testing.T, fh *filestore.FileHandler, fields require.Equal(t, test.ObjectSHA1, fields[key("sha1")]) require.Equal(t, test.ObjectSHA256, fields[key("sha256")]) require.Equal(t, test.ObjectSHA512, fields[key("sha512")]) - require.Contains(t, fields, key("etag")) } diff --git a/internal/filestore/reader.go b/internal/filestore/reader.go new file mode 100644 index 0000000000000000000000000000000000000000..b1045b991fc06cb3b04745094c25b04cc2ae5263 --- /dev/null +++ b/internal/filestore/reader.go @@ -0,0 +1,17 @@ +package filestore + +import "io" + +type hardLimitReader struct { + r io.Reader + n int64 +} + +func (h *hardLimitReader) Read(p []byte) (int, error) { + nRead, err := h.r.Read(p) + h.n -= int64(nRead) + if h.n < 0 { + err = ErrEntityTooLarge + } + return nRead, err +} diff --git a/internal/filestore/reader_test.go b/internal/filestore/reader_test.go new file mode 100644 index 0000000000000000000000000000000000000000..424d921ecaf70078728b79fb30d2caa1a893e279 --- /dev/null +++ b/internal/filestore/reader_test.go @@ -0,0 +1,46 @@ +package filestore + +import ( + "fmt" + "io/ioutil" + "strings" + "testing" + "testing/iotest" + + "github.com/stretchr/testify/require" +) + +func TestHardLimitReader(t *testing.T) { + const text = "hello world" + r := iotest.OneByteReader( + &hardLimitReader{ + r: strings.NewReader(text), + n: int64(len(text)), + }, + ) + + out, err := ioutil.ReadAll(r) + require.NoError(t, err) + require.Equal(t, text, string(out)) +} + +func TestHardLimitReaderFail(t *testing.T) { + const text = "hello world" + + for bufSize := len(text) / 2; bufSize < len(text)*2; bufSize++ { + t.Run(fmt.Sprintf("bufsize:%d", bufSize), func(t *testing.T) { + r := &hardLimitReader{ + r: iotest.DataErrReader(strings.NewReader(text)), + n: int64(len(text)) - 1, + } + buf := make([]byte, bufSize) + + var err error + for i := 0; err == nil && i < 1000; i++ { + _, err = r.Read(buf) + } + + require.Equal(t, ErrEntityTooLarge, err) + }) + } +} diff --git a/internal/objectstore/gocloud_object.go b/internal/objectstore/gocloud_object.go index 3cb05d431b673391f5180d7ad9eb1dcccd6983af..38545086994b16a86f6c8af082d9a81555d32331 100644 --- a/internal/objectstore/gocloud_object.go +++ b/internal/objectstore/gocloud_object.go @@ -15,7 +15,7 @@ type GoCloudObject struct { mux *blob.URLMux bucketURL string objectName string - uploader + *uploader } type GoCloudObjectParams struct { @@ -23,7 +23,6 @@ type GoCloudObjectParams struct { Mux *blob.URLMux BucketURL string ObjectName string - Deadline time.Time } func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) { @@ -40,8 +39,6 @@ func NewGoCloudObject(p *GoCloudObjectParams) (*GoCloudObject, error) { } o.uploader = newUploader(o) - o.Execute(p.Ctx, p.Deadline) - return o, nil } diff --git a/internal/objectstore/gocloud_object_test.go b/internal/objectstore/gocloud_object_test.go index f9260c3009dfbfcb130b88c190866195feb2eb94..4dc9d2d75cc3fe1d6971fcb795524ed11ae5c5c4 100644 --- a/internal/objectstore/gocloud_object_test.go +++ b/internal/objectstore/gocloud_object_test.go @@ -3,7 +3,6 @@ package objectstore_test import ( "context" "fmt" - "io" "strings" "testing" "time" @@ -24,20 +23,16 @@ func TestGoCloudObjectUpload(t *testing.T) { objectName := "test.png" testURL := "azuretest://azure.example.com/test-container" - p := &objectstore.GoCloudObjectParams{Ctx: ctx, Mux: mux, BucketURL: testURL, ObjectName: objectName, Deadline: deadline} + p := &objectstore.GoCloudObjectParams{Ctx: ctx, Mux: mux, BucketURL: testURL, ObjectName: objectName} object, err := objectstore.NewGoCloudObject(p) require.NotNil(t, object) require.NoError(t, err) // copy data - n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) require.NoError(t, err) require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") - // close HTTP stream - err = object.Close() - require.NoError(t, err) - bucket, err := mux.OpenBucket(ctx, testURL) require.NoError(t, err) diff --git a/internal/objectstore/multipart.go b/internal/objectstore/multipart.go index 8ab936d3d023a73b57b0b246a7a8c8246493e7c0..fd1c0ed487dcf3b8367dda083cfe17ad1165d5d1 100644 --- a/internal/objectstore/multipart.go +++ b/internal/objectstore/multipart.go @@ -10,7 +10,6 @@ import ( "io/ioutil" "net/http" "os" - "time" "gitlab.com/gitlab-org/labkit/log" "gitlab.com/gitlab-org/labkit/mask" @@ -33,13 +32,13 @@ type Multipart struct { partSize int64 etag string - uploader + *uploader } // NewMultipart provides Multipart pointer that can be used for uploading. Data written will be split buffered on disk up to size bytes // then uploaded with S3 Upload Part. Once Multipart is Closed a final call to CompleteMultipartUpload will be sent. // In case of any error a call to AbortMultipartUpload will be made to cleanup all the resources -func NewMultipart(ctx context.Context, partURLs []string, completeURL, abortURL, deleteURL string, putHeaders map[string]string, deadline time.Time, partSize int64) (*Multipart, error) { +func NewMultipart(partURLs []string, completeURL, abortURL, deleteURL string, putHeaders map[string]string, partSize int64) (*Multipart, error) { m := &Multipart{ PartURLs: partURLs, CompleteURL: completeURL, @@ -50,8 +49,6 @@ func NewMultipart(ctx context.Context, partURLs []string, completeURL, abortURL, } m.uploader = newUploader(m) - m.Execute(ctx, deadline) - return m, nil } @@ -109,7 +106,7 @@ func (m *Multipart) readAndUploadOnePart(ctx context.Context, partURL string, pu n, err := io.Copy(file, src) if err != nil { - return nil, fmt.Errorf("write part %d to disk: %v", partNumber, err) + return nil, err } if n == 0 { return nil, nil @@ -132,18 +129,15 @@ func (m *Multipart) uploadPart(ctx context.Context, url string, headers map[stri return "", fmt.Errorf("missing deadline") } - part, err := newObject(ctx, url, "", headers, deadline, size, false) + part, err := newObject(url, "", headers, size, false) if err != nil { return "", err } - _, err = io.CopyN(part, body, size) - if err != nil { - return "", err - } - - err = part.Close() - if err != nil { + if n, err := part.Consume(ctx, io.LimitReader(body, size), deadline); err != nil || n < size { + if err == nil { + err = io.ErrUnexpectedEOF + } return "", err } diff --git a/internal/objectstore/multipart_test.go b/internal/objectstore/multipart_test.go index c7b6e4ded7611c576e855a06fe7cd07898c075f2..00d6efc0982d43bec85b41b6d465c7177ab61ff3 100644 --- a/internal/objectstore/multipart_test.go +++ b/internal/objectstore/multipart_test.go @@ -48,19 +48,17 @@ func TestMultipartUploadWithUpcaseETags(t *testing.T) { deadline := time.Now().Add(testTimeout) - m, err := objectstore.NewMultipart(ctx, + m, err := objectstore.NewMultipart( []string{ts.URL}, // a single presigned part URL ts.URL, // the complete multipart upload URL "", // no abort "", // no delete map[string]string{}, // no custom headers - deadline, - test.ObjectSize) // parts size equal to the whole content. Only 1 part + test.ObjectSize) // parts size equal to the whole content. Only 1 part require.NoError(t, err) - _, err = m.Write([]byte(test.ObjectContent)) + _, err = m.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) require.NoError(t, err) - require.NoError(t, m.Close()) require.Equal(t, 1, putCnt, "1 part expected") require.Equal(t, 1, postCnt, "1 complete multipart upload expected") } diff --git a/internal/objectstore/object.go b/internal/objectstore/object.go index 2a6bd8004d3698480594cbca713f809480a5964a..eaf3bfb2e3612b1c0d2bf09755ad65c71c73d80e 100644 --- a/internal/objectstore/object.go +++ b/internal/objectstore/object.go @@ -47,17 +47,17 @@ type Object struct { etag string metrics bool - uploader + *uploader } type StatusCodeError error // NewObject opens an HTTP connection to Object Store and returns an Object pointer that can be used for uploading. -func NewObject(ctx context.Context, putURL, deleteURL string, putHeaders map[string]string, deadline time.Time, size int64) (*Object, error) { - return newObject(ctx, putURL, deleteURL, putHeaders, deadline, size, true) +func NewObject(putURL, deleteURL string, putHeaders map[string]string, size int64) (*Object, error) { + return newObject(putURL, deleteURL, putHeaders, size, true) } -func newObject(ctx context.Context, putURL, deleteURL string, putHeaders map[string]string, deadline time.Time, size int64, metrics bool) (*Object, error) { +func newObject(putURL, deleteURL string, putHeaders map[string]string, size int64, metrics bool) (*Object, error) { o := &Object{ putURL: putURL, deleteURL: deleteURL, @@ -66,9 +66,7 @@ func newObject(ctx context.Context, putURL, deleteURL string, putHeaders map[str metrics: metrics, } - o.uploader = newMD5Uploader(o, metrics) - o.Execute(ctx, deadline) - + o.uploader = newETagCheckUploader(o, metrics) return o, nil } diff --git a/internal/objectstore/object_test.go b/internal/objectstore/object_test.go index 21888f8c9c8dcea0e5582518885f5c30bdb930b6..2ec45520e976cb52bae4d520a9679a5cf8465b1b 100644 --- a/internal/objectstore/object_test.go +++ b/internal/objectstore/object_test.go @@ -35,18 +35,14 @@ func testObjectUploadNoErrors(t *testing.T, startObjectStore osFactory, useDelet defer cancel() deadline := time.Now().Add(testTimeout) - object, err := objectstore.NewObject(ctx, objectURL, deleteURL, putHeaders, deadline, test.ObjectSize) + object, err := objectstore.NewObject(objectURL, deleteURL, putHeaders, test.ObjectSize) require.NoError(t, err) // copy data - n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) require.NoError(t, err) require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") - // close HTTP stream - err = object.Close() - require.NoError(t, err) - require.Equal(t, contentType, osStub.GetHeader(test.ObjectPath, "Content-Type")) // Checking MD5 extraction @@ -107,12 +103,10 @@ func TestObjectUpload404(t *testing.T) { deadline := time.Now().Add(testTimeout) objectURL := ts.URL + test.ObjectPath - object, err := objectstore.NewObject(ctx, objectURL, "", map[string]string{}, deadline, test.ObjectSize) + object, err := objectstore.NewObject(objectURL, "", map[string]string{}, test.ObjectSize) require.NoError(t, err) - _, err = io.Copy(object, strings.NewReader(test.ObjectContent)) + _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) - require.NoError(t, err) - err = object.Close() require.Error(t, err) _, isStatusCodeError := err.(objectstore.StatusCodeError) require.True(t, isStatusCodeError, "Should fail with StatusCodeError") @@ -152,13 +146,10 @@ func TestObjectUploadBrokenConnection(t *testing.T) { deadline := time.Now().Add(testTimeout) objectURL := ts.URL + test.ObjectPath - object, err := objectstore.NewObject(ctx, objectURL, "", map[string]string{}, deadline, -1) + object, err := objectstore.NewObject(objectURL, "", map[string]string{}, -1) require.NoError(t, err) - _, copyErr := io.Copy(object, &endlessReader{}) + _, copyErr := object.Consume(ctx, &endlessReader{}, deadline) require.Error(t, copyErr) require.NotEqual(t, io.ErrClosedPipe, copyErr, "We are shadowing the real error") - - closeErr := object.Close() - require.Equal(t, copyErr, closeErr) } diff --git a/internal/objectstore/s3_object.go b/internal/objectstore/s3_object.go index 7444283bbc7d60cd399958c5d56d77eafabe761a..d29ecaea3453d3233c170210560cae9f5f76997c 100644 --- a/internal/objectstore/s3_object.go +++ b/internal/objectstore/s3_object.go @@ -19,10 +19,10 @@ type S3Object struct { objectName string uploaded bool - uploader + *uploader } -func NewS3Object(ctx context.Context, objectName string, s3Credentials config.S3Credentials, s3Config config.S3Config, deadline time.Time) (*S3Object, error) { +func NewS3Object(objectName string, s3Credentials config.S3Credentials, s3Config config.S3Config) (*S3Object, error) { o := &S3Object{ credentials: s3Credentials, config: s3Config, @@ -30,8 +30,6 @@ func NewS3Object(ctx context.Context, objectName string, s3Credentials config.S3 } o.uploader = newUploader(o) - o.Execute(ctx, deadline) - return o, nil } diff --git a/internal/objectstore/s3_object_test.go b/internal/objectstore/s3_object_test.go index 1f4e530321b6d893a9865f5eb5679c7131f8a3ed..86b1827934809394256e454d6c8b2c218ebfefb9 100644 --- a/internal/objectstore/s3_object_test.go +++ b/internal/objectstore/s3_object_test.go @@ -3,7 +3,6 @@ package objectstore_test import ( "context" "fmt" - "io" "io/ioutil" "os" "path/filepath" @@ -44,18 +43,14 @@ func TestS3ObjectUpload(t *testing.T) { objectName := filepath.Join(tmpDir, "s3-test-data") ctx, cancel := context.WithCancel(context.Background()) - object, err := objectstore.NewS3Object(ctx, objectName, creds, config, deadline) + object, err := objectstore.NewS3Object(objectName, creds, config) require.NoError(t, err) // copy data - n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) require.NoError(t, err) require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") - // close HTTP stream - err = object.Close() - require.NoError(t, err) - test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent) test.CheckS3Metadata(t, sess, config, objectName) @@ -107,17 +102,14 @@ func TestConcurrentS3ObjectUpload(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - object, err := objectstore.NewS3Object(ctx, objectName, creds, config, deadline) + object, err := objectstore.NewS3Object(objectName, creds, config) require.NoError(t, err) // copy data - n, err := io.Copy(object, strings.NewReader(test.ObjectContent)) + n, err := object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) require.NoError(t, err) require.Equal(t, test.ObjectSize, n, "Uploaded file mismatch") - // close HTTP stream - require.NoError(t, object.Close()) - test.S3ObjectExists(t, sess, config, objectName, test.ObjectContent) wg.Done() }(i) @@ -139,7 +131,7 @@ func TestS3ObjectUploadCancel(t *testing.T) { objectName := filepath.Join(tmpDir, "s3-test-data") - object, err := objectstore.NewS3Object(ctx, objectName, creds, config, deadline) + object, err := objectstore.NewS3Object(objectName, creds, config) require.NoError(t, err) @@ -147,6 +139,6 @@ func TestS3ObjectUploadCancel(t *testing.T) { // we handle this gracefully. cancel() - _, err = io.Copy(object, strings.NewReader(test.ObjectContent)) + _, err = object.Consume(ctx, strings.NewReader(test.ObjectContent), deadline) require.Error(t, err) } diff --git a/internal/objectstore/uploader.go b/internal/objectstore/uploader.go index 82f420dd0eae3e06ccb6b9858f6c5a2463dc04ca..fb93cb836308799c98649ca56c77be28472c999a 100644 --- a/internal/objectstore/uploader.go +++ b/internal/objectstore/uploader.go @@ -8,177 +8,89 @@ import ( "hash" "io" "strings" - "sync" "time" "gitlab.com/gitlab-org/labkit/log" ) -// uploader is an io.WriteCloser that can be used as write end of the uploading pipe. +// uploader consumes an io.Reader and uploads it using a pluggable uploadStrategy. type uploader struct { - // etag is the object storage provided checksum - etag string - - // md5 is an optional hasher for calculating md5 on the fly - md5 hash.Hash - - w io.Writer - - // uploadError is the last error occourred during upload - uploadError error - // ctx is the internal context bound to the upload request - ctx context.Context - - pr *io.PipeReader - pw *io.PipeWriter strategy uploadStrategy - metrics bool - // closeOnce is used to prevent multiple calls to pw.Close - // which may result to Close overriding the error set by CloseWithError - // Bug fixed in v1.14: https://github.com/golang/go/commit/f45eb9ff3c96dfd951c65d112d033ed7b5e02432 - closeOnce sync.Once -} + // In the case of S3 uploads, we have a multipart upload which + // instantiates uploads for the individual parts. We don't want to + // increment metrics for the individual parts, so that is why we have + // this boolean flag. + metrics bool -func newUploader(strategy uploadStrategy) uploader { - pr, pw := io.Pipe() - return uploader{w: pw, pr: pr, pw: pw, strategy: strategy, metrics: true} + // With S3 we compare the MD5 of the data we sent with the ETag returned + // by the object storage server. + checkETag bool } -func newMD5Uploader(strategy uploadStrategy, metrics bool) uploader { - pr, pw := io.Pipe() - hasher := md5.New() - mw := io.MultiWriter(pw, hasher) - return uploader{w: mw, pr: pr, pw: pw, md5: hasher, strategy: strategy, metrics: metrics} +func newUploader(strategy uploadStrategy) *uploader { + return &uploader{strategy: strategy, metrics: true} } -// Close implements the standard io.Closer interface: it closes the http client request. -// This method will also wait for the connection to terminate and return any error occurred during the upload -func (u *uploader) Close() error { - var closeError error - u.closeOnce.Do(func() { - closeError = u.pw.Close() - }) - if closeError != nil { - return closeError - } - - <-u.ctx.Done() - - if err := u.ctx.Err(); err == context.DeadlineExceeded { - return err - } - - return u.uploadError +func newETagCheckUploader(strategy uploadStrategy, metrics bool) *uploader { + return &uploader{strategy: strategy, metrics: metrics, checkETag: true} } -func (u *uploader) CloseWithError(err error) error { - u.closeOnce.Do(func() { - u.pw.CloseWithError(err) - }) - - return nil -} +func hexString(h hash.Hash) string { return hex.EncodeToString(h.Sum(nil)) } -func (u *uploader) Write(p []byte) (int, error) { - return u.w.Write(p) -} - -func (u *uploader) md5Sum() string { - if u.md5 == nil { - return "" - } - - checksum := u.md5.Sum(nil) - return hex.EncodeToString(checksum) -} - -// ETag returns the checksum of the uploaded object returned by the ObjectStorage provider via ETag Header. -// This method will wait until upload context is done before returning. -func (u *uploader) ETag() string { - <-u.ctx.Done() - - return u.etag -} - -func (u *uploader) Execute(ctx context.Context, deadline time.Time) { +// Consume reads the reader until it reaches EOF or an error. It spawns a +// goroutine that waits for outerCtx to be done, after which the remote +// file is deleted. The deadline applies to the upload performed inside +// Consume, not to outerCtx. +func (u *uploader) Consume(outerCtx context.Context, reader io.Reader, deadline time.Time) (_ int64, err error) { if u.metrics { objectStorageUploadsOpen.Inc() + defer func(started time.Time) { + objectStorageUploadsOpen.Dec() + objectStorageUploadTime.Observe(time.Since(started).Seconds()) + if err != nil { + objectStorageUploadRequestsRequestFailed.Inc() + } + }(time.Now()) } - uploadCtx, cancelFn := context.WithDeadline(ctx, deadline) - u.ctx = uploadCtx - - if u.metrics { - go u.trackUploadTime() - } - - uploadDone := make(chan struct{}) - go u.cleanup(ctx, uploadDone) - go func() { - defer cancelFn() - defer close(uploadDone) - - if u.metrics { - defer objectStorageUploadsOpen.Dec() - } - defer func() { - // This will be returned as error to the next write operation on the pipe - u.pr.CloseWithError(u.uploadError) - }() - err := u.strategy.Upload(uploadCtx, u.pr) + defer func() { + // We do this mainly to abort S3 multipart uploads: it is not enough to + // "delete" them. if err != nil { - u.uploadError = err - if u.metrics { - objectStorageUploadRequestsRequestFailed.Inc() - } - return + u.strategy.Abort() } + }() - u.etag = u.strategy.ETag() - - if u.md5 != nil { - err := compareMD5(u.md5Sum(), u.etag) - if err != nil { - log.ContextLogger(ctx).WithError(err).Error("error comparing MD5 checksum") - - u.uploadError = err - if u.metrics { - objectStorageUploadRequestsRequestFailed.Inc() - } - } - } + go func() { + // Once gitlab-rails is done handling the request, we are supposed to + // delete the upload from its temporary location. + <-outerCtx.Done() + u.strategy.Delete() }() -} -func (u *uploader) trackUploadTime() { - started := time.Now() - <-u.ctx.Done() + uploadCtx, cancelFn := context.WithDeadline(outerCtx, deadline) + defer cancelFn() - if u.metrics { - objectStorageUploadTime.Observe(time.Since(started).Seconds()) + var hasher hash.Hash + if u.checkETag { + hasher = md5.New() + reader = io.TeeReader(reader, hasher) } -} -func (u *uploader) cleanup(ctx context.Context, uploadDone chan struct{}) { - // wait for the upload to finish - <-u.ctx.Done() + cr := &countReader{r: reader} + if err := u.strategy.Upload(uploadCtx, cr); err != nil { + return cr.n, err + } - <-uploadDone - if u.uploadError != nil { - if u.metrics { - objectStorageUploadRequestsRequestFailed.Inc() + if u.checkETag { + if err := compareMD5(hexString(hasher), u.strategy.ETag()); err != nil { + log.ContextLogger(uploadCtx).WithError(err).Error("error comparing MD5 checksum") + return cr.n, err } - u.strategy.Abort() - return } - // We have now successfully uploaded the file to object storage. Another - // goroutine will hand off the object to gitlab-rails. - <-ctx.Done() - - // gitlab-rails is now done with the object so it's time to delete it. - u.strategy.Delete() + return cr.n, nil } func compareMD5(local, remote string) error { @@ -188,3 +100,14 @@ func compareMD5(local, remote string) error { return nil } + +type countReader struct { + r io.Reader + n int64 +} + +func (cr *countReader) Read(p []byte) (int, error) { + nRead, err := cr.r.Read(p) + cr.n += int64(nRead) + return nRead, err +} diff --git a/internal/upload/uploads_test.go b/internal/upload/uploads_test.go index c7da72649c56facda6e6e95d76d7abaa974d5045..fc1a1ac57ef55cbfb0b88ae2a9dd22b2579e5605 100644 --- a/internal/upload/uploads_test.go +++ b/internal/upload/uploads_test.go @@ -123,7 +123,7 @@ func TestUploadHandlerRewritingMultiPartData(t *testing.T) { require.Equal(t, hash, r.FormValue("file."+algo), "file hash %s", algo) } - require.Len(t, r.MultipartForm.Value, 12, "multipart form values") + require.Len(t, r.MultipartForm.Value, 11, "multipart form values") w.WriteHeader(202) fmt.Fprint(w, "RESPONSE") diff --git a/upload_test.go b/upload_test.go index f61ec04f867bf673dfd7bd3d84dcda7e1c079d58..c08d04b100a3f2539d002d8c8df37e11dbadc113 100644 --- a/upload_test.go +++ b/upload_test.go @@ -79,7 +79,7 @@ func uploadTestServer(t *testing.T, extraTests func(r *http.Request)) *httptest. require.NoError(t, r.ParseMultipartForm(100000)) - const nValues = 11 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload, etag for just the upload (no metadata because we are not POSTing a valid zip file) + const nValues = 10 // file name, path, remote_url, remote_id, size, md5, sha1, sha256, sha512, gitlab-workhorse-upload for just the upload (no metadata because we are not POSTing a valid zip file) require.Len(t, r.MultipartForm.Value, nValues) require.Empty(t, r.MultipartForm.File, "multipart form files")