diff --git a/README.md b/README.md index ed27ba9a197d15fc3dedbd855d155c872183fc8c..317e662b3c1bbe3eaff1e61bc0e32b289ec80ed7 100644 --- a/README.md +++ b/README.md @@ -105,12 +105,16 @@ SentinelMaster = "mymaster" Optional fields are as follows: ``` [redis] -ReadTimeout = 1000 +DB = 0 +ReadTimeout = "1s" +KeepAlivePeriod = "5m" MaxIdle = 1 MaxActive = 1 ``` -- `ReadTimeout` is how many milliseconds that a redis read-command can take. Defaults to `1000` +- `DB` is the Database to connect to. Defaults to `0` +- `ReadTimeout` is how long a redis read-command can take. Defaults to `1s` +- `KeepAlivePeriod` is how long the redis connection is to be kept alive without anything flowing through it. Defaults to `5m` - `MaxIdle` is how many idle connections can be in the redis-pool at once. Defaults to 1 - `MaxActive` is how many connections the pool can keep. Defaults to 1 diff --git a/internal/config/config.go b/internal/config/config.go index dfb5b4fdeb2c02f3adf9d1d913f015346b64f0b2..3c7a912626d42f304f2d3023664d583b21e538e3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,14 +17,27 @@ func (u *TomlURL) UnmarshalText(text []byte) error { return err } +type TomlDuration struct { + time.Duration +} + +func (d *TomlDuration) UnmarshalTest(text []byte) error { + temp, err := time.ParseDuration(string(text)) + d.Duration = temp + return err +} + type RedisConfig struct { - URL TomlURL - Sentinel []TomlURL - SentinelMaster string - Password string - ReadTimeout *int - MaxIdle *int - MaxActive *int + URL TomlURL + Sentinel []TomlURL + SentinelMaster string + Password string + DB *int + ReadTimeout *TomlDuration + WriteTimeout *TomlDuration + KeepAlivePeriod *TomlDuration + MaxIdle *int + MaxActive *int } type Config struct { diff --git a/internal/redis/keywatcher.go b/internal/redis/keywatcher.go index 0b6227e78ecb7ef0ee1c2b3dc238b57f13876e81..6c7b465c77cfe8f0fc9fcb1c834a6a1b0b556a9e 100644 --- a/internal/redis/keywatcher.go +++ b/internal/redis/keywatcher.go @@ -1,7 +1,6 @@ package redis import ( - "errors" "fmt" "log" "strings" @@ -34,7 +33,7 @@ var ( totalMessages = prometheus.NewCounter( prometheus.CounterOpts{ Name: "gitlab_workhorse_keywather_total_messages", - Help: "How many messages gitlab-workhorse has recieved in total on pubsub.", + Help: "How many messages gitlab-workhorse has received in total on pubsub.", }, ) ) @@ -58,13 +57,11 @@ type KeyChan struct { Chan chan string } -func processInner(conn redis.Conn) { - redisReconnectTimeout.Reset() - +func processInner(conn redis.Conn) error { defer conn.Close() psc := redis.PubSubConn{Conn: conn} if err := psc.Subscribe(keySubChannel); err != nil { - return + return err } defer psc.Unsubscribe(keySubChannel) @@ -72,20 +69,38 @@ func processInner(conn redis.Conn) { switch v := psc.Receive().(type) { case redis.Message: totalMessages.Inc() - msg := strings.SplitN(string(v.Data), "=", 2) + dataStr := string(v.Data) + msg := strings.SplitN(dataStr, "=", 2) if len(msg) != 2 { - helper.LogError(nil, errors.New("Redis subscribe error: got an invalid notification")) + helper.LogError(nil, fmt.Errorf("Redis receive error: got an invalid notification: %q", dataStr)) continue } key, value := msg[0], msg[1] notifyChanWatchers(key, value) case error: - helper.LogError(nil, fmt.Errorf("Redis subscribe error: %s", v)) - return + helper.LogError(nil, fmt.Errorf("Redis receive error: %s", v)) + // Intermittent error, return nil so that it doesn't wait before reconnect + return nil } } } +func dialPubSub(dialer redisDialerFunc) (redis.Conn, error) { + conn, err := dialer() + if err != nil { + return nil, err + } + + // Make sure Redis is actually connected + conn.Do("PING") + if err := conn.Err(); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + // Process redis subscriptions // // NOTE: There Can Only Be One! @@ -97,13 +112,19 @@ func Process(reconnect bool) { for loop { loop = reconnect log.Println("Connecting to redis") - conn, err := redisDialFunc() + + conn, err := dialPubSub(workerDialFunc) if err != nil { helper.LogError(nil, fmt.Errorf("Failed to connect to redis: %s", err)) time.Sleep(redisReconnectTimeout.Duration()) continue } - processInner(conn) + redisReconnectTimeout.Reset() + + if err = processInner(conn); err != nil { + helper.LogError(nil, fmt.Errorf("Failed to process redis-queue: %s", err)) + continue + } } } diff --git a/internal/redis/keywatcher_test.go b/internal/redis/keywatcher_test.go index 68c459586e27446b92cdfe1a807bdbf47b9e7add..fd3a73ab00d18382d4e917d26b68925698a9ec34 100644 --- a/internal/redis/keywatcher_test.go +++ b/internal/redis/keywatcher_test.go @@ -103,7 +103,6 @@ func TestWatchKeyNoChange(t *testing.T) { processMessages(1, "something") wg.Wait() - } func TestWatchKeyTimeout(t *testing.T) { diff --git a/internal/redis/redis.go b/internal/redis/redis.go index 7e0825cd147313f79ae2c3e7b26c102b731e73aa..5e6347e71b892b47bef7903edcd1855d894f4768 100644 --- a/internal/redis/redis.go +++ b/internal/redis/redis.go @@ -3,6 +3,8 @@ package redis import ( "errors" "fmt" + "net" + "net/url" "time" "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" @@ -18,10 +20,26 @@ var ( ) const ( - defaultMaxIdle = 1 - defaultMaxActive = 1 + // Max Idle Connections in the pool. + defaultMaxIdle = 1 + // Max Active Connections in the pool. + defaultMaxActive = 1 + // Timeout for Read operations on the pool. 1 second is technically overkill, + // it's just for sanity. defaultReadTimeout = 1 * time.Second + // Timeout for Write operations on the pool. 1 second is technically overkill, + // it's just for sanity. + defaultWriteTimeout = 1 * time.Second + // Timeout before killing Idle connections in the pool. 3 minutes seemed good. + // If you _actually_ hit this timeout often, you should consider turning of + // redis-support since it's not necessary at that point... defaultIdleTimeout = 3 * time.Minute + // KeepAlivePeriod is to keep a TCP connection open for an extended period of + // time without being killed. This is used both in the pool, and in the + // worker-connection. + // See https://en.wikipedia.org/wiki/Keepalive#TCP_keepalive for more + // information. + defaultKeepAlivePeriod = 5 * time.Minute ) var ( @@ -65,37 +83,91 @@ func sentinelConn(master string, urls []config.TomlURL) *sentinel.Sentinel { } } -var redisDialFunc func() (redis.Conn, error) +var poolDialFunc func() (redis.Conn, error) +var workerDialFunc func() (redis.Conn, error) -func dialOptionsBuilder(cfg *config.RedisConfig) []redis.DialOption { +func timeoutDialOptions(cfg *config.RedisConfig) []redis.DialOption { readTimeout := defaultReadTimeout - if cfg.ReadTimeout != nil { - readTimeout = time.Millisecond * time.Duration(*cfg.ReadTimeout) + writeTimeout := defaultWriteTimeout + + if cfg != nil { + if cfg.ReadTimeout != nil { + readTimeout = cfg.ReadTimeout.Duration + } + + if cfg.WriteTimeout != nil { + writeTimeout = cfg.WriteTimeout.Duration + } + } + return []redis.DialOption{ + redis.DialReadTimeout(readTimeout), + redis.DialWriteTimeout(writeTimeout), + } +} + +func dialOptionsBuilder(cfg *config.RedisConfig, setTimeouts bool) []redis.DialOption { + var dopts []redis.DialOption + if setTimeouts { + dopts = timeoutDialOptions(cfg) + } + if cfg == nil { + return dopts } - dopts := []redis.DialOption{redis.DialReadTimeout(readTimeout)} if cfg.Password != "" { dopts = append(dopts, redis.DialPassword(cfg.Password)) } + if cfg.DB != nil { + dopts = append(dopts, redis.DialDatabase(*cfg.DB)) + } return dopts } -// DefaultDialFunc should always used. Only exception is for unit-tests. -func DefaultDialFunc(cfg *config.RedisConfig) func() (redis.Conn, error) { - dopts := dialOptionsBuilder(cfg) - innerDial := func() (redis.Conn, error) { - return redis.Dial(cfg.URL.Scheme, cfg.URL.Host, dopts...) +func keepAliveDialer(timeout time.Duration) func(string, string) (net.Conn, error) { + return func(network, address string) (net.Conn, error) { + addr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + tc, err := net.DialTCP(network, nil, addr) + if err != nil { + return nil, err + } + if err := tc.SetKeepAlive(true); err != nil { + return nil, err + } + if err := tc.SetKeepAlivePeriod(timeout); err != nil { + return nil, err + } + return tc, nil } - if sntnl != nil { - innerDial = func() (redis.Conn, error) { - address, err := sntnl.MasterAddr() - if err != nil { - return nil, err - } - return redis.Dial("tcp", address, dopts...) +} + +type redisDialerFunc func() (redis.Conn, error) + +func sentinelDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration) redisDialerFunc { + return func() (redis.Conn, error) { + address, err := sntnl.MasterAddr() + if err != nil { + return nil, err + } + dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod))) + return redis.Dial("tcp", address, dopts...) + } +} + +func defaultDialer(dopts []redis.DialOption, keepAlivePeriod time.Duration, url url.URL) redisDialerFunc { + return func() (redis.Conn, error) { + if url.Scheme == "unix" { + return redis.Dial(url.Scheme, url.Path, dopts...) } + dopts = append(dopts, redis.DialNetDial(keepAliveDialer(keepAlivePeriod))) + return redis.Dial(url.Scheme, url.Host, dopts...) } +} + +func countDialer(dialer redisDialerFunc) redisDialerFunc { return func() (redis.Conn, error) { - c, err := innerDial() + c, err := dialer() if err == nil { totalConnections.Inc() } @@ -103,8 +175,21 @@ func DefaultDialFunc(cfg *config.RedisConfig) func() (redis.Conn, error) { } } +// DefaultDialFunc should always used. Only exception is for unit-tests. +func DefaultDialFunc(cfg *config.RedisConfig, setReadTimeout bool) func() (redis.Conn, error) { + keepAlivePeriod := defaultKeepAlivePeriod + if cfg.KeepAlivePeriod != nil { + keepAlivePeriod = cfg.KeepAlivePeriod.Duration + } + dopts := dialOptionsBuilder(cfg, setReadTimeout) + if sntnl != nil { + return countDialer(sentinelDialer(dopts, keepAlivePeriod)) + } + return countDialer(defaultDialer(dopts, keepAlivePeriod, cfg.URL.URL)) +} + // Configure redis-connection -func Configure(cfg *config.RedisConfig, dialFunc func() (redis.Conn, error)) { +func Configure(cfg *config.RedisConfig, dialFunc func(*config.RedisConfig, bool) func() (redis.Conn, error)) { if cfg == nil { return } @@ -117,12 +202,13 @@ func Configure(cfg *config.RedisConfig, dialFunc func() (redis.Conn, error)) { maxActive = *cfg.MaxActive } sntnl = sentinelConn(cfg.SentinelMaster, cfg.Sentinel) - redisDialFunc = dialFunc + workerDialFunc = dialFunc(cfg, false) + poolDialFunc = dialFunc(cfg, true) pool = &redis.Pool{ MaxIdle: maxIdle, // Keep at most X hot connections MaxActive: maxActive, // Keep at most X live connections, 0 means unlimited IdleTimeout: defaultIdleTimeout, // X time until an unused connection is closed - Dial: redisDialFunc, + Dial: poolDialFunc, Wait: true, } if sntnl != nil { diff --git a/internal/redis/redis_test.go b/internal/redis/redis_test.go index 2d0f82a33ab7c03672247de9ea24c246295aa350..ba024cedfed9669c7e65f97a7ac3ff523c9290b0 100644 --- a/internal/redis/redis_test.go +++ b/internal/redis/redis_test.go @@ -5,6 +5,7 @@ import ( "time" "gitlab.com/gitlab-org/gitlab-workhorse/internal/config" + "gitlab.com/gitlab-org/gitlab-workhorse/internal/helper" "github.com/garyburd/redigo/redis" "github.com/rafaeljusto/redigomock" @@ -17,8 +18,10 @@ import ( func setupMockPool() (*redigomock.Conn, func()) { conn := redigomock.NewConn() cfg := &config.RedisConfig{URL: config.TomlURL{}} - Configure(cfg, func() (redis.Conn, error) { - return conn, nil + Configure(cfg, func(_ *config.RedisConfig, _ bool) func() (redis.Conn, error) { + return func() (redis.Conn, error) { + return conn, nil + } }) return conn, func() { pool = nil @@ -33,7 +36,7 @@ func TestConfigureNoConfig(t *testing.T) { func TestConfigureMinimalConfig(t *testing.T) { cfg := &config.RedisConfig{URL: config.TomlURL{}, Password: ""} - Configure(cfg, DefaultDialFunc(cfg)) + Configure(cfg, DefaultDialFunc) if assert.NotNil(t, pool, "Pool should not be nil") { assert.Equal(t, 1, pool.MaxIdle) assert.Equal(t, 1, pool.MaxActive) @@ -43,7 +46,8 @@ func TestConfigureMinimalConfig(t *testing.T) { } func TestConfigureFullConfig(t *testing.T) { - i, a, r := 4, 10, 3 + i, a := 4, 10 + r := config.TomlDuration{Duration: 3} cfg := &config.RedisConfig{ URL: config.TomlURL{}, Password: "", @@ -51,7 +55,7 @@ func TestConfigureFullConfig(t *testing.T) { MaxActive: &a, ReadTimeout: &r, } - Configure(cfg, DefaultDialFunc(cfg)) + Configure(cfg, DefaultDialFunc) if assert.NotNil(t, pool, "Pool should not be nil") { assert.Equal(t, i, pool.MaxIdle) assert.Equal(t, a, pool.MaxActive) @@ -88,3 +92,51 @@ func TestGetStringFail(t *testing.T) { _, err := GetString("foobar") assert.Error(t, err, "Expected error when not connected to redis") } + +func TestSentinelConnNoSentinel(t *testing.T) { + s := sentinelConn("", []config.TomlURL{}) + + assert.Nil(t, s, "Sentinel without urls should return nil") +} + +func TestSentinelConnTwoURLs(t *testing.T) { + urls := []string{"tcp://10.0.0.1:12345", "tcp://10.0.0.2:12345"} + var sentinelUrls []config.TomlURL + + for _, url := range urls { + parsedURL := helper.URLMustParse(url) + sentinelUrls = append(sentinelUrls, config.TomlURL{URL: *parsedURL}) + } + + s := sentinelConn("foobar", sentinelUrls) + assert.Equal(t, len(urls), len(s.Addrs)) + + for i := range urls { + assert.Equal(t, urls[i], s.Addrs[i]) + } +} + +func TestDialOptionsBuildersPassword(t *testing.T) { + dopts := dialOptionsBuilder(&config.RedisConfig{Password: "foo"}, false) + assert.Equal(t, 1, len(dopts)) +} + +func TestDialOptionsBuildersSetTimeouts(t *testing.T) { + dopts := dialOptionsBuilder(nil, true) + assert.Equal(t, 2, len(dopts)) +} + +func TestDialOptionsBuildersSetTimeoutsConfig(t *testing.T) { + cfg := &config.RedisConfig{ + ReadTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)}, + WriteTimeout: &config.TomlDuration{Duration: time.Second * time.Duration(15)}, + } + dopts := dialOptionsBuilder(cfg, true) + assert.Equal(t, 2, len(dopts)) +} + +func TestDialOptionsBuildersSelectDB(t *testing.T) { + db := 3 + dopts := dialOptionsBuilder(&config.RedisConfig{DB: &db}, false) + assert.Equal(t, 1, len(dopts)) +} diff --git a/main.go b/main.go index 866b21dca7690687487f24375d95b6e006b8ec1a..77b308b09f576c7f25e716a987a69b4dd838d2ac 100644 --- a/main.go +++ b/main.go @@ -133,7 +133,7 @@ func main() { cfg.Redis = cfgFromFile.Redis - redis.Configure(cfg.Redis, redis.DefaultDialFunc(cfg.Redis)) + redis.Configure(cfg.Redis, redis.DefaultDialFunc) go redis.Process(true) }