diff --git a/app/models/ci/stage.rb b/app/models/ci/stage.rb
index ba1a0a46247ae0a13c149ad91b754bc76fd0f822..080f93746ccf944d46099509d95650a5e545415d 100644
--- a/app/models/ci/stage.rb
+++ b/app/models/ci/stage.rb
@@ -8,6 +8,8 @@ class Stage < Ci::ApplicationRecord
     include Gitlab::OptimisticLocking
     include Presentable
 
+    self.primary_key = :id
+
     partitionable scope: :pipeline
 
     enum status: Ci::HasStatus::STATUSES_ENUM
diff --git a/db/post_migrate/20240207094958_swap_primary_key_ci_stage.rb b/db/post_migrate/20240207094958_swap_primary_key_ci_stage.rb
new file mode 100644
index 0000000000000000000000000000000000000000..311830be50108ab6f3f536cebaaee1d71e75bb18
--- /dev/null
+++ b/db/post_migrate/20240207094958_swap_primary_key_ci_stage.rb
@@ -0,0 +1,45 @@
+# frozen_string_literal: true
+
+class SwapPrimaryKeyCiStage < Gitlab::Database::Migration[2.2]
+  include Gitlab::Database::PartitioningMigrationHelpers
+
+  milestone '16.9'
+  disable_ddl_transaction!
+
+  TABLE_NAME = :ci_stages
+  PRIMARY_KEY = :ci_stages_pkey
+  NEW_INDEX = :index_ci_stages_on_id_partition_id_unique
+  OLD_INDEX = :index_ci_stages_on_id_unique
+
+  def up
+    swap_primary_key(TABLE_NAME, PRIMARY_KEY, NEW_INDEX)
+  end
+
+  def down
+    add_concurrent_index(TABLE_NAME, :id, unique: true, name: OLD_INDEX)
+    add_concurrent_index(TABLE_NAME, [:id, :partition_id], unique: true, name: NEW_INDEX)
+
+    unswap_primary_key(TABLE_NAME, PRIMARY_KEY, OLD_INDEX)
+
+    recreate_partitioned_foreign_keys
+  end
+
+  private
+
+  def recreate_partitioned_foreign_keys
+    add_partitioned_fk(:p_ci_builds, :fk_3a9eaa254d_p, column: :stage_id)
+  end
+
+  def add_partitioned_fk(source_table, name, column: nil)
+    add_concurrent_partitioned_foreign_key(
+      source_table,
+      TABLE_NAME,
+      column: [:partition_id, column],
+      target_column: [:partition_id, :id],
+      reverse_lock_order: true,
+      on_update: :cascade,
+      on_delete: :cascade,
+      name: name
+    )
+  end
+end
diff --git a/db/schema_migrations/20240207094958 b/db/schema_migrations/20240207094958
new file mode 100644
index 0000000000000000000000000000000000000000..8b947ba8aa422ab7b647d7e78fcd4816c7420eca
--- /dev/null
+++ b/db/schema_migrations/20240207094958
@@ -0,0 +1 @@
+48ff6b60706c886ddbd3046bab40e73c90f2a95d244402276b58c13e62ab29ab
\ No newline at end of file
diff --git a/db/structure.sql b/db/structure.sql
index 05fd042fb7b505054c8d923406299f392a0dcb68..5d622aaa7674a56f25fd997da3eb96536d006061 100644
--- a/db/structure.sql
+++ b/db/structure.sql
@@ -29344,7 +29344,7 @@ ALTER TABLE ONLY ci_sources_projects
     ADD CONSTRAINT ci_sources_projects_pkey PRIMARY KEY (id);
 
 ALTER TABLE ONLY ci_stages
-    ADD CONSTRAINT ci_stages_pkey PRIMARY KEY (id);
+    ADD CONSTRAINT ci_stages_pkey PRIMARY KEY (id, partition_id);
 
 ALTER TABLE ONLY ci_subscriptions_projects
     ADD CONSTRAINT ci_subscriptions_projects_pkey PRIMARY KEY (id);
@@ -33362,8 +33362,6 @@ CREATE INDEX index_ci_sources_projects_on_pipeline_id ON ci_sources_projects USI
 
 CREATE UNIQUE INDEX index_ci_sources_projects_on_source_project_id_and_pipeline_id ON ci_sources_projects USING btree (source_project_id, pipeline_id);
 
-CREATE UNIQUE INDEX index_ci_stages_on_id_partition_id_unique ON ci_stages USING btree (id, partition_id);
-
 CREATE INDEX index_ci_stages_on_pipeline_id ON ci_stages USING btree (pipeline_id);
 
 CREATE INDEX index_ci_stages_on_pipeline_id_and_id ON ci_stages USING btree (pipeline_id, id) WHERE (status = ANY (ARRAY[0, 1, 2, 8, 9, 10]));