diff --git a/ee/lib/gitlab/llm/ai_gateway/client.rb b/ee/lib/gitlab/llm/ai_gateway/client.rb index 79c4a9f0e251fb2d9e75a9993d2541644cf07b53..8cebc89e9ca1f98fb807bd4aa53d498f38ead78e 100644 --- a/ee/lib/gitlab/llm/ai_gateway/client.rb +++ b/ee/lib/gitlab/llm/ai_gateway/client.rb @@ -15,6 +15,8 @@ class Client DEFAULT_TYPE = 'prompt' DEFAULT_SOURCE = 'GitLab EE' + ALLOWED_PAYLOAD_PARAM_KEYS = %i[temperature max_tokens_to_sample stop_sequences].freeze + def initialize(user, tracking_context: {}) @user = user @tracking_context = tracking_context @@ -124,10 +126,19 @@ def request_body(prompt:, options: {}) payload: { content: prompt, provider: DEFAULT_PROVIDER, - model: DEFAULT_MODEL - } - }] - }.merge(options) + model: options.fetch(:model, DEFAULT_MODEL) + }.merge(payload_params(options)) + }], + stream: options.fetch(:stream, false) + } + end + + def payload_params(options) + params = options.slice(*ALLOWED_PAYLOAD_PARAM_KEYS) + + return {} if params.empty? + + { params: params } end def token_size(content) diff --git a/ee/lib/gitlab/llm/chain/requests/ai_gateway.rb b/ee/lib/gitlab/llm/chain/requests/ai_gateway.rb index c0c4a73634ddd02ce6a78d1dae98a34dd6ac82e4..9692b016c1af4c460552f4fa3d990bfc1432cb3c 100644 --- a/ee/lib/gitlab/llm/chain/requests/ai_gateway.rb +++ b/ee/lib/gitlab/llm/chain/requests/ai_gateway.rb @@ -7,6 +7,10 @@ module Requests class AiGateway < Base attr_reader :ai_client + TEMPERATURE = 0.1 + STOP_WORDS = ["\n\nHuman", "Observation:"].freeze + DEFAULT_MAX_TOKENS = 2048 + def initialize(user, tracking_context: {}) @user = user @ai_client = ::Gitlab::Llm::AiGateway::Client.new(user, tracking_context: tracking_context) @@ -27,7 +31,11 @@ def request(prompt) attr_reader :user, :logger def default_options - {} + { + temperature: TEMPERATURE, + stop_sequences: STOP_WORDS, + max_tokens_to_sample: DEFAULT_MAX_TOKENS + } end end end diff --git a/ee/spec/lib/gitlab/llm/ai_gateway/client_spec.rb b/ee/spec/lib/gitlab/llm/ai_gateway/client_spec.rb index 6456c0e17a4c9a3d928658f4a0ed24f9afb44582..3b03b2e0c8cc318769c511e06596b1935e174eaf 100644 --- a/ee/spec/lib/gitlab/llm/ai_gateway/client_spec.rb +++ b/ee/spec/lib/gitlab/llm/ai_gateway/client_spec.rb @@ -42,7 +42,8 @@ provider: described_class::DEFAULT_PROVIDER, model: described_class::DEFAULT_MODEL } - }] + }], + stream: false } end diff --git a/ee/spec/lib/gitlab/llm/chain/requests/ai_gateway_spec.rb b/ee/spec/lib/gitlab/llm/chain/requests/ai_gateway_spec.rb index 232659fd7e5ad1fe053c43723aa89219273cc650..2450b59039daa4f1c05d3c8265cae354ca61e8da 100644 --- a/ee/spec/lib/gitlab/llm/chain/requests/ai_gateway_spec.rb +++ b/ee/spec/lib/gitlab/llm/chain/requests/ai_gateway_spec.rb @@ -22,7 +22,10 @@ let(:response) { 'Hello World' } let(:expected_params) do { - prompt: "some user request" + prompt: "some user request", + max_tokens_to_sample: 2048, + stop_sequences: ["\n\nHuman", "Observation:"], + temperature: 0.1 } end