diff --git a/doc/api/graphql/reference/index.md b/doc/api/graphql/reference/index.md index 4954f88e8a773a6cbaf6935ba201cb2ff8f06a6d..7e29e15f2783fdd647e348af98258b934d244a77 100644 --- a/doc/api/graphql/reference/index.md +++ b/doc/api/graphql/reference/index.md @@ -12564,6 +12564,7 @@ Information about a connected Agent. | Name | Type | Description | | ---- | ---- | ----------- | +| <a id="airesponsechunkid"></a>`chunkId` | [`Int`](#int) | Incremental ID for a chunk from a streamed response. Null when it is not a streamed response. | | <a id="airesponseerrors"></a>`errors` | [`[String!]`](#string) | Errors return by AI API as response. | | <a id="airesponserequestid"></a>`requestId` | [`String`](#string) | ID of the original request. | | <a id="airesponseresponsebody"></a>`responseBody` | [`String`](#string) | Response body from AI API. | diff --git a/ee/app/graphql/subscriptions/ai_completion_response.rb b/ee/app/graphql/subscriptions/ai_completion_response.rb index f75900925a1fe081f1bba5b4fbff52b7dda54317..3469e9a5cc90e622e790dfa9ff4cb5f3860c6376 100644 --- a/ee/app/graphql/subscriptions/ai_completion_response.rb +++ b/ee/app/graphql/subscriptions/ai_completion_response.rb @@ -24,7 +24,8 @@ def update(*_args) role: object[:role], errors: object[:errors], timestamp: object[:timestamp], - type: object[:type] + type: object[:type], + chunk_id: object[:chunk_id] } end diff --git a/ee/app/graphql/types/ai/ai_response_type.rb b/ee/app/graphql/types/ai/ai_response_type.rb index 21cac39fb2f8b9cb31666ee3ef97ee44c1563242..989dcd5f1a5dbc98a04aae7f737bfb168e5e483f 100644 --- a/ee/app/graphql/types/ai/ai_response_type.rb +++ b/ee/app/graphql/types/ai/ai_response_type.rb @@ -32,6 +32,11 @@ class AiResponseType < BaseObject null: false, description: 'Message timestamp.' + field :chunk_id, + GraphQL::Types::Int, + null: true, + description: 'Incremental ID for a chunk from a streamed response. Null when it is not a streamed response.' + field :errors, [GraphQL::Types::String], null: true, description: 'Errors return by AI API as response.' diff --git a/ee/config/feature_flags/development/stream_gitlab_duo.yml b/ee/config/feature_flags/development/stream_gitlab_duo.yml new file mode 100644 index 0000000000000000000000000000000000000000..7d3cef90cd21567e6debdb6cb7571fdddf36c3db --- /dev/null +++ b/ee/config/feature_flags/development/stream_gitlab_duo.yml @@ -0,0 +1,8 @@ +--- +name: stream_gitlab_duo +introduced_by_url: https://gitlab.com/gitlab-org/gitlab/-/merge_requests/129966 +rollout_issue_url: https://gitlab.com/gitlab-org/gitlab/-/issues/423457 +milestone: '16.4' +type: development +group: group::ai framework +default_enabled: false diff --git a/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb b/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb index 937b462b00dd1ad59584b53158a84ee096dfff34..29f74945d607849be11e0e4331cff9d231388c66 100644 --- a/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb +++ b/ee/lib/gitlab/llm/chain/agents/zero_shot/executor.rb @@ -24,18 +24,27 @@ class Executor # @param [String] user_input - a question from a user # @param [Array<Tool>] tools - an array of Tools defined in the tools module. # @param [GitlabContext] context - Gitlab context containing useful context information - def initialize(user_input:, tools:, context:, response_handler:) + # @param [ResponseService] response_handler - Handles returning the response to the client + # @param [ResponseService] stream_response_handler - Handles streaming chunks to the client + def initialize(user_input:, tools:, context:, response_handler:, stream_response_handler: nil) @user_input = user_input @tools = tools @context = context @iterations = 0 @logger = Gitlab::Llm::Logger.build @response_handler = response_handler + @stream_response_handler = stream_response_handler end def execute MAX_ITERATIONS.times do - answer = Answer.from_response(response_body: "Thought: #{request}", tools: tools, context: context) + thought = if stream_response_handler && Feature.enabled?(:stream_gitlab_duo, context.current_user) + execute_streamed_request + else + request + end + + answer = Answer.from_response(response_body: "Thought: #{thought}", tools: tools, context: context) return answer if answer.is_final? @@ -72,11 +81,30 @@ def execute private + def execute_streamed_request + streamed_answer = StreamedAnswer.new + + request do |content| + chunk = streamed_answer.next_chunk(content) + + if chunk + stream_response_handler.execute( + response: Gitlab::Llm::Chain::PlainResponseModifier.new(content), + options: { + cache_response: false, + role: ::Gitlab::Llm::Cache::ROLE_ASSISTANT, + chunk_id: chunk[:id] + } + ) + end + end + end + def tools_cycle? context.tools_used.size != context.tools_used.uniq.size end - attr_reader :logger + attr_reader :logger, :stream_response_handler # This method should not be memoized because the input variables change over time def base_prompt diff --git a/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb b/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb index fc1b3fba3451ebb8ee20a893ff327a24e4f71982..4e11b190cbec38192d1b6c090e0e0554b9a5b6c0 100644 --- a/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb +++ b/ee/lib/gitlab/llm/chain/concerns/ai_dependent.rb @@ -11,9 +11,9 @@ def prompt provider_prompt_class.prompt(options) end - def request + def request(&block) logger.debug(message: "Prompt", class: self.class.to_s, content: prompt) - ai_request.request(prompt) + ai_request.request(prompt, &block) end private diff --git a/ee/lib/gitlab/llm/chain/plain_response_modifier.rb b/ee/lib/gitlab/llm/chain/plain_response_modifier.rb new file mode 100644 index 0000000000000000000000000000000000000000..49f69af803b76c456a8cab6bf253ed0d73e4e86d --- /dev/null +++ b/ee/lib/gitlab/llm/chain/plain_response_modifier.rb @@ -0,0 +1,21 @@ +# frozen_string_literal: true + +module Gitlab + module Llm + module Chain + class PlainResponseModifier < Gitlab::Llm::BaseResponseModifier + def initialize(answer) + @ai_response = answer + end + + def response_body + @ai_response + end + + def errors + [] + end + end + end + end +end diff --git a/ee/lib/gitlab/llm/chain/requests/anthropic.rb b/ee/lib/gitlab/llm/chain/requests/anthropic.rb index 9dd66a83b2e2a6c76c72fd26ed07f17bad15884e..ff7582d247f12bc7a1d6f6831fc01db491c833c8 100644 --- a/ee/lib/gitlab/llm/chain/requests/anthropic.rb +++ b/ee/lib/gitlab/llm/chain/requests/anthropic.rb @@ -12,18 +12,34 @@ class Anthropic < Base PROMPT_SIZE = 30_000 def initialize(user) + @user = user @ai_client = ::Gitlab::Llm::Anthropic::Client.new(user) + @logger = Gitlab::Llm::Logger.build end def request(prompt) - ai_client.complete( - prompt: prompt[:prompt], - **default_options.merge(prompt.fetch(:options, {})) - )&.dig("completion").to_s.strip + if Feature.enabled?(:stream_gitlab_duo, user) + ai_client.stream( + prompt: prompt[:prompt], + **default_options.merge(prompt.fetch(:options, {})) + ) do |data| + logger.info(message: "Streaming error", error: data&.dig("error")) if data&.dig("error") + + content = data&.dig("completion").to_s + yield content if block_given? + end + else + ai_client.complete( + prompt: prompt[:prompt], + **default_options.merge(prompt.fetch(:options, {})) + )&.dig("completion").to_s.strip + end end private + attr_reader :user, :logger + def default_options { temperature: TEMPERATURE, diff --git a/ee/lib/gitlab/llm/chain/streamed_answer.rb b/ee/lib/gitlab/llm/chain/streamed_answer.rb new file mode 100644 index 0000000000000000000000000000000000000000..0079f0f66abfe58c649c49ecd64f88f55e09d9b8 --- /dev/null +++ b/ee/lib/gitlab/llm/chain/streamed_answer.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +module Gitlab + module Llm + module Chain + class StreamedAnswer + def initialize + self.final_answer_started = false + self.id = 0 + self.full_message = "" + end + + def next_chunk(content) + # If it already contains the final answer, we can return the content directly. + # There is then also no longer the need to build the full message. + return payload(content) if final_answer_started + + self.full_message += content + + return unless final_answer_start.present? + + self.final_answer_started = true + payload(final_answer_start.lstrip) + end + + private + + attr_accessor :full_message, :id, :final_answer_started + + def payload(content) + self.id += 1 + + { content: content, id: id } + end + + # The ChainOfThoughtParser would treat a response without any "Final Answer:" in the response + # as an answer. Because we do not have the full response when parsing the stream, we need to rely + # on the fact that everything after "Final Answer:" will be the final answer. + def final_answer_start + /Final Answer:(?<final_answer>.+)/m =~ full_message + + final_answer + end + end + end + end +end diff --git a/ee/lib/gitlab/llm/completions/chat.rb b/ee/lib/gitlab/llm/completions/chat.rb index 080089fb1dc92866f8af7e3f620b0eb9ec2374a3..80f90b44be7ff9b995db04de26bd95c4f121e095 100644 --- a/ee/lib/gitlab/llm/completions/chat.rb +++ b/ee/lib/gitlab/llm/completions/chat.rb @@ -23,13 +23,20 @@ def execute(user, resource, options) extra_resource: options.delete(:extra_resource) || {} ) - response_handler = ::Gitlab::Llm::ResponseService.new(context, response_options) + response_handler = ::Gitlab::Llm::ResponseService + .new(context, response_options.except(:client_subscription_id)) + + stream_response_handler = nil + if response_options[:client_subscription_id] + stream_response_handler = ::Gitlab::Llm::ResponseService.new(context, response_options) + end response = Gitlab::Llm::Chain::Agents::ZeroShot::Executor.new( user_input: options[:content], tools: tools(user), context: context, - response_handler: response_handler + response_handler: response_handler, + stream_response_handler: stream_response_handler ).execute Gitlab::Metrics::Sli::Apdex[:llm_chat_answers].increment( @@ -39,8 +46,6 @@ def execute(user, resource, options) response_modifier = Gitlab::Llm::Chain::ResponseModifier.new(response) - response_handler.execute(response: response_modifier) - context.tools_used.each do |tool| Gitlab::Tracking.event( self.class.to_s, @@ -52,6 +57,8 @@ def execute(user, resource, options) value: response.status == :ok ? 1 : 0 ) end + + response_handler.execute(response: response_modifier) end def tools(user) diff --git a/ee/lib/gitlab/llm/graphql_subscription_response_service.rb b/ee/lib/gitlab/llm/graphql_subscription_response_service.rb index f6f3f72051d47ae5c918575b5be67439e91c910b..1fb89c74c167c4a6297ae7672a8f20b3f5e7a014 100644 --- a/ee/lib/gitlab/llm/graphql_subscription_response_service.rb +++ b/ee/lib/gitlab/llm/graphql_subscription_response_service.rb @@ -23,7 +23,8 @@ def execute errors: response_modifier.errors, role: options[:role] || Cache::ROLE_ASSISTANT, timestamp: Time.current, - type: options.fetch(:type, nil) + type: options.fetch(:type, nil), + chunk_id: options.fetch(:chunk_id, nil) } logger.debug( diff --git a/ee/spec/lib/gitlab/llm/chain/agents/zero_shot/executor_spec.rb b/ee/spec/lib/gitlab/llm/chain/agents/zero_shot/executor_spec.rb index f96efa95fcb89b9fd2000a13f7fbd5e1556b46aa..df0e009b2aba6edeecc883e2e1ee95a3c7600bd3 100644 --- a/ee/spec/lib/gitlab/llm/chain/agents/zero_shot/executor_spec.rb +++ b/ee/spec/lib/gitlab/llm/chain/agents/zero_shot/executor_spec.rb @@ -14,6 +14,7 @@ let(:response_double) { "I know the final answer\nFinal Answer: FooBar" } let(:resource) { user } let(:response_service_double) { instance_double(::Gitlab::Llm::ResponseService) } + let(:stream_response_service_double) { nil } let(:context) do Gitlab::Llm::Chain::GitlabContext.new( @@ -28,7 +29,8 @@ user_input: input, tools: tools, context: context, - response_handler: response_service_double + response_handler: response_service_double, + stream_response_handler: stream_response_service_double ) end @@ -46,80 +48,117 @@ .and_return(tool_double) end - it 'executes associated tools and adds observations during the execution' do - answer = agent.execute - - expect(answer.is_final).to eq(true) - expect(answer.content).to include('FooBar') - end - - context 'without final answer' do + context 'when streaming is disabled' do before do - # just limiting the number of iterations here from 10 to 2 - stub_const("#{described_class.name}::MAX_ITERATIONS", 2) + stub_feature_flags(stream_gitlab_duo: false) end it 'executes associated tools and adds observations during the execution' do - logger = instance_double(Gitlab::Llm::Logger) - - expect(Gitlab::Llm::Logger).to receive(:build).at_least(:once).and_return(logger) - expect(logger).to receive(:info).with(hash_including(message: "Tool cycling detected")).exactly(2) - expect(logger).to receive(:info).at_least(:once) - expect(logger).to receive(:debug).at_least(:once) - expect(response_service_double).to receive(:execute).at_least(:once) - - allow(agent).to receive(:request).and_return("Action: IssueIdentifier\nAction Input: #3") + answer = agent.execute - agent.execute + expect(answer.is_final).to eq(true) + expect(answer.content).to include('FooBar') end - context 'with the ai_tool_info flag switched off' do + context 'without final answer' do before do - stub_feature_flags(ai_tool_info: false) + # just limiting the number of iterations here from 10 to 2 + stub_const("#{described_class.name}::MAX_ITERATIONS", 2) end - it 'does not call response_service' do - expect(response_service_double).not_to receive(:execute) + it 'executes associated tools and adds observations during the execution' do + logger = instance_double(Gitlab::Llm::Logger) + + expect(Gitlab::Llm::Logger).to receive(:build).at_least(:once).and_return(logger) + expect(logger).to receive(:info).with(hash_including(message: "Tool cycling detected")).exactly(2) + expect(logger).to receive(:info).at_least(:once) + expect(logger).to receive(:debug).at_least(:once) + expect(response_service_double).to receive(:execute).at_least(:once) allow(agent).to receive(:request).and_return("Action: IssueIdentifier\nAction Input: #3") agent.execute end + + context 'with the ai_tool_info flag switched off' do + before do + stub_feature_flags(ai_tool_info: false) + end + + it 'does not call response_service' do + expect(response_service_double).not_to receive(:execute) + + allow(agent).to receive(:request).and_return("Action: IssueIdentifier\nAction Input: #3") + + agent.execute + end + end end - end - context 'when max iterations reached' do - it 'returns' do - stub_const("#{described_class.name}::MAX_ITERATIONS", 2) + context 'when max iterations reached' do + it 'returns' do + stub_const("#{described_class.name}::MAX_ITERATIONS", 2) + + allow(agent).to receive(:request).and_return("Action: IssueIdentifier\nAction Input: #3") + expect(agent).to receive(:request).twice.times + expect(response_service_double).to receive(:execute).at_least(:once) - allow(agent).to receive(:request).and_return("Action: IssueIdentifier\nAction Input: #3") - expect(agent).to receive(:request).twice.times - expect(response_service_double).to receive(:execute).at_least(:once) + answer = agent.execute - answer = agent.execute + expect(answer.is_final?).to eq(true) + expect(answer.content).to include(Gitlab::Llm::Chain::Answer.default_final_message) + end + end + + context 'when answer is final' do + let(:response_content_1) { "Thought: I know final answer\nFinal Answer: Foo" } - expect(answer.is_final?).to eq(true) - expect(answer.content).to include(Gitlab::Llm::Chain::Answer.default_final_message) + it 'returns final answer' do + answer = agent.execute + + expect(answer.is_final?).to eq(true) + end end - end - context 'when answer is final' do - let(:response_content_1) { "Thought: I know final answer\nFinal Answer: Foo" } + context 'when tool answer if final' do + let(:tool_answer) { instance_double(Gitlab::Llm::Chain::Answer, is_final?: true) } - it 'returns final answer' do - answer = agent.execute + it 'returns final answer' do + answer = agent.execute - expect(answer.is_final?).to eq(true) + expect(answer.is_final?).to eq(true) + end end end - context 'when tool answer if final' do - let(:tool_answer) { instance_double(Gitlab::Llm::Chain::Answer, is_final?: true) } + context 'when streaming is enabled' do + let(:stream_response_service_double) { instance_double(::Gitlab::Llm::ResponseService) } - it 'returns final answer' do - answer = agent.execute + before do + stub_feature_flags(stream_gitlab_duo: true) + allow(ai_request_double).to receive(:request).and_yield("Final Answer:").and_yield("Hello").and_yield(" World") + end + + it 'streams the final answer' do + first_response_double = double + second_response_double = double - expect(answer.is_final?).to eq(true) + allow(Gitlab::Llm::Chain::PlainResponseModifier).to receive(:new).with("Hello") + .and_return(first_response_double) + + allow(Gitlab::Llm::Chain::PlainResponseModifier).to receive(:new).with(" World") + .and_return(second_response_double) + + expect(stream_response_service_double).to receive(:execute).with( + response: first_response_double, + options: { cache_response: false, role: ::Gitlab::Llm::Cache::ROLE_ASSISTANT, chunk_id: 1 } + ) + expect(stream_response_service_double).to receive(:execute).with( + response: second_response_double, + options: { cache_response: false, role: ::Gitlab::Llm::Cache::ROLE_ASSISTANT, chunk_id: 2 } + ) + + agent.execute end end end diff --git a/ee/spec/lib/gitlab/llm/chain/concerns/ai_dependent_spec.rb b/ee/spec/lib/gitlab/llm/chain/concerns/ai_dependent_spec.rb index a535353291724138425ba65768267fedc4b1ff34..9dee0aab141a7ba19105c87f3e00d2a8f24d5283 100644 --- a/ee/spec/lib/gitlab/llm/chain/concerns/ai_dependent_spec.rb +++ b/ee/spec/lib/gitlab/llm/chain/concerns/ai_dependent_spec.rb @@ -3,18 +3,18 @@ require 'spec_helper' RSpec.describe Gitlab::Llm::Chain::Concerns::AiDependent, feature_category: :duo_chat do - describe '#prompt' do - let(:options) { { suggestions: "", input: "" } } - let(:ai_request) { ::Gitlab::Llm::Chain::Requests::Anthropic.new(double) } - let(:context) do - ::Gitlab::Llm::Chain::GitlabContext.new( - current_user: double, - container: double, - resource: double, - ai_request: ai_request - ) - end + let(:options) { { suggestions: "", input: "" } } + let(:ai_request) { ::Gitlab::Llm::Chain::Requests::Anthropic.new(double) } + let(:context) do + ::Gitlab::Llm::Chain::GitlabContext.new( + current_user: double, + container: double, + resource: double, + ai_request: ai_request + ) + end + describe '#prompt' do context 'when prompt is called' do it 'returns provider specific prompt' do tool = ::Gitlab::Llm::Chain::Tools::IssueIdentifier::Executor.new(context: context, options: options) @@ -65,4 +65,23 @@ def provider_prompt_class end end end + + describe '#request' do + it 'passes prompt to the ai_client' do + tool = ::Gitlab::Llm::Chain::Tools::IssueIdentifier::Executor.new(context: context, options: options) + + expect(ai_request).to receive(:request).with(tool.prompt) + + tool.request + end + + it 'passes blocks forward to the ai_client' do + b = proc { "something" } + tool = ::Gitlab::Llm::Chain::Tools::IssueIdentifier::Executor.new(context: context, options: options) + + expect(ai_request).to receive(:request).with(tool.prompt, &b) + + tool.request(&b) + end + end end diff --git a/ee/spec/lib/gitlab/llm/chain/plain_response_modifier_spec.rb b/ee/spec/lib/gitlab/llm/chain/plain_response_modifier_spec.rb new file mode 100644 index 0000000000000000000000000000000000000000..159a8a395c099e8f2759a535b143fcf781cc4424 --- /dev/null +++ b/ee/spec/lib/gitlab/llm/chain/plain_response_modifier_spec.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe Gitlab::Llm::Chain::PlainResponseModifier, feature_category: :shared do + let(:content) { "content" } + + context 'on success' do + subject { described_class.new(content).response_body } + + it { is_expected.to eq "content" } + end + + context 'on error' do + subject { described_class.new(content).errors } + + it { is_expected.to eq [] } + end +end diff --git a/ee/spec/lib/gitlab/llm/chain/requests/anthropic_spec.rb b/ee/spec/lib/gitlab/llm/chain/requests/anthropic_spec.rb index b9302fa76a05e01b76d0fc1eec83fde75797b866..8aa7c8fab4f6f4f201ba125f93d917f694eb1768 100644 --- a/ee/spec/lib/gitlab/llm/chain/requests/anthropic_spec.rb +++ b/ee/spec/lib/gitlab/llm/chain/requests/anthropic_spec.rb @@ -3,9 +3,11 @@ require 'spec_helper' RSpec.describe Gitlab::Llm::Chain::Requests::Anthropic, feature_category: :duo_chat do + let_it_be(:user) { build(:user) } + describe 'initializer' do it 'initializes the anthropic client' do - request = described_class.new(double) + request = described_class.new(user) expect(request.ai_client.class).to eq(::Gitlab::Llm::Anthropic::Client) end @@ -14,8 +16,10 @@ describe '#request' do subject(:request) { instance.request(params) } - let(:instance) { described_class.new(double) } + let(:instance) { described_class.new(user) } + let(:logger) { instance_double(Gitlab::Llm::Logger) } let(:ai_client) { double } + let(:response) { { "completion" => "Hello World " } } let(:expected_params) do { prompt: "some user request", @@ -25,26 +29,73 @@ end before do + allow(Gitlab::Llm::Logger).to receive(:build).and_return(logger) allow(instance).to receive(:ai_client).and_return(ai_client) end - context 'with prompt and options' do - let(:params) { { prompt: "some user request", options: { max_tokens: 4000 } } } + context 'when streaming is disabled' do + before do + stub_feature_flags(stream_gitlab_duo: false) + end + + context 'with prompt and options' do + let(:params) { { prompt: "some user request", options: { max_tokens: 4000 } } } - it 'calls the anthropic completion endpoint' do - expect(ai_client).to receive(:complete).with(expected_params.merge({ max_tokens: 4000 })) + it 'calls the anthropic completion endpoint, parses response and strips it' do + expect(ai_client).to receive(:complete).with(expected_params.merge({ max_tokens: 4000 })).and_return(response) - request + expect(request).to eq("Hello World") + end + end + + context 'when options are not present' do + let(:params) { { prompt: "some user request" } } + + it 'calls the anthropic completion endpoint' do + expect(ai_client).to receive(:complete).with(expected_params) + + request + end end end - context 'when options are not present' do - let(:params) { { prompt: "some user request" } } + context 'when streaming is enabled' do + before do + stub_feature_flags(stream_gitlab_duo: true) + end + + context 'with prompt and options' do + let(:params) { { prompt: "some user request", options: { max_tokens: 4000 } } } + + it 'calls the anthropic streaming endpoint and yields response without stripping it' do + expect(ai_client).to receive(:stream).with(expected_params.merge({ max_tokens: 4000 })).and_yield(response) + + expect { |b| instance.request(params, &b) }.to yield_with_args( + "Hello World " + ) + end + end + + context 'when options are not present' do + let(:params) { { prompt: "some user request" } } + + it 'calls the anthropic streaming endpoint' do + expect(ai_client).to receive(:stream).with(expected_params) + + request + end + end + + context 'when stream errors' do + let(:params) { { prompt: "some user request" } } + let(:response) { { "error" => { "type" => "overload_error", message: "Overloaded" } } } - it 'calls the anthropic completion endpoint' do - expect(ai_client).to receive(:complete).with(expected_params) + it 'logs the error' do + expect(ai_client).to receive(:stream).with(expected_params).and_yield(response) + expect(logger).to receive(:info).with(hash_including(message: "Streaming error", error: response["error"])) - request + request + end end end end diff --git a/ee/spec/lib/gitlab/llm/chain/streamed_answer_spec.rb b/ee/spec/lib/gitlab/llm/chain/streamed_answer_spec.rb new file mode 100644 index 0000000000000000000000000000000000000000..4be16d3588f40bf3a1a68a71afad82174534aaac --- /dev/null +++ b/ee/spec/lib/gitlab/llm/chain/streamed_answer_spec.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +require 'fast_spec_helper' + +RSpec.describe Gitlab::Llm::Chain::StreamedAnswer, feature_category: :duo_chat do + let(:input) do + <<-INPUT + Thought: thought + Action: IssueIdentifier + Action Input: Bar + INPUT + end + + describe '#next_chunk' do + let(:streamed_answer) { described_class.new } + + context 'when stream is empty' do + it 'returns nil' do + expect(streamed_answer.next_chunk("")).to be_nil + end + end + + context 'when stream does not contain the final answer' do + it 'returns nil' do + expect(streamed_answer.next_chunk("Some")).to be_nil + expect(streamed_answer.next_chunk("Content")).to be_nil + end + end + + context 'when receiving thoughts and actions' do + it 'only returns the final answer', :aggregate_failures do + expect(streamed_answer.next_chunk("Thought: thought\n")).to be_nil + expect(streamed_answer.next_chunk("Action: IssueIdentifier\n")).to be_nil + expect(streamed_answer.next_chunk("Final Answer: Hello")).to eq({ id: 1, content: "Hello" }) + end + end + + context 'when receiving a final answer split up in multiple tokens', :aggregate_failures do + it 'returns the final answer once it is ready', :aggregate_failures do + expect(streamed_answer.next_chunk("Final Answer")).to be_nil + expect(streamed_answer.next_chunk(": ")).to be_nil + expect(streamed_answer.next_chunk("Hello")).to eq({ id: 1, content: "Hello" }) + expect(streamed_answer.next_chunk(" ")).to eq({ id: 2, content: " " }) + end + end + end +end diff --git a/ee/spec/lib/gitlab/llm/completions/chat_spec.rb b/ee/spec/lib/gitlab/llm/completions/chat_spec.rb index 259d559a3f2169489e20166b3bd6aa01a835b4f0..87512f96a3183cac28bf4f90304c81cbf758ac58 100644 --- a/ee/spec/lib/gitlab/llm/completions/chat_spec.rb +++ b/ee/spec/lib/gitlab/llm/completions/chat_spec.rb @@ -17,6 +17,7 @@ let(:blob) { fake_blob(path: 'file.md') } let(:extra_resource) { { blob: blob } } let(:options) { { request_id: 'uuid', content: content, extra_resource: extra_resource } } + let(:params) { { request_id: 'uuid' } } let(:container) { group } let(:context) do instance_double( @@ -35,8 +36,9 @@ end let(:response_handler) { instance_double(Gitlab::Llm::ResponseService) } + let(:stream_response_handler) { nil } - subject { described_class.new(nil, request_id: 'uuid').execute(user, resource, options) } + subject { described_class.new(nil, **params).execute(user, resource, options) } shared_examples 'success' do it 'calls the ZeroShot Agent with the right parameters', :snowplow do @@ -46,11 +48,13 @@ ::Gitlab::Llm::Chain::Tools::GitlabDocumentation, ::Gitlab::Llm::Chain::Tools::EpicIdentifier ] + expected_params = [ user_input: content, tools: match_array(tools), context: context, - response_handler: response_handler + response_handler: response_handler, + stream_response_handler: stream_response_handler ] expect_next_instance_of(::Gitlab::Llm::Chain::Agents::ZeroShot::Executor, *expected_params) do |instance| @@ -81,6 +85,36 @@ ) end + context 'when client_subscription_id is set' do + let(:params) { { request_id: 'uuid', content: content, client_subscription_id: 'someid' } } + let(:stream_response_handler) { instance_double(Gitlab::Llm::ResponseService) } + + it 'correctly initialzes response handlers' do + expected_params = [ + user_input: content, + tools: an_instance_of(Array), + context: an_instance_of(Gitlab::Llm::Chain::GitlabContext), + response_handler: response_handler, + stream_response_handler: stream_response_handler + ] + + expect_next_instance_of(::Gitlab::Llm::Chain::Agents::ZeroShot::Executor, *expected_params) do |instance| + expect(instance).to receive(:execute).and_return(answer) + end + + expect(response_handler).to receive(:execute) + expect(::Gitlab::Llm::ResponseService).to receive(:new).with( + an_instance_of(Gitlab::Llm::Chain::GitlabContext), { request_id: 'uuid' } + ).and_return(response_handler) + + expect(::Gitlab::Llm::ResponseService).to receive(:new).with( + an_instance_of(Gitlab::Llm::Chain::GitlabContext), { request_id: 'uuid', client_subscription_id: 'someid' } + ).and_return(stream_response_handler) + + subject + end + end + context 'with unsuccessful response' do let(:answer) do ::Gitlab::Llm::Chain::Answer.new( @@ -150,7 +184,8 @@ user_input: content, tools: match_array(tools), context: context, - response_handler: response_handler + response_handler: response_handler, + stream_response_handler: stream_response_handler ] expect_next_instance_of(::Gitlab::Llm::Chain::Agents::ZeroShot::Executor, *expected_params) do |instance| diff --git a/ee/spec/lib/gitlab/llm/completions/summarize_all_open_notes_spec.rb b/ee/spec/lib/gitlab/llm/completions/summarize_all_open_notes_spec.rb index bff94346942ffcb364f544bac282ea370fe6140a..c492525fb433b19853c4bd53f24afedec4edc81e 100644 --- a/ee/spec/lib/gitlab/llm/completions/summarize_all_open_notes_spec.rb +++ b/ee/spec/lib/gitlab/llm/completions/summarize_all_open_notes_spec.rb @@ -3,7 +3,7 @@ require 'spec_helper' RSpec.describe Gitlab::Llm::Completions::SummarizeAllOpenNotes, feature_category: :duo_chat do - let(:ai_response) { { "completion" => "some ai response text" } } + let(:ai_response) { "some ai response text" } let(:template_class) { nil } let(:ai_options) do { @@ -76,7 +76,7 @@ describe "#execute", :saas do let(:ai_request_class) { ::Gitlab::Llm::Anthropic::Client } - let(:completion_method) { :complete } + let(:completion_method) { :stream } let(:options) { { ai_provider: :anthropic } } let_it_be(:user) { create(:user) } @@ -130,7 +130,20 @@ let_it_be(:notes) { create_pair(:note_on_issue, project: project, noteable: issuable) } let_it_be(:system_note) { create(:note_on_issue, :system, project: project, noteable: issuable) } - it_behaves_like 'performs completion' + context 'when streaming is enabled' do + it_behaves_like 'performs completion' + end + + context 'when streaming is disabled' do + let(:completion_method) { :complete } + let(:ai_response) { { "completion" => "some ai response text" } } + + before do + stub_feature_flags(stream_gitlab_duo: false) + end + + it_behaves_like 'performs completion' + end context 'with vertex_ai provider' do let(:options) { { ai_provider: :vertex_ai } } diff --git a/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb b/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb index c5b144dd8f9ac4457e9e9ad72442144bae1c3d3c..ea7694f74b8925ee661a7d31a918f7b828362963 100644 --- a/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb +++ b/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb @@ -50,7 +50,8 @@ role: 'assistant', timestamp: an_instance_of(ActiveSupport::TimeWithZone), errors: [], - type: nil + type: nil, + chunk_id: nil } end diff --git a/ee/spec/lib/gitlab/llm/open_ai/completions/generate_test_file_spec.rb b/ee/spec/lib/gitlab/llm/open_ai/completions/generate_test_file_spec.rb index 0057efb6264799bf2b67aee54bcdfca14c1674e7..814a92e31780d468f1563fa0dccaa4bf2f4e721a 100644 --- a/ee/spec/lib/gitlab/llm/open_ai/completions/generate_test_file_spec.rb +++ b/ee/spec/lib/gitlab/llm/open_ai/completions/generate_test_file_spec.rb @@ -60,7 +60,8 @@ role: 'assistant', timestamp: an_instance_of(ActiveSupport::TimeWithZone), errors: [], - type: nil + type: nil, + chunk_id: nil } expect(GraphqlTriggers).to receive(:ai_completion_response).with( diff --git a/ee/spec/requests/api/graphql/subscriptions/ai_completion_response_spec.rb b/ee/spec/requests/api/graphql/subscriptions/ai_completion_response_spec.rb index 143003a2ad87cb65578245407962810af3411cc7..038444596ca8bff01b87c4636fccd88dfa62bcaf 100644 --- a/ee/spec/requests/api/graphql/subscriptions/ai_completion_response_spec.rb +++ b/ee/spec/requests/api/graphql/subscriptions/ai_completion_response_spec.rb @@ -42,7 +42,8 @@ request_id: request_id, content: content, role: ::Gitlab::Llm::Cache::ROLE_ASSISTANT, - errors: [] + errors: [], + chunk_id: nil } GraphqlTriggers.ai_completion_response(params, data) @@ -56,6 +57,7 @@ expect(ai_completion_response['role']).to eq('ASSISTANT') expect(ai_completion_response['requestId']).to eq(request_id) expect(ai_completion_response['errors']).to eq([]) + expect(ai_completion_response['chunk_id']).to eq(nil) end end @@ -132,6 +134,7 @@ def build_subscription_query(requested_user, params) role requestId errors + chunkId } } SUBSCRIPTION