From 010dc11145cd18c3a4e9dc27503a57cfabefe72a Mon Sep 17 00:00:00 2001
From: Madelein van Niekerk <mvanniekerk@gitlab.com>
Date: Thu, 20 Feb 2025 11:09:25 +0000
Subject: [PATCH] ActiveContext Find partition by serializing collection

Changelog: added
---
 ee/app/models/ai/active_context/collection.rb |  4 ++
 .../ai/active_context/collection_spec.rb      | 25 ++++++++++
 .../lib/active_context/collection_cache.rb    | 46 +++++++++++++++++++
 .../lib/active_context/concerns/collection.rb | 11 ++++-
 .../lib/active_context/concerns/queue.rb      |  2 +-
 .../lib/active_context/config.rb              | 14 +++++-
 .../databases/concerns/executor.rb            |  8 ++++
 .../lib/active_context/hash.rb                | 10 ++++
 .../lib/active_context/reference.rb           | 42 +++++++++++------
 .../lib/active_context/shard.rb               |  9 ----
 .../lib/active_context/concerns/queue_spec.rb |  2 +-
 .../spec/lib/active_context/config_spec.rb    | 26 +++++++++++
 .../spec/lib/active_context/reference_spec.rb | 11 +----
 .../spec/lib/active_context/tracker_spec.rb   | 12 -----
 14 files changed, 175 insertions(+), 47 deletions(-)
 create mode 100644 gems/gitlab-active-context/lib/active_context/collection_cache.rb
 create mode 100644 gems/gitlab-active-context/lib/active_context/hash.rb
 delete mode 100644 gems/gitlab-active-context/lib/active_context/shard.rb

diff --git a/ee/app/models/ai/active_context/collection.rb b/ee/app/models/ai/active_context/collection.rb
index 69225d69c754..fb99d105c89a 100644
--- a/ee/app/models/ai/active_context/collection.rb
+++ b/ee/app/models/ai/active_context/collection.rb
@@ -8,6 +8,10 @@ class Collection < ApplicationRecord
       validates :name, presence: true, length: { maximum: 255 }
       validates :metadata, json_schema: { filename: 'ai_active_context_collection_metadata' }
       validates :number_of_partitions, presence: true, numericality: { greater_than_or_equal_to: 1, only_integer: true }
+
+      def partition_for(routing_value)
+        ::ActiveContext::Hash.consistent_hash(number_of_partitions, routing_value)
+      end
     end
   end
 end
diff --git a/ee/spec/models/ai/active_context/collection_spec.rb b/ee/spec/models/ai/active_context/collection_spec.rb
index e706ee2e4c37..dfaf183fecde 100644
--- a/ee/spec/models/ai/active_context/collection_spec.rb
+++ b/ee/spec/models/ai/active_context/collection_spec.rb
@@ -25,4 +25,29 @@
       expect(collection.errors[:metadata]).to include('must be a valid json schema')
     end
   end
+
+  describe '.partition_for' do
+    using RSpec::Parameterized::TableSyntax
+
+    let(:collection) { create(:ai_active_context_collection, number_of_partitions: 5) }
+
+    where(:routing_value, :partition_number) do
+      1 | 0
+      2 | 1
+      3 | 3
+      4 | 2
+      5 | 3
+      6 | 3
+      7 | 4
+      8 | 4
+      9 | 2
+      10 | 2
+    end
+
+    with_them do
+      it 'always returns the same partition for a routing value' do
+        expect(collection.partition_for(routing_value)).to eq(partition_number)
+      end
+    end
+  end
 end
diff --git a/gems/gitlab-active-context/lib/active_context/collection_cache.rb b/gems/gitlab-active-context/lib/active_context/collection_cache.rb
new file mode 100644
index 000000000000..ae8eb716d66b
--- /dev/null
+++ b/gems/gitlab-active-context/lib/active_context/collection_cache.rb
@@ -0,0 +1,46 @@
+# frozen_string_literal: true
+
+module ActiveContext
+  module CollectionCache
+    class << self
+      TTL = 1.minute
+
+      def collections
+        refresh_cache if cache_expired?
+
+        @collections ||= {}
+      end
+
+      def fetch(value)
+        by_id(value) || by_name(value)
+      end
+
+      def by_id(id)
+        collections[id]
+      end
+
+      def by_name(name)
+        collections.values.find { |collection| collection.name == name.to_s }
+      end
+
+      private
+
+      def cache_expired?
+        return true unless @last_refreshed_at
+
+        Time.current - @last_refreshed_at > TTL
+      end
+
+      def refresh_cache
+        new_collections = {}
+
+        Config.collection_model.find_each do |record|
+          new_collections[record.id] = record
+        end
+
+        @collections = new_collections
+        @last_refreshed_at = Time.current
+      end
+    end
+  end
+end
diff --git a/gems/gitlab-active-context/lib/active_context/concerns/collection.rb b/gems/gitlab-active-context/lib/active_context/concerns/collection.rb
index 049ae007bb83..5c4ee132bcbf 100644
--- a/gems/gitlab-active-context/lib/active_context/concerns/collection.rb
+++ b/gems/gitlab-active-context/lib/active_context/concerns/collection.rb
@@ -10,6 +10,10 @@ def track!(*objects)
           ActiveContext::Tracker.track!(objects, collection: self)
         end
 
+        def collection_name
+          raise NotImplementedError
+        end
+
         def queue
           raise NotImplementedError
         end
@@ -27,6 +31,10 @@ def reference_klasses
         def reference_klass
           nil
         end
+
+        def collection_record
+          ActiveContext::CollectionCache.fetch(collection_name)
+        end
       end
 
       attr_reader :object
@@ -38,9 +46,10 @@ def initialize(object)
       def references
         reference_klasses = Array.wrap(self.class.reference_klasses)
         routing = self.class.routing(object)
+        collection_id = self.class.collection_record.id
 
         reference_klasses.map do |reference_klass|
-          reference_klass.serialize(object, routing)
+          reference_klass.serialize(collection_id, routing, object)
         end
       end
     end
diff --git a/gems/gitlab-active-context/lib/active_context/concerns/queue.rb b/gems/gitlab-active-context/lib/active_context/concerns/queue.rb
index 685362dec7d8..9c3a9e3d9e6c 100644
--- a/gems/gitlab-active-context/lib/active_context/concerns/queue.rb
+++ b/gems/gitlab-active-context/lib/active_context/concerns/queue.rb
@@ -31,7 +31,7 @@ def register!
         end
 
         def push(references)
-          refs_by_shard = references.group_by { |ref| ActiveContext::Shard.shard_number(number_of_shards, ref) }
+          refs_by_shard = references.group_by { |ref| ActiveContext::Hash.consistent_hash(number_of_shards, ref) }
 
           ActiveContext::Redis.with_redis do |redis|
             refs_by_shard.each do |shard_number, shard_items|
diff --git a/gems/gitlab-active-context/lib/active_context/config.rb b/gems/gitlab-active-context/lib/active_context/config.rb
index b68843f65e11..5d3a5323691e 100644
--- a/gems/gitlab-active-context/lib/active_context/config.rb
+++ b/gems/gitlab-active-context/lib/active_context/config.rb
@@ -2,7 +2,15 @@
 
 module ActiveContext
   class Config
-    Cfg = Struct.new(:enabled, :databases, :logger, :indexing_enabled, :re_enqueue_indexing_workers, :migrations_path)
+    Cfg = Struct.new(
+      :enabled,
+      :databases,
+      :logger,
+      :indexing_enabled,
+      :re_enqueue_indexing_workers,
+      :migrations_path,
+      :collection_model
+    )
 
     class << self
       def configure(&block)
@@ -25,6 +33,10 @@ def migrations_path
         current.migrations_path || Rails.root.join('ee/db/active_context/migrate')
       end
 
+      def collection_model
+        current.collection_model || ::Ai::ActiveContext::Collection
+      end
+
       def logger
         current.logger || ::Logger.new($stdout)
       end
diff --git a/gems/gitlab-active-context/lib/active_context/databases/concerns/executor.rb b/gems/gitlab-active-context/lib/active_context/databases/concerns/executor.rb
index 3ca6974f3207..910bb0192d41 100644
--- a/gems/gitlab-active-context/lib/active_context/databases/concerns/executor.rb
+++ b/gems/gitlab-active-context/lib/active_context/databases/concerns/executor.rb
@@ -20,10 +20,18 @@ def create_collection(name, number_of_partitions:, &block)
             number_of_partitions: number_of_partitions,
             fields: builder.fields
           )
+
+          create_collection_record(full_name, number_of_partitions)
         end
 
         private
 
+        def create_collection_record(name, number_of_partitions)
+          collection = Config.collection_model.find_or_initialize_by(name: name)
+          collection.update(number_of_partitions: number_of_partitions)
+          collection.save!
+        end
+
         def do_create_collection(...)
           raise NotImplementedError
         end
diff --git a/gems/gitlab-active-context/lib/active_context/hash.rb b/gems/gitlab-active-context/lib/active_context/hash.rb
new file mode 100644
index 000000000000..8a68bcbb23e1
--- /dev/null
+++ b/gems/gitlab-active-context/lib/active_context/hash.rb
@@ -0,0 +1,10 @@
+# frozen_string_literal: true
+
+module ActiveContext
+  class Hash
+    def self.consistent_hash(number, data)
+      data = data.to_s unless data.is_a?(String)
+      Digest::SHA256.hexdigest(data).hex % number # rubocop: disable Fips/OpenSSL -- used for data distribution, not for security
+    end
+  end
+end
diff --git a/gems/gitlab-active-context/lib/active_context/reference.rb b/gems/gitlab-active-context/lib/active_context/reference.rb
index bda7d8612354..f2bec3c6b496 100644
--- a/gems/gitlab-active-context/lib/active_context/reference.rb
+++ b/gems/gitlab-active-context/lib/active_context/reference.rb
@@ -9,20 +9,18 @@ class Reference
 
     class << self
       def deserialize(string)
-        ref_klass = ref_klass(string)
-
-        if ref_klass
-          ref_klass.instantiate(string)
-        else
-          Search::Elastic::Reference.deserialize(string)
-        end
+        ref_klass(string)&.instantiate(string)
       end
 
       def instantiate(string)
         new(*deserialize_string(string))
       end
 
-      def serialize
+      def serialize(collection_id, routing, data)
+        new(collection_id, routing, *serialize_data(data)).serialize
+      end
+
+      def serialize_data
         raise NotImplementedError
       end
 
@@ -35,23 +33,37 @@ def preprocess_references(refs)
       end
     end
 
+    attr_reader :collection_id, :collection, :routing, :serialized_args
+
+    def initialize(collection_id, routing, *serialized_args)
+      @collection_id = collection_id.to_i
+      @collection = ActiveContext::CollectionCache.fetch(@collection_id)
+      @routing = routing
+      @serialized_args = serialized_args
+      init
+    end
+
     def klass
       self.class.klass
     end
 
     def serialize
+      self.class.join_delimited([collection_id, routing, serialize_arguments].flatten.compact)
+    end
+
+    def init
       raise NotImplementedError
     end
 
-    def as_indexed_json
+    def serialize_arguments
       raise NotImplementedError
     end
 
-    def operation
+    def as_indexed_json
       raise NotImplementedError
     end
 
-    def partition_name
+    def operation
       raise NotImplementedError
     end
 
@@ -59,8 +71,12 @@ def identifier
       raise NotImplementedError
     end
 
-    def routing
-      nil
+    def partition_name
+      collection.name
+    end
+
+    def partition_number
+      collection.partition_for(routing)
     end
   end
 end
diff --git a/gems/gitlab-active-context/lib/active_context/shard.rb b/gems/gitlab-active-context/lib/active_context/shard.rb
deleted file mode 100644
index c3a628436b79..000000000000
--- a/gems/gitlab-active-context/lib/active_context/shard.rb
+++ /dev/null
@@ -1,9 +0,0 @@
-# frozen_string_literal: true
-
-module ActiveContext
-  class Shard
-    def self.shard_number(number_of_shards, data)
-      Digest::SHA256.hexdigest(data).hex % number_of_shards # rubocop: disable Fips/OpenSSL -- used for data distribution, not for security
-    end
-  end
-end
diff --git a/gems/gitlab-active-context/spec/lib/active_context/concerns/queue_spec.rb b/gems/gitlab-active-context/spec/lib/active_context/concerns/queue_spec.rb
index 9ee1824cddb8..e05efe2a30e7 100644
--- a/gems/gitlab-active-context/spec/lib/active_context/concerns/queue_spec.rb
+++ b/gems/gitlab-active-context/spec/lib/active_context/concerns/queue_spec.rb
@@ -36,7 +36,7 @@ def self.number_of_shards
     it 'pushes references to Redis' do
       references = %w[ref1 ref2 ref3]
 
-      allow(ActiveContext::Shard).to receive(:shard_number).and_return(0, 1, 0)
+      allow(ActiveContext::Hash).to receive(:consistent_hash).and_return(0, 1, 0)
       expect(redis_double).to receive(:incrby).with('mockmodule:{test_queue}:0:score', 2).and_return(2)
       expect(redis_double).to receive(:incrby).with('mockmodule:{test_queue}:1:score', 1).and_return(1)
       expect(redis_double).to receive(:zadd).with('mockmodule:{test_queue}:0:zset', [[1, 'ref1'], [2, 'ref3']])
diff --git a/gems/gitlab-active-context/spec/lib/active_context/config_spec.rb b/gems/gitlab-active-context/spec/lib/active_context/config_spec.rb
index d689a0d60fdf..ef8fdcb85435 100644
--- a/gems/gitlab-active-context/spec/lib/active_context/config_spec.rb
+++ b/gems/gitlab-active-context/spec/lib/active_context/config_spec.rb
@@ -72,6 +72,32 @@
     end
   end
 
+  describe '.collection_model' do
+    before do
+      stub_const('Ai::ActiveContext::Collection', Class.new)
+    end
+
+    context 'when collection_model is not set' do
+      it 'returns the default model' do
+        expect(described_class.collection_model).to eq(::Ai::ActiveContext::Collection)
+      end
+    end
+
+    context 'when collection_model is set' do
+      let(:custom_model) { Class.new }
+
+      before do
+        described_class.configure do |config|
+          config.collection_model = custom_model
+        end
+      end
+
+      it 'returns the configured collection model' do
+        expect(described_class.collection_model).to eq(custom_model)
+      end
+    end
+  end
+
   describe '.logger' do
     context 'when logger is not set' do
       it 'returns a default stdout logger' do
diff --git a/gems/gitlab-active-context/spec/lib/active_context/reference_spec.rb b/gems/gitlab-active-context/spec/lib/active_context/reference_spec.rb
index fd83d9de69ce..04d7e76917d0 100644
--- a/gems/gitlab-active-context/spec/lib/active_context/reference_spec.rb
+++ b/gems/gitlab-active-context/spec/lib/active_context/reference_spec.rb
@@ -23,9 +23,8 @@
         stub_const('Search::Elastic::Reference', Class.new)
       end
 
-      it 'falls back to Search::Elastic::Reference.deserialize' do
-        expect(Search::Elastic::Reference).to receive(:deserialize).with('test|string')
-        described_class.deserialize('test|string')
+      it 'returns nil' do
+        expect(described_class.deserialize('test|string')).to be_nil
       end
     end
   end
@@ -45,12 +44,6 @@
     end
   end
 
-  describe '#klass' do
-    it 'returns the demodulized class name' do
-      expect(described_class.new.klass).to eq('Reference')
-    end
-  end
-
   describe 'ReferenceUtils methods' do
     describe '.delimit' do
       it 'splits the string by the delimiter' do
diff --git a/gems/gitlab-active-context/spec/lib/active_context/tracker_spec.rb b/gems/gitlab-active-context/spec/lib/active_context/tracker_spec.rb
index 7ac7bbb0fc7e..8351a894c7cf 100644
--- a/gems/gitlab-active-context/spec/lib/active_context/tracker_spec.rb
+++ b/gems/gitlab-active-context/spec/lib/active_context/tracker_spec.rb
@@ -30,18 +30,6 @@ def references
       expect(mock_queue).to contain_exactly(['test_string'])
     end
 
-    it 'serializes ActiveContext::Reference objects' do
-      reference_class = Class.new(ActiveContext::Reference) do
-        def serialize
-          'serialized_reference'
-        end
-      end
-      reference = reference_class.new
-
-      expect(described_class.track!(reference, collection: mock_collection)).to eq(1)
-      expect(mock_queue).to contain_exactly(['serialized_reference'])
-    end
-
     it 'uses collection.references for other objects' do
       obj = double('SomeObject')
       collection_instance = instance_double('CollectionInstance')
-- 
GitLab