From 1c4905be2ac6ee92e27a22312c2d6c29b6f5ceee Mon Sep 17 00:00:00 2001
From: Matthias Kaeppler <mkaeppler@gitlab.com>
Date: Mon, 14 Feb 2022 13:52:42 +0100
Subject: [PATCH] Break out process supervision into lib

This extracts the process supervision logic
from sidekiq-cluster into a helper module.

We will re-use this for Puma soon.
---
 lib/gitlab/process_supervisor.rb           | 110 +++++++++++
 sidekiq_cluster/cli.rb                     | 107 ++++-------
 sidekiq_cluster/sidekiq_cluster.rb         |   7 +-
 spec/commands/sidekiq_cluster/cli_spec.rb  | 213 +++++++--------------
 spec/lib/gitlab/process_supervisor_spec.rb | 127 ++++++++++++
 5 files changed, 339 insertions(+), 225 deletions(-)
 create mode 100644 lib/gitlab/process_supervisor.rb
 create mode 100644 spec/lib/gitlab/process_supervisor_spec.rb

diff --git a/lib/gitlab/process_supervisor.rb b/lib/gitlab/process_supervisor.rb
new file mode 100644
index 0000000000000..f0d2bbc33bdc7
--- /dev/null
+++ b/lib/gitlab/process_supervisor.rb
@@ -0,0 +1,110 @@
+# frozen_string_literal: true
+
+module Gitlab
+  # Given a set of process IDs, the supervisor can monitor processes
+  # for being alive and invoke a callback if some or all should go away.
+  # The receiver of the callback can then act on this event, for instance
+  # by restarting those processes or performing clean-up work.
+  #
+  # The supervisor will also trap termination signals if provided and
+  # propagate those to the supervised processes. Any supervised processes
+  # that do not terminate within a specified grace period will be killed.
+  class ProcessSupervisor
+    DEFAULT_HEALTH_CHECK_INTERVAL_SECONDS = 5
+    DEFAULT_TERMINATE_INTERVAL_SECONDS = 1
+    DEFAULT_TERMINATE_TIMEOUT_SECONDS = 10
+
+    attr_reader :alive
+
+    def initialize(
+      health_check_interval_seconds: DEFAULT_HEALTH_CHECK_INTERVAL_SECONDS,
+      check_terminate_interval_seconds: DEFAULT_TERMINATE_INTERVAL_SECONDS,
+      terminate_timeout_seconds: DEFAULT_TERMINATE_TIMEOUT_SECONDS,
+      term_signals: %i(INT TERM),
+      forwarded_signals: [])
+
+      @term_signals = term_signals
+      @forwarded_signals = forwarded_signals
+      @health_check_interval_seconds = health_check_interval_seconds
+      @check_terminate_interval_seconds = check_terminate_interval_seconds
+      @terminate_timeout_seconds = terminate_timeout_seconds
+    end
+
+    # Starts a supervision loop for the given process ID(s).
+    #
+    # If any or all processes go away, the IDs of any dead processes will
+    # be yielded to the given block, so callers can act on them.
+    #
+    # If the block returns a non-empty list of IDs, the supervisor will
+    # start observing those processes instead. Otherwise it will shut down.
+    def supervise(pid_or_pids, &on_process_death)
+      @pids = Array(pid_or_pids)
+
+      trap_signals!
+
+      @alive = true
+      while @alive
+        sleep(@health_check_interval_seconds)
+
+        check_process_health(&on_process_death)
+      end
+    end
+
+    private
+
+    def check_process_health(&on_process_death)
+      unless all_alive?
+        dead_pids = @pids - live_pids
+        @pids = Array(yield(dead_pids))
+        @alive = @pids.any?
+      end
+    end
+
+    def trap_signals!
+      ProcessManagement.trap_signals(@term_signals) do |signal|
+        @alive = false
+        signal_all(signal)
+        wait_for_termination
+      end
+
+      ProcessManagement.trap_signals(@forwarded_signals) do |signal|
+        signal_all(signal)
+      end
+    end
+
+    def wait_for_termination
+      deadline = monotonic_time + @terminate_timeout_seconds
+      sleep(@check_terminate_interval_seconds) while continue_waiting?(deadline)
+
+      hard_stop_stuck_pids
+    end
+
+    def monotonic_time
+      Process.clock_gettime(Process::CLOCK_MONOTONIC, :float_second)
+    end
+
+    def continue_waiting?(deadline)
+      any_alive? && monotonic_time < deadline
+    end
+
+    def signal_all(signal)
+      ProcessManagement.signal_processes(@pids, signal)
+    end
+
+    def hard_stop_stuck_pids
+      ProcessManagement.signal_processes(live_pids, "-KILL")
+    end
+
+    def any_alive?
+      ProcessManagement.any_alive?(@pids)
+    end
+
+    def all_alive?
+      ProcessManagement.all_alive?(@pids)
+    end
+
+    def live_pids
+      ProcessManagement.pids_alive(@pids)
+    end
+  end
+end
diff --git a/sidekiq_cluster/cli.rb b/sidekiq_cluster/cli.rb
index 2feb77601b89c..f366cb26b8ef8 100644
--- a/sidekiq_cluster/cli.rb
+++ b/sidekiq_cluster/cli.rb
@@ -14,6 +14,7 @@
 require_relative '../lib/gitlab/sidekiq_config/worker_matcher'
 require_relative '../lib/gitlab/sidekiq_logging/json_formatter'
 require_relative '../lib/gitlab/process_management'
+require_relative '../lib/gitlab/process_supervisor'
 require_relative '../metrics_server/metrics_server'
 require_relative 'sidekiq_cluster'
 
@@ -38,8 +39,7 @@ def initialize(log_output = $stderr)
         @metrics_dir = ENV["prometheus_multiproc_dir"] || File.absolute_path("tmp/prometheus_multiproc_dir/sidekiq")
         @pid = nil
         @interval = 5
-        @alive = true
-        @processes = []
+        @soft_timeout_seconds = DEFAULT_SOFT_TIMEOUT_SECONDS
         @logger = Logger.new(log_output)
         @logger.formatter = ::Gitlab::SidekiqLogging::JSONFormatter.new
         @rails_path = Dir.pwd
@@ -103,95 +103,63 @@ def run(argv = ARGV)
           @logger.info("Starting cluster with #{queue_groups.length} processes")
         end
 
-        start_metrics_server(wipe_metrics_dir: true)
+        start_and_supervise_workers(queue_groups)
+      end
 
-        @processes = SidekiqCluster.start(
+      def start_and_supervise_workers(queue_groups)
+        worker_pids = SidekiqCluster.start(
           queue_groups,
           env: @environment,
           directory: @rails_path,
           max_concurrency: @max_concurrency,
           min_concurrency: @min_concurrency,
           dryrun: @dryrun,
-          timeout: soft_timeout_seconds
+          timeout: @soft_timeout_seconds
         )
 
         return if @dryrun
 
-        write_pid
-        trap_signals
-        start_loop
-      end
-
-      def write_pid
         ProcessManagement.write_pid(@pid) if @pid
-      end
-
-      def soft_timeout_seconds
-        @soft_timeout_seconds || DEFAULT_SOFT_TIMEOUT_SECONDS
-      end
-
-      # The amount of time it'll wait for killing the alive Sidekiq processes.
-      def hard_timeout_seconds
-        soft_timeout_seconds + DEFAULT_HARD_TIMEOUT_SECONDS
-      end
-
-      def monotonic_time
-        Process.clock_gettime(Process::CLOCK_MONOTONIC, :float_second)
-      end
-
-      def continue_waiting?(deadline)
-        ProcessManagement.any_alive?(@processes) && monotonic_time < deadline
-      end
-
-      def hard_stop_stuck_pids
-        ProcessManagement.signal_processes(ProcessManagement.pids_alive(@processes), "-KILL")
-      end
-
-      def wait_for_termination
-        deadline = monotonic_time + hard_timeout_seconds
-        sleep(CHECK_TERMINATE_INTERVAL_SECONDS) while continue_waiting?(deadline)
 
-        hard_stop_stuck_pids
-      end
-
-      def trap_signals
-        ProcessManagement.trap_signals(TERMINATE_SIGNALS) do |signal|
-          @alive = false
-          ProcessManagement.signal_processes(@processes, signal)
-          wait_for_termination
-        end
-
-        ProcessManagement.trap_signals(FORWARD_SIGNALS) do |signal|
-          ProcessManagement.signal_processes(@processes, signal)
-        end
-      end
+        supervisor = Gitlab::ProcessSupervisor.new(
+          health_check_interval_seconds: @interval,
+          terminate_timeout_seconds: @soft_timeout_seconds + TIMEOUT_GRACE_PERIOD_SECONDS,
+          term_signals: TERMINATE_SIGNALS,
+          forwarded_signals: FORWARD_SIGNALS
+        )
 
-      def start_loop
-        while @alive
-          sleep(@interval)
+        metrics_server_pid = start_metrics_server
 
-          if metrics_server_enabled? && ProcessManagement.process_died?(@metrics_server_pid)
-            @logger.warn('Metrics server went away')
-            start_metrics_server(wipe_metrics_dir: false)
-          end
+        all_pids = worker_pids + Array(metrics_server_pid)
 
-          unless ProcessManagement.all_alive?(@processes)
-            # If a child process died we'll just terminate the whole cluster. It's up to
-            # runit and such to then restart the cluster.
+        supervisor.supervise(all_pids) do |dead_pids|
+          # If we're not in the process of shutting down the cluster,
+          # and the metrics server died, restart it.
+          if supervisor.alive && dead_pids.include?(metrics_server_pid)
+            @logger.info('Metrics server terminated, restarting...')
+            metrics_server_pid = restart_metrics_server(wipe_metrics_dir: false)
+            all_pids = worker_pids + Array(metrics_server_pid)
+          else
+            # If a worker process died we'll just terminate the whole cluster.
+            # We let an external system (runit, kubernetes) handle the restart.
             @logger.info('A worker terminated, shutting down the cluster')
 
-            stop_metrics_server
-            ProcessManagement.signal_processes(@processes, :TERM)
-            break
+            ProcessManagement.signal_processes(all_pids - dead_pids, :TERM)
+            # Signal supervisor not to respawn workers and shut down.
+            []
           end
         end
       end
 
-      def start_metrics_server(wipe_metrics_dir: false)
+      def start_metrics_server
         return unless metrics_server_enabled?
 
+        restart_metrics_server(wipe_metrics_dir: true)
+      end
+
+      def restart_metrics_server(wipe_metrics_dir: false)
         @logger.info("Starting metrics server on port #{sidekiq_exporter_port}")
-        @metrics_server_pid = MetricsServer.fork(
+        MetricsServer.fork(
           'sidekiq',
           metrics_dir: @metrics_dir,
           wipe_metrics_dir: wipe_metrics_dir,
@@ -225,13 +193,6 @@ def metrics_server_enabled?
         !@dryrun && sidekiq_exporter_enabled? && exporter_has_a_unique_port?
       end
 
-      def stop_metrics_server
-        return unless @metrics_server_pid
-
-        @logger.info("Stopping metrics server (PID #{@metrics_server_pid})")
-        ProcessManagement.signal(@metrics_server_pid, :TERM)
-      end
-
       def option_parser
         OptionParser.new do |opt|
           opt.banner = "#{File.basename(__FILE__)} [QUEUE,QUEUE] [QUEUE] ... [OPTIONS]"
diff --git a/sidekiq_cluster/sidekiq_cluster.rb b/sidekiq_cluster/sidekiq_cluster.rb
index c5139ab887467..3ba3211b0e42c 100644
--- a/sidekiq_cluster/sidekiq_cluster.rb
+++ b/sidekiq_cluster/sidekiq_cluster.rb
@@ -4,8 +4,6 @@
 
 module Gitlab
   module SidekiqCluster
-    CHECK_TERMINATE_INTERVAL_SECONDS = 1
-
     # How long to wait when asking for a clean termination.
     # It maps the Sidekiq default timeout:
     # https://github.com/mperham/sidekiq/wiki/Signals#term
@@ -14,8 +12,9 @@ module SidekiqCluster
     # is given through arguments.
     DEFAULT_SOFT_TIMEOUT_SECONDS = 25
 
-    # After surpassing the soft timeout.
-    DEFAULT_HARD_TIMEOUT_SECONDS = 5
+    # Additional time granted after surpassing the soft timeout
+    # before we kill the process.
+    TIMEOUT_GRACE_PERIOD_SECONDS = 5
 
     # Starts Sidekiq workers for the pairs of processes.
     #
diff --git a/spec/commands/sidekiq_cluster/cli_spec.rb b/spec/commands/sidekiq_cluster/cli_spec.rb
index 15b738cacd100..6baaa98eff9d4 100644
--- a/spec/commands/sidekiq_cluster/cli_spec.rb
+++ b/spec/commands/sidekiq_cluster/cli_spec.rb
@@ -5,8 +5,11 @@
 
 require_relative '../../support/stub_settings_source'
 require_relative '../../../sidekiq_cluster/cli'
+require_relative '../../support/helpers/next_instance_of'
 
 RSpec.describe Gitlab::SidekiqCluster::CLI, stub_settings_source: true do # rubocop:disable RSpec/FilePath
+  include NextInstanceOf
+
   let(:cli) { described_class.new('/dev/null') }
   let(:timeout) { Gitlab::SidekiqCluster::DEFAULT_SOFT_TIMEOUT_SECONDS }
   let(:default_options) do
@@ -61,9 +64,8 @@
 
     context 'with arguments' do
       before do
-        allow(cli).to receive(:write_pid)
-        allow(cli).to receive(:trap_signals)
-        allow(cli).to receive(:start_loop)
+        allow(Gitlab::ProcessManagement).to receive(:write_pid)
+        allow_next_instance_of(Gitlab::ProcessSupervisor) { |it| allow(it).to receive(:supervise) }
       end
 
       it 'starts the Sidekiq workers' do
@@ -81,7 +83,7 @@
           .to receive(:worker_queues).and_return(worker_queues)
 
         expect(Gitlab::SidekiqCluster)
-          .to receive(:start).with([worker_queues], default_options)
+          .to receive(:start).with([worker_queues], default_options).and_return([])
 
         cli.run(%w(*))
       end
@@ -135,6 +137,7 @@
         it 'when given', 'starts Sidekiq workers with given timeout' do
           expect(Gitlab::SidekiqCluster).to receive(:start)
             .with([['foo']], default_options.merge(timeout: 10))
+            .and_return([])
 
           cli.run(%w(foo --timeout 10))
         end
@@ -142,6 +145,7 @@
         it 'when not given', 'starts Sidekiq workers with default timeout' do
           expect(Gitlab::SidekiqCluster).to receive(:start)
             .with([['foo']], default_options.merge(timeout: Gitlab::SidekiqCluster::DEFAULT_SOFT_TIMEOUT_SECONDS))
+            .and_return([])
 
           cli.run(%w(foo))
         end
@@ -257,7 +261,7 @@
             .to receive(:worker_queues).and_return(worker_queues)
 
           expect(Gitlab::SidekiqCluster)
-            .to receive(:start).with([worker_queues], default_options)
+            .to receive(:start).with([worker_queues], default_options).and_return([])
 
           cli.run(%w(--queue-selector *))
         end
@@ -292,16 +296,15 @@
 
       context 'starting the server' do
         context 'without --dryrun' do
+          before do
+            allow(Gitlab::SidekiqCluster).to receive(:start).and_return([])
+            allow(Gitlab::ProcessManagement).to receive(:write_pid)
+            allow_next_instance_of(Gitlab::ProcessSupervisor) { |it| allow(it).to receive(:supervise) }
+          end
+
           context 'when there are no sidekiq_health_checks settings set' do
             let(:sidekiq_exporter_enabled) { true }
 
-            before do
-              allow(Gitlab::SidekiqCluster).to receive(:start)
-              allow(cli).to receive(:write_pid)
-              allow(cli).to receive(:trap_signals)
-              allow(cli).to receive(:start_loop)
-            end
-
             it 'does not start a sidekiq metrics server' do
               expect(MetricsServer).not_to receive(:fork)
 
@@ -312,13 +315,6 @@
           context 'when the sidekiq_exporter.port setting is not set' do
             let(:sidekiq_exporter_enabled) { true }
 
-            before do
-              allow(Gitlab::SidekiqCluster).to receive(:start)
-              allow(cli).to receive(:write_pid)
-              allow(cli).to receive(:trap_signals)
-              allow(cli).to receive(:start_loop)
-            end
-
             it 'does not start a sidekiq metrics server' do
               expect(MetricsServer).not_to receive(:fork)
 
@@ -342,13 +338,6 @@
               }
             end
 
-            before do
-              allow(Gitlab::SidekiqCluster).to receive(:start)
-              allow(cli).to receive(:write_pid)
-              allow(cli).to receive(:trap_signals)
-              allow(cli).to receive(:start_loop)
-            end
-
             it 'does not start a sidekiq metrics server' do
               expect(MetricsServer).not_to receive(:fork)
 
@@ -368,13 +357,6 @@
               }
             end
 
-            before do
-              allow(Gitlab::SidekiqCluster).to receive(:start)
-              allow(cli).to receive(:write_pid)
-              allow(cli).to receive(:trap_signals)
-              allow(cli).to receive(:start_loop)
-            end
-
             it 'does not start a sidekiq metrics server' do
               expect(MetricsServer).not_to receive(:fork)
 
@@ -397,13 +379,6 @@
             end
 
             with_them do
-              before do
-                allow(Gitlab::SidekiqCluster).to receive(:start)
-                allow(cli).to receive(:write_pid)
-                allow(cli).to receive(:trap_signals)
-                allow(cli).to receive(:start_loop)
-              end
-
               specify do
                 if start_metrics_server
                   expect(MetricsServer).to receive(:fork).with('sidekiq', metrics_dir: metrics_dir, wipe_metrics_dir: true, reset_signals: trapped_signals)
@@ -415,6 +390,23 @@
               end
             end
           end
+
+          context 'when a PID is specified' do
+            it 'writes the PID to a file' do
+              expect(Gitlab::ProcessManagement).to receive(:write_pid).with('/dev/null')
+
+              cli.option_parser.parse!(%w(-P /dev/null))
+              cli.run(%w(foo))
+            end
+          end
+
+          context 'when no PID is specified' do
+            it 'does not write a PID' do
+              expect(Gitlab::ProcessManagement).not_to receive(:write_pid)
+
+              cli.run(%w(foo))
+            end
+          end
         end
 
         context 'with --dryrun set' do
@@ -427,130 +419,55 @@
           end
         end
       end
-
-      context 'supervising the server' do
-        let(:sidekiq_exporter_enabled) { true }
-        let(:sidekiq_health_checks_port) { '3907' }
-
-        before do
-          allow(cli).to receive(:sleep).with(a_kind_of(Numeric))
-          allow(MetricsServer).to receive(:fork).and_return(99)
-          cli.start_metrics_server
-        end
-
-        it 'stops the metrics server when one of the processes has been terminated' do
-          allow(Gitlab::ProcessManagement).to receive(:process_died?).and_return(false)
-          allow(Gitlab::ProcessManagement).to receive(:all_alive?).with(an_instance_of(Array)).and_return(false)
-          allow(Gitlab::ProcessManagement).to receive(:signal_processes).with(an_instance_of(Array), :TERM)
-
-          expect(Process).to receive(:kill).with(:TERM, 99)
-
-          cli.start_loop
-        end
-
-        it 'starts the metrics server when it is down' do
-          allow(Gitlab::ProcessManagement).to receive(:process_died?).and_return(true)
-          allow(Gitlab::ProcessManagement).to receive(:all_alive?).with(an_instance_of(Array)).and_return(false)
-          allow(cli).to receive(:stop_metrics_server)
-
-          expect(MetricsServer).to receive(:fork).with(
-            'sidekiq', metrics_dir: metrics_dir, wipe_metrics_dir: false, reset_signals: trapped_signals
-          )
-
-          cli.start_loop
-        end
-      end
-    end
-  end
-
-  describe '#write_pid' do
-    context 'when a PID is specified' do
-      it 'writes the PID to a file' do
-        expect(Gitlab::ProcessManagement).to receive(:write_pid).with('/dev/null')
-
-        cli.option_parser.parse!(%w(-P /dev/null))
-        cli.write_pid
-      end
     end
 
-    context 'when no PID is specified' do
-      it 'does not write a PID' do
-        expect(Gitlab::ProcessManagement).not_to receive(:write_pid)
-
-        cli.write_pid
-      end
-    end
-  end
+    context 'supervising the cluster' do
+      let(:sidekiq_exporter_enabled) { true }
+      let(:sidekiq_health_checks_port) { '3907' }
+      let(:metrics_server_pid) { 99 }
+      let(:sidekiq_worker_pids) { [2, 42] }
 
-  describe '#wait_for_termination' do
-    it 'waits for termination of all sub-processes and succeeds after 3 checks' do
-      expect(Gitlab::ProcessManagement).to receive(:any_alive?)
-        .with(an_instance_of(Array)).and_return(true, true, true, false)
-
-      expect(Gitlab::ProcessManagement).to receive(:pids_alive)
-        .with([]).and_return([])
-
-      expect(Gitlab::ProcessManagement).to receive(:signal_processes)
-        .with([], "-KILL")
-
-      stub_const("Gitlab::SidekiqCluster::CHECK_TERMINATE_INTERVAL_SECONDS", 0.1)
-      allow(cli).to receive(:terminate_timeout_seconds) { 1 }
-
-      cli.wait_for_termination
-    end
-
-    context 'with hanging workers' do
       before do
-        expect(cli).to receive(:write_pid)
-        expect(cli).to receive(:trap_signals)
-        expect(cli).to receive(:start_loop)
+        allow(Gitlab::SidekiqCluster).to receive(:start).and_return(sidekiq_worker_pids)
+        allow(Gitlab::ProcessManagement).to receive(:write_pid)
       end
 
-      it 'hard kills workers after timeout expires' do
-        worker_pids = [101, 102, 103]
-        expect(Gitlab::SidekiqCluster).to receive(:start)
-                                            .with([['foo']], default_options)
-                                            .and_return(worker_pids)
-
-        expect(Gitlab::ProcessManagement).to receive(:any_alive?)
-          .with(worker_pids).and_return(true).at_least(10).times
-
-        expect(Gitlab::ProcessManagement).to receive(:pids_alive)
-          .with(worker_pids).and_return([102])
+      it 'stops the entire process cluster if one of the workers has been terminated' do
+        allow_next_instance_of(Gitlab::ProcessSupervisor) do |it|
+          allow(it).to receive(:supervise).and_yield([2])
+        end
 
-        expect(Gitlab::ProcessManagement).to receive(:signal_processes)
-          .with([102], "-KILL")
+        expect(MetricsServer).to receive(:fork).once.and_return(metrics_server_pid)
+        expect(Gitlab::ProcessManagement).to receive(:signal_processes).with([42, 99], :TERM)
 
         cli.run(%w(foo))
-
-        stub_const("Gitlab::SidekiqCluster::CHECK_TERMINATE_INTERVAL_SECONDS", 0.1)
-        allow(cli).to receive(:terminate_timeout_seconds) { 1 }
-
-        cli.wait_for_termination
       end
-    end
-  end
 
-  describe '#trap_signals' do
-    it 'traps termination and sidekiq specific signals' do
-      expect(Gitlab::ProcessManagement).to receive(:trap_signals).with(%i[INT TERM])
-      expect(Gitlab::ProcessManagement).to receive(:trap_signals).with(%i[TTIN USR1 USR2 HUP])
+      context 'when the supervisor is alive' do
+        it 'restarts the metrics server when it is down' do
+          allow_next_instance_of(Gitlab::ProcessSupervisor) do |it|
+            allow(it).to receive(:alive).and_return(true)
+            allow(it).to receive(:supervise).and_yield([metrics_server_pid])
+          end
 
-      cli.trap_signals
-    end
-  end
+          expect(MetricsServer).to receive(:fork).twice.and_return(metrics_server_pid)
 
-  describe '#start_loop' do
-    it 'runs until one of the processes has been terminated' do
-      allow(cli).to receive(:sleep).with(a_kind_of(Numeric))
+          cli.run(%w(foo))
+        end
+      end
 
-      expect(Gitlab::ProcessManagement).to receive(:all_alive?)
-        .with(an_instance_of(Array)).and_return(false)
+      context 'when the supervisor is shutting down' do
+        it 'does not restart the metrics server' do
+          allow_next_instance_of(Gitlab::ProcessSupervisor) do |it|
+            allow(it).to receive(:alive).and_return(false)
+            allow(it).to receive(:supervise).and_yield([metrics_server_pid])
+          end
 
-      expect(Gitlab::ProcessManagement).to receive(:signal_processes)
-        .with(an_instance_of(Array), :TERM)
+          expect(MetricsServer).to receive(:fork).once.and_return(metrics_server_pid)
 
-      cli.start_loop
+          cli.run(%w(foo))
+        end
+      end
     end
   end
 end
diff --git a/spec/lib/gitlab/process_supervisor_spec.rb b/spec/lib/gitlab/process_supervisor_spec.rb
new file mode 100644
index 0000000000000..d264c77d5fb30
--- /dev/null
+++ b/spec/lib/gitlab/process_supervisor_spec.rb
@@ -0,0 +1,127 @@
+# frozen_string_literal: true
+
+require_relative '../../../lib/gitlab/process_supervisor'
+
+RSpec.describe Gitlab::ProcessSupervisor do
+  let(:health_check_interval_seconds) { 0.1 }
+  let(:check_terminate_interval_seconds) { 1 }
+  let(:forwarded_signals) { [] }
+  let(:process_id) do
+    Process.spawn('while true; do sleep 1; done').tap do |pid|
+      Process.detach(pid)
+    end
+  end
+
+  subject(:supervisor) do
+    described_class.new(
+      health_check_interval_seconds: health_check_interval_seconds,
+      check_terminate_interval_seconds: check_terminate_interval_seconds,
+      terminate_timeout_seconds: 1 + check_terminate_interval_seconds,
+      forwarded_signals: forwarded_signals
+    )
+  end
+
+  after do
+    if Gitlab::ProcessManagement.process_alive?(process_id)
+      Process.kill('KILL', process_id)
+    end
+  end
+
+  describe '#supervise' do
+    context 'while supervised process is alive' do
+      it 'does not invoke callback' do
+        expect(Gitlab::ProcessManagement.process_alive?(process_id)).to be(true)
+        pids_killed = []
+
+        thread = Thread.new do
+          supervisor.supervise(process_id) do |dead_pids|
+            pids_killed = dead_pids
+            []
+          end
+        end
+
+        # Wait several times the poll frequency of the supervisor.
+        sleep health_check_interval_seconds * 10
+        thread.terminate
+
+        expect(pids_killed).to be_empty
+        expect(Gitlab::ProcessManagement.process_alive?(process_id)).to be(true)
+      end
+    end
+
+    context 'when supervised process dies' do
+      it 'triggers callback with the dead PIDs' do
+        expect(Gitlab::ProcessManagement.process_alive?(process_id)).to be(true)
+        pids_killed = []
+
+        thread = Thread.new do
+          supervisor.supervise(process_id) do |dead_pids|
+            pids_killed = dead_pids
+            []
+          end
+        end
+
+        # Terminate the supervised process.
+        Process.kill('TERM', process_id)
+
+        await_condition(sleep_sec: health_check_interval_seconds) do
+          pids_killed == [process_id]
+        end
+        thread.terminate
+
+        expect(Gitlab::ProcessManagement.process_alive?(process_id)).to be(false)
+      end
+    end
+
+    context 'signal handling' do
+      before do
+        allow(supervisor).to receive(:sleep)
+        allow(Gitlab::ProcessManagement).to receive(:trap_signals)
+        allow(Gitlab::ProcessManagement).to receive(:all_alive?).and_return(false)
+        allow(Gitlab::ProcessManagement).to receive(:signal_processes).with([process_id], anything)
+      end
+
+      context 'termination signals' do
+        context 'when TERM results in timely shutdown of processes' do
+          it 'forwards them to observed processes without waiting for grace period to expire' do
+            allow(Gitlab::ProcessManagement).to receive(:any_alive?).and_return(false)
+
+            expect(Gitlab::ProcessManagement).to receive(:trap_signals).ordered.with(%i(INT TERM)).and_yield(:TERM)
+            expect(Gitlab::ProcessManagement).to receive(:signal_processes).ordered.with([process_id], :TERM)
+            expect(supervisor).not_to receive(:sleep).with(check_terminate_interval_seconds)
+
+            supervisor.supervise(process_id) { [] }
+          end
+        end
+
+        context 'when TERM does not result in timely shutdown of processes' do
+          it 'issues a KILL signal after the grace period expires' do
+            expect(Gitlab::ProcessManagement).to receive(:trap_signals).with(%i(INT TERM)).and_yield(:TERM)
+            expect(Gitlab::ProcessManagement).to receive(:signal_processes).ordered.with([process_id], :TERM)
+            expect(supervisor).to receive(:sleep).ordered.with(check_terminate_interval_seconds).at_least(:once)
+            expect(Gitlab::ProcessManagement).to receive(:signal_processes).ordered.with([process_id], '-KILL')
+
+            supervisor.supervise(process_id) { [] }
+          end
+        end
+      end
+
+      context 'forwarded signals' do
+        let(:forwarded_signals) { %i(USR1) }
+
+        it 'forwards given signals to the observed processes' do
+          expect(Gitlab::ProcessManagement).to receive(:trap_signals).with(%i(USR1)).and_yield(:USR1)
+          expect(Gitlab::ProcessManagement).to receive(:signal_processes).ordered.with([process_id], :USR1)
+
+          supervisor.supervise(process_id) { [] }
+        end
+      end
+    end
+  end
+
+  def await_condition(timeout_sec: 5, sleep_sec: 0.1)
+    Timeout.timeout(timeout_sec) do
+      sleep sleep_sec until yield
+    end
+  end
+end
-- 
GitLab