diff --git a/internal/redis/redis.go b/internal/redis/redis.go index 8c8d1809e74e3693efe415d03ca356faade2dc98..803052dc2f07b70f2f6eb230bb64425cac8f6340 100644 --- a/internal/redis/redis.go +++ b/internal/redis/redis.go @@ -13,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" ) var ( @@ -73,9 +74,10 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { } var addrs []string for _, url := range urls { - h := url.URL.Host + h := url.URL.String() log.WithFields(log.Fields{ - "host": h, + "scheme": url.URL.Scheme, + "host": url.URL.Host, }).Printf("redis: using sentinel") addrs = append(addrs, h) } @@ -88,7 +90,22 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { // For every address it should try to connect to the Sentinel, // using a short timeout (in the order of a few hundreds of milliseconds). timeout := 500 * time.Millisecond - c, err := redis.Dial("tcp", addr, redis.DialConnectTimeout(timeout), redis.DialReadTimeout(timeout), redis.DialWriteTimeout(timeout)) + url := helper.URLMustParse(addr) + + var c redis.Conn + var err error + options := []redis.DialOption{ + redis.DialConnectTimeout(timeout), + redis.DialReadTimeout(timeout), + redis.DialWriteTimeout(timeout), + } + + if url.Scheme == "redis" || url.Scheme == "redisss" { + c, err = redis.DialURL(addr, options...) + } else { + c, err = redis.Dial("tcp", url.Host, options...) + } + if err != nil { errorCounter.WithLabelValues("dial", "sentinel").Inc() return nil, err @@ -176,11 +193,27 @@ func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url if url.Scheme == "unix" { return redisDial(url.Scheme, url.Path, dopts...) } + dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod))) + + // redis.DialURL only works with redis[s]:// URLs + if url.Scheme == "redis" || url.Scheme == "rediss" { + return redisURLDial(url, dopts...) + } + return redisDial(url.Scheme, url.Host, dopts...) } } +func redisURLDial(url url.URL, options ...redis.DialOption) (redis.Conn, error) { + log.WithFields(log.Fields{ + "scheme": url.Scheme, + "address": url.Host, + }).Printf("redis: dialing") + + return redis.DialURL(url.String(), options...) +} + func redisDial(network, address string, options ...redis.DialOption) (redis.Conn, error) { log.WithFields(log.Fields{ "network": network, diff --git a/internal/redis/redis_test.go b/internal/redis/redis_test.go index e83c059b7ef17a265566433b09fbdb0d590e4316..b601ba205d58d4c53f5434c0cf3deb3b5117328d 100644 --- a/internal/redis/redis_test.go +++ b/internal/redis/redis_test.go @@ -1,6 +1,7 @@ package redis import ( + "net" "testing" "time" @@ -12,6 +13,22 @@ import ( "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" ) +func mockRedisServer(t *testing.T, connectReceived *bool) string { + ln, err := net.Listen("tcp", "127.0.0.1:0") + + assert.Nil(t, err) + + go func() { + defer ln.Close() + conn, err := ln.Accept() + assert.Nil(t, err) + *connectReceived = true + conn.Write([]byte("OK\n")) + }() + + return ln.Addr().String() +} + // Setup a MockPool for Redis // // Returns a teardown-function and the mock-connection @@ -28,6 +45,37 @@ func setupMockPool() (*redigomock.Conn, func()) { } } +func TestDefaultDialFunc(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "tcp", + }, + { + scheme: "redis", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + parsedURL := helper.URLMustParse(tc.scheme + "://" + a) + cfg := &config.RedisConfig{URL: config.TomlURL{URL: *parsedURL}} + + dialer := DefaultDialFunc(cfg, true) + conn, err := dialer() + + assert.Nil(t, err) + conn.Receive() + + assert.True(t, connectReceived) + }) + } +} + func TestConfigureNoConfig(t *testing.T) { pool = nil Configure(nil, nil) @@ -99,12 +147,54 @@ func TestSentinelConnNoSentinel(t *testing.T) { assert.Nil(t, s, "Sentinel without urls should return nil") } +func TestSentinelConnDialURL(t *testing.T) { + testCases := []struct { + scheme string + }{ + { + scheme: "tcp", + }, + { + scheme: "redis", + }, + } + + for _, tc := range testCases { + t.Run(tc.scheme, func(t *testing.T) { + connectReceived := false + a := mockRedisServer(t, &connectReceived) + + addrs := []string{tc.scheme + "://" + a} + var sentinelUrls []config.TomlURL + + for _, a := range addrs { + parsedURL := helper.URLMustParse(a) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + s := sentinelConn("foobar", sentinelUrls) + assert.Equal(t, len(addrs), len(s.Addrs)) + + for i := range addrs { + assert.Equal(t, addrs[i], s.Addrs[i]) + } + + conn, err := s.Dial(s.Addrs[0]) + + assert.Nil(t, err) + conn.Receive() + + assert.True(t, connectReceived) + }) + } +} + func TestSentinelConnTwoURLs(t *testing.T) { - addrs := []string{"10.0.0.1:12345", "10.0.0.2:12345"} + addrs := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"} var sentinelUrls []config.TomlURL for _, a := range addrs { - parsedURL := helper.URLMustParse(`tcp://` + a) + parsedURL := helper.URLMustParse(a) sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) }