From 3f66cf31d9c60c8c12de0c2d2d336baf698e9064 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Rodr=C3=ADguez?= <alejandro@gitlab.com> Date: Thu, 8 Feb 2024 18:06:49 -0500 Subject: [PATCH] Stream Duo Chat slash commands responses --- .../llm/chain/agents/zero_shot/executor.rb | 17 +---------- .../gitlab/llm/chain/concerns/ai_dependent.rb | 15 ++++++++++ .../llm/chain/tools/explain_code/executor.rb | 6 ---- .../llm/chain/tools/refactor_code/executor.rb | 6 ---- .../llm/chain/tools/slash_command_tool.rb | 8 +++++ .../llm/chain/tools/write_tests/executor.rb | 6 ---- .../chain/tools/explain_code/executor_spec.rb | 7 ++++- .../tools/refactor_code/executor_spec.rb | 7 ++++- .../chain/tools/write_tests/executor_spec.rb | 7 ++++- .../slash_command_tool_shared_examples.rb | 30 +++++++++++++++++++ 10 files changed, 72 insertions(+), 37 deletions(-) diff --git a/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb b/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb index 0bd997aba660d..8646e85e5dcd5 100644 --- a/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb +++ b/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb @@ -80,22 +80,7 @@ def execute private def execute_streamed_request - streamed_answer = StreamedZeroShotAnswer.new - - request do |content| - next unless stream_response_handler - - chunk = streamed_answer.next_chunk(content) - - if chunk - stream_response_handler.execute( - response: Gitlab::Llm::Chain::PlainResponseModifier.new(content), - options: { - chunk_id: chunk[:id] - } - ) - end - end + request(&streamed_request_handler(StreamedZeroShotAnswer.new)) end attr_reader :logger, :stream_response_handler diff --git a/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb b/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb index 1e7944dad3a3b..ca6611bf32691 100644 --- a/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb +++ b/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb @@ -19,6 +19,21 @@ def request(&block) ai_request.request(prompt_str, &block) end + def streamed_request_handler(streamed_answer) + proc do |content| + next unless stream_response_handler + + chunk = streamed_answer.next_chunk(content) + + if chunk + stream_response_handler.execute( + response: Gitlab::Llm::Chain::PlainResponseModifier.new(content), + options: { chunk_id: chunk[:id] } + ) + end + end + end + private def ai_request diff --git a/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb b/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb index 286c1ec254577..012347da5eef0 100644 --- a/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb +++ b/ee/lib/gitlab/llm/chain/tools/explain_code/executor.rb @@ -57,12 +57,6 @@ def self.slash_commands SLASH_COMMANDS end - def perform - Answer.new(status: :ok, context: context, content: request, tool: nil) - rescue StandardError - Answer.error_answer(context: context, content: _("Unexpected error")) - end - private def authorize diff --git a/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb b/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb index c5b5ee2a62ce6..a098bb0301a78 100644 --- a/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb +++ b/ee/lib/gitlab/llm/chain/tools/refactor_code/executor.rb @@ -64,12 +64,6 @@ def self.slash_commands SLASH_COMMANDS end - def perform - Answer.new(status: :ok, context: context, content: request, tool: nil) - rescue StandardError - Answer.error_answer(context: context, content: _("Unexpected error")) - end - private def selected_text_options diff --git a/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb b/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb index 99927db1f9e98..a4f517e06cf50 100644 --- a/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb +++ b/ee/lib/gitlab/llm/chain/tools/slash_command_tool.rb @@ -7,6 +7,14 @@ module Tools class SlashCommandTool < Tool extend ::Gitlab::Utils::Override + def perform + content = request(&streamed_request_handler(StreamedAnswer.new)) + + Answer.new(status: :ok, context: context, content: content, tool: nil) + rescue StandardError + Answer.error_answer(context: context, content: _("Unexpected error")) + end + private attr_reader :command diff --git a/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb b/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb index 3e5e25b68ac2f..57006e86617c2 100644 --- a/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb +++ b/ee/lib/gitlab/llm/chain/tools/write_tests/executor.rb @@ -63,12 +63,6 @@ def self.slash_commands SLASH_COMMANDS end - def perform - Answer.new(status: :ok, context: context, content: request, tool: nil) - rescue StandardError - Answer.error_answer(context: context, content: _("Unexpected error")) - end - private def authorize diff --git a/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb b/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb index cc961659a0a84..ecc78cb103b2a 100644 --- a/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb +++ b/ee/spec/lib/gitlab/llm/chain/tools/explain_code/executor_spec.rb @@ -8,6 +8,7 @@ let(:ai_request_double) { instance_double(Gitlab::Llm::Chain::Requests::Anthropic) } let(:input) { 'input' } let(:options) { { input: input } } + let(:stream_response_handler) { nil } let(:command) { nil } let(:context) do @@ -22,7 +23,11 @@ ) end - subject(:tool) { described_class.new(context: context, options: options, command: command) } + subject(:tool) do + described_class.new( + context: context, options: options, stream_response_handler: stream_response_handler, command: command + ) + end describe '#name' do it 'returns tool name' do diff --git a/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb b/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb index 0c77d6a38e0c7..22503102fd974 100644 --- a/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb +++ b/ee/spec/lib/gitlab/llm/chain/tools/refactor_code/executor_spec.rb @@ -8,6 +8,7 @@ let(:ai_request_double) { instance_double(Gitlab::Llm::Chain::Requests::Anthropic) } let(:input) { 'input' } let(:options) { { input: input } } + let(:stream_response_handler) { nil } let(:command) { nil } let(:context) do @@ -22,7 +23,11 @@ ) end - subject(:tool) { described_class.new(context: context, options: options, command: command) } + subject(:tool) do + described_class.new( + context: context, options: options, stream_response_handler: stream_response_handler, command: command + ) + end describe '#name' do it 'returns tool name' do diff --git a/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb b/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb index b5d7e440b25b5..aac6991990f26 100644 --- a/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb +++ b/ee/spec/lib/gitlab/llm/chain/tools/write_tests/executor_spec.rb @@ -8,6 +8,7 @@ let(:ai_request_double) { instance_double(Gitlab::Llm::Chain::Requests::Anthropic) } let(:input) { 'input' } let(:options) { { input: input } } + let(:stream_response_handler) { nil } let(:command) { nil } let(:context) do @@ -17,7 +18,11 @@ ) end - subject(:tool) { described_class.new(context: context, options: options, command: command) } + subject(:tool) do + described_class.new( + context: context, options: options, stream_response_handler: stream_response_handler, command: command + ) + end describe '#name' do it 'returns tool name' do diff --git a/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb b/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb index 00b5d3fa8f668..528ad652c1d85 100644 --- a/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb +++ b/ee/spec/support/shared_examples/lib/gitlab/llm/chain/slash_command_tool_shared_examples.rb @@ -84,4 +84,34 @@ tool.execute end end + + context 'when stream_response_service is set' do + let(:stream_response_handler) { instance_double(::Gitlab::Llm::ResponseService) } + + before do + allow(ai_request_double).to receive(:request).and_yield("Hello").and_yield(" World") + end + + it 'streams the final answer' do + first_response_double = double + second_response_double = double + + allow(Gitlab::Llm::Chain::PlainResponseModifier).to receive(:new).with("Hello") + .and_return(first_response_double) + + allow(Gitlab::Llm::Chain::PlainResponseModifier).to receive(:new).with(" World") + .and_return(second_response_double) + + expect(stream_response_handler).to receive(:execute).with( + response: first_response_double, + options: { chunk_id: 1 } + ) + expect(stream_response_handler).to receive(:execute).with( + response: second_response_double, + options: { chunk_id: 2 } + ) + + tool.execute + end + end end -- GitLab