From 3f66cf31d9c60c8c12de0c2d2d336baf698e9064 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Alejandro=20Rodr=C3=ADguez?= <alejandro@gitlab.com>
Date: Thu, 8 Feb 2024 18:06:49 -0500
Subject: [PATCH] Stream Duo Chat slash commands responses

---
 .../llm/chain/agents/zero_shot/executor.rb    | 17 +----------
 .../gitlab/llm/chain/concerns/ai_dependent.rb | 15 ++++++++++
 .../llm/chain/tools/explain_code/executor.rb  |  6 ----
 .../llm/chain/tools/refactor_code/executor.rb |  6 ----
 .../llm/chain/tools/slash_command_tool.rb     |  8 +++++
 .../llm/chain/tools/write_tests/executor.rb   |  6 ----
 .../chain/tools/explain_code/executor_spec.rb |  7 ++++-
 .../tools/refactor_code/executor_spec.rb      |  7 ++++-
 .../chain/tools/write_tests/executor_spec.rb  |  7 ++++-
 .../slash_command_tool_shared_examples.rb     | 30 +++++++++++++++++++
 10 files changed, 72 insertions(+), 37 deletions(-)

diff --git a/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb b/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb
index 0bd997aba660d..8646e85e5dcd5 100644
--- a/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb
+++ b/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb
@@ -80,22 +80,7 @@ def execute
             private
 
             def execute_streamed_request
-              streamed_answer = StreamedZeroShotAnswer.new
-
-              request do |content|
-                next unless stream_response_handler
-
-                chunk = streamed_answer.next_chunk(content)
-
-                if chunk
-                  stream_response_handler.execute(
-                    response: Gitlab::Llm::Chain::PlainResponseModifier.new(content),
-                    options: {
-                      chunk_id: chunk[:id]
-                    }
-                  )
-                end
-              end
+              request(&streamed_request_handler(StreamedZeroShotAnswer.new))
             end
 
             attr_reader :logger, :stream_response_handler
diff --git a/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb b/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb
index 1e7944dad3a3b..ca6611bf32691 100644
--- a/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb
+++ b/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb
@@ -19,6 +19,21 @@ def request(&block)
             ai_request.request(prompt_str, &block)
           end
 
+          def streamed_request_handler(streamed_answer)
+            proc do |content|
+              next unless stream_response_handler
+
+              chunk = streamed_answer.next_chunk(content)
+
+              if chunk
+                stream_response_handler.execute(
+                  response: Gitlab::Llm::Chain::PlainResponseModifier.new(content),
+                  options: { chunk_id: chunk[:id] }
+                )
+              end
+            end
+          end
+
           private
 
           def ai_request
diff --git a/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb b/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb
index 286c1ec254577..012347da5eef0 100644
--- a/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb
+++ b/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb
@@ -57,12 +57,6 @@ def self.slash_commands
               SLASH_COMMANDS
             end
 
-            def perform
-              Answer.new(status: :ok, context: context, content: request, tool: nil)
-            rescue StandardError
-              Answer.error_answer(context: context, content: _("Unexpected error"))
-            end
-
             private
 
             def authorize
diff --git a/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb b/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb
index c5b5ee2a62ce6..a098bb0301a78 100644
--- a/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb
+++ b/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb
@@ -64,12 +64,6 @@ def self.slash_commands
               SLASH_COMMANDS
             end
 
-            def perform
-              Answer.new(status: :ok, context: context, content: request, tool: nil)
-            rescue StandardError
-              Answer.error_answer(context: context, content: _("Unexpected error"))
-            end
-
             private
 
             def selected_text_options
diff --git a/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb b/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb
index 99927db1f9e98..a4f517e06cf50 100644
--- a/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb
+++ b/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb
@@ -7,6 +7,14 @@ module Tools
         class SlashCommandTool < Tool
           extend ::Gitlab::Utils::Override
 
+          def perform
+            content = request(&streamed_request_handler(StreamedAnswer.new))
+
+            Answer.new(status: :ok, context: context, content: content, tool: nil)
+          rescue StandardError
+            Answer.error_answer(context: context, content: _("Unexpected error"))
+          end
+
           private
 
           attr_reader :command
diff --git a/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb b/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb
index 3e5e25b68ac2f..57006e86617c2 100644
--- a/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb
+++ b/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb
@@ -63,12 +63,6 @@ def self.slash_commands
               SLASH_COMMANDS
             end
 
-            def perform
-              Answer.new(status: :ok, context: context, content: request, tool: nil)
-            rescue StandardError
-              Answer.error_answer(context: context, content: _("Unexpected error"))
-            end
-
             private
 
             def authorize
diff --git a/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb b/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb
index cc961659a0a84..ecc78cb103b2a 100644
--- a/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb
+++ b/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb
@@ -8,6 +8,7 @@
   let(:ai_request_double) { instance_double(Gitlab::Llm::Chain::Requests::Anthropic) }
   let(:input) { 'input' }
   let(:options) { { input: input } }
+  let(:stream_response_handler) { nil }
   let(:command) { nil }
 
   let(:context) do
@@ -22,7 +23,11 @@
     )
   end
 
-  subject(:tool) { described_class.new(context: context, options: options, command: command) }
+  subject(:tool) do
+    described_class.new(
+      context: context, options: options, stream_response_handler: stream_response_handler, command: command
+    )
+  end
 
   describe '#name' do
     it 'returns tool name' do
diff --git a/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb b/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb
index 0c77d6a38e0c7..22503102fd974 100644
--- a/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb
+++ b/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb
@@ -8,6 +8,7 @@
   let(:ai_request_double) { instance_double(Gitlab::Llm::Chain::Requests::Anthropic) }
   let(:input) { 'input' }
   let(:options) { { input: input } }
+  let(:stream_response_handler) { nil }
   let(:command) { nil }
 
   let(:context) do
@@ -22,7 +23,11 @@
     )
   end
 
-  subject(:tool) { described_class.new(context: context, options: options, command: command) }
+  subject(:tool) do
+    described_class.new(
+      context: context, options: options, stream_response_handler: stream_response_handler, command: command
+    )
+  end
 
   describe '#name' do
     it 'returns tool name' do
diff --git a/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb b/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb
index b5d7e440b25b5..aac6991990f26 100644
--- a/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb
+++ b/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb
@@ -8,6 +8,7 @@
   let(:ai_request_double) { instance_double(Gitlab::Llm::Chain::Requests::Anthropic) }
   let(:input) { 'input' }
   let(:options) { { input: input } }
+  let(:stream_response_handler) { nil }
   let(:command) { nil }
 
   let(:context) do
@@ -17,7 +18,11 @@
     )
   end
 
-  subject(:tool) { described_class.new(context: context, options: options, command: command) }
+  subject(:tool) do
+    described_class.new(
+      context: context, options: options, stream_response_handler: stream_response_handler, command: command
+    )
+  end
 
   describe '#name' do
     it 'returns tool name' do
diff --git a/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb b/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb
index 00b5d3fa8f668..528ad652c1d85 100644
--- a/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb
+++ b/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb
@@ -84,4 +84,34 @@
       tool.execute
     end
   end
+
+  context 'when stream_response_service is set' do
+    let(:stream_response_handler) { instance_double(::Gitlab::Llm::ResponseService) }
+
+    before do
+      allow(ai_request_double).to receive(:request).and_yield("Hello").and_yield(" World")
+    end
+
+    it 'streams the final answer' do
+      first_response_double = double
+      second_response_double = double
+
+      allow(Gitlab::Llm::Chain::PlainResponseModifier).to receive(:new).with("Hello")
+        .and_return(first_response_double)
+
+      allow(Gitlab::Llm::Chain::PlainResponseModifier).to receive(:new).with(" World")
+        .and_return(second_response_double)
+
+      expect(stream_response_handler).to receive(:execute).with(
+        response: first_response_double,
+        options: { chunk_id: 1 }
+      )
+      expect(stream_response_handler).to receive(:execute).with(
+        response: second_response_double,
+        options: { chunk_id: 2 }
+      )
+
+      tool.execute
+    end
+  end
 end
-- 
GitLab