From e085cf3c68493a540debe02f8db9ab1544d2357a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Alejandro=20Rodr=C3=ADguez?= <alejandro@gitlab.com>
Date: Sat, 17 Feb 2024 16:34:07 -0500
Subject: [PATCH] Fix Duo Chat messages being saved twice

We support two response handling workflows. The default response handler
is always executed, and then there's an optional streaming response
handler. Both of them currently save the Chat message, which results in
duplicated messages in our Chat storage. To address this, we introduce a
new toggle, and only save the message on the default response handler,
and skip it on the streaming one
---
 ee/lib/gitlab/llm/completions/chat.rb         |  3 +-
 .../graphql_subscription_response_service.rb  |  8 ++--
 ee/lib/gitlab/llm/response_service.rb         |  5 ++-
 .../lib/gitlab/llm/completions/chat_spec.rb   |  2 +-
 ...phql_subscription_response_service_spec.rb | 38 +++++++++++--------
 .../lib/gitlab/llm/response_service_spec.rb   |  7 +++-
 6 files changed, 39 insertions(+), 24 deletions(-)

diff --git a/ee/lib/gitlab/llm/completions/chat.rb b/ee/lib/gitlab/llm/completions/chat.rb
index 3997b55f696d..59e157b3026a 100644
--- a/ee/lib/gitlab/llm/completions/chat.rb
+++ b/ee/lib/gitlab/llm/completions/chat.rb
@@ -70,7 +70,8 @@ def execute
 
           # Send full message to custom clientSubscriptionId at the end of streaming.
           if response_options[:client_subscription_id]
-            ::Gitlab::Llm::ResponseService.new(context, response_options).execute(response: response_modifier)
+            ::Gitlab::Llm::ResponseService.new(context, response_options)
+              .execute(response: response_modifier, save_message: false)
           end
 
           response_handler.execute(response: response_modifier)
diff --git a/ee/lib/gitlab/llm/graphql_subscription_response_service.rb b/ee/lib/gitlab/llm/graphql_subscription_response_service.rb
index cac52ab5af9a..13906175d81f 100644
--- a/ee/lib/gitlab/llm/graphql_subscription_response_service.rb
+++ b/ee/lib/gitlab/llm/graphql_subscription_response_service.rb
@@ -3,11 +3,12 @@
 module Gitlab
   module Llm
     class GraphqlSubscriptionResponseService < BaseService
-      def initialize(user, resource, response_modifier, options:)
+      def initialize(user, resource, response_modifier, options:, save_message: true)
         @user = user
         @resource = resource
         @response_modifier = response_modifier
         @options = options
+        @save_message = save_message
         @logger = Gitlab::Llm::Logger.build
       end
 
@@ -39,10 +40,11 @@ def execute
 
       private
 
-      attr_reader :user, :resource, :response_modifier, :options, :logger
+      attr_reader :user, :resource, :response_modifier, :options, :save_message, :logger
 
       def save_message?
-        response_message.is_a?(ChatMessage) &&
+        save_message &&
+          response_message.is_a?(ChatMessage) &&
           !response_message.type &&
           !response_message.chunk_id
       end
diff --git a/ee/lib/gitlab/llm/response_service.rb b/ee/lib/gitlab/llm/response_service.rb
index b701937d8f72..709dc0443e67 100644
--- a/ee/lib/gitlab/llm/response_service.rb
+++ b/ee/lib/gitlab/llm/response_service.rb
@@ -9,10 +9,11 @@ def initialize(context, basic_options)
         @basic_options = basic_options
       end
 
-      def execute(response:, options: {})
+      def execute(response:, options: {}, save_message: true)
         ::Gitlab::Llm::GraphqlSubscriptionResponseService
           .new(user, resource, response,
-            options: basic_options.merge(options))
+            options: basic_options.merge(options),
+            save_message: save_message)
           .execute
       end
 
diff --git a/ee/spec/lib/gitlab/llm/completions/chat_spec.rb b/ee/spec/lib/gitlab/llm/completions/chat_spec.rb
index 6ffe9c80a727..3de78b5ac26f 100644
--- a/ee/spec/lib/gitlab/llm/completions/chat_spec.rb
+++ b/ee/spec/lib/gitlab/llm/completions/chat_spec.rb
@@ -140,7 +140,7 @@
           an_instance_of(Gitlab::Llm::Chain::GitlabContext), { request_id: 'uuid', ai_action: :chat,
 client_subscription_id: 'someid' }
         ).and_return(stream_response_handler).twice
-        expect(stream_response_handler).to receive(:execute)
+        expect(stream_response_handler).to receive(:execute).with(response: anything, save_message: false)
         expect(categorize_service).to receive(:execute)
         expect(::Llm::ExecuteMethodService).to receive(:new)
           .with(user, user, :categorize_question, categorize_service_params)
diff --git a/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb b/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb
index a3ff19fb21c6..6fb61ddb062e 100644
--- a/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb
+++ b/ee/spec/lib/gitlab/llm/graphql_subscription_response_service_spec.rb
@@ -75,12 +75,26 @@
   end
 
   describe '#execute' do
-    let(:service) { described_class.new(user, resource, response_modifier, options: options) }
+    let(:save_message) { true }
+    let(:service) do
+      described_class.new(user, resource, response_modifier, options: options, save_message: save_message)
+    end
+
     let_it_be(:resource) { project }
 
     subject { service.execute }
 
     context 'when message is chat' do
+      shared_examples 'not saving the message' do
+        it 'does not save the message' do
+          expect_next_instance_of(::Gitlab::Llm::AiMessage) do |instance|
+            expect(instance).not_to receive(:save!)
+          end
+
+          subject
+        end
+      end
+
       let(:ai_action) { 'chat' }
 
       it 'saves the message' do
@@ -89,28 +103,22 @@
         subject
       end
 
+      context 'when save_message is false' do
+        let(:save_message) { false }
+
+        it_behaves_like 'not saving the message'
+      end
+
       context 'when message is stream chunk' do
         let(:options) { super().merge(chunk_id: 1) }
 
-        it 'does not save the message' do
-          expect_next_instance_of(::Gitlab::Llm::AiMessage) do |instance|
-            expect(instance).not_to receive(:save!)
-          end
-
-          subject
-        end
+        it_behaves_like 'not saving the message'
       end
 
       context 'when message has special type' do
         let(:options) { super().merge(type: 'tool') }
 
-        it 'does not save the message' do
-          expect_next_instance_of(::Gitlab::Llm::AiMessage) do |instance|
-            expect(instance).not_to receive(:save!)
-          end
-
-          subject
-        end
+        it_behaves_like 'not saving the message'
       end
     end
 
diff --git a/ee/spec/lib/gitlab/llm/response_service_spec.rb b/ee/spec/lib/gitlab/llm/response_service_spec.rb
index 40a294733d9e..27125c4b68e1 100644
--- a/ee/spec/lib/gitlab/llm/response_service_spec.rb
+++ b/ee/spec/lib/gitlab/llm/response_service_spec.rb
@@ -11,15 +11,18 @@
 
   let(:basic_options) { { cache_request: true } }
   let(:options) { { cache_request: false } }
+  let(:save_message) { false }
   let(:graphql_subscription_double) { instance_double(::Gitlab::Llm::GraphqlSubscriptionResponseService) }
 
   describe '#execute' do
     it 'calls GraphQL subscription service with the right params' do
       expect(graphql_subscription_double).to receive(:execute)
       expect(::Gitlab::Llm::GraphqlSubscriptionResponseService).to receive(:new)
-        .with(user, issue, 'response', options: { cache_request: false }).and_return(graphql_subscription_double)
+        .with(user, issue, 'response', options: options, save_message: save_message)
+        .and_return(graphql_subscription_double)
 
-      described_class.new(context, basic_options).execute(response: 'response', options: options)
+      described_class.new(context, basic_options)
+        .execute(response: 'response', options: options, save_message: save_message)
     end
   end
 end
-- 
GitLab