From 358d01f5e0fbebb8d6d8641c63d48bf70287fffd Mon Sep 17 00:00:00 2001 From: Eduardo Bonet <ebonet@gitlab.com> Date: Mon, 22 Apr 2024 15:41:29 +0000 Subject: [PATCH] Migrates show_ml_model_version to graphql Updates components so what show_ml_model_version fetches the version data using graphql instead of passing as a prop. --- .../apps/show_ml_model_version.vue | 79 +++++++++++- .../components/model_version_detail.vue | 31 +++-- .../queries/get_model_version.query.graphql | 69 ++++++++++ .../javascripts/ml/model_registry/utils.js | 53 ++++++++ .../projects/ml/model_registry_helper.rb | 15 +++ .../projects/ml/model_versions/show.html.haml | 2 +- .../apps/show_ml_model_version_spec.js | 83 ++++++++++-- .../components/model_version_detail_spec.js | 121 +++++++++++++----- .../ml/model_registry/graphql_mock_data.js | 97 ++++++++++++++ spec/frontend/ml/model_registry/utils_spec.js | 60 +++++++++ .../projects/ml/model_registry_helper_spec.rb | 36 ++++++ 11 files changed, 591 insertions(+), 55 deletions(-) create mode 100644 app/assets/javascripts/ml/model_registry/graphql/queries/get_model_version.query.graphql create mode 100644 app/assets/javascripts/ml/model_registry/utils.js create mode 100644 spec/frontend/ml/model_registry/utils_spec.js diff --git a/app/assets/javascripts/ml/model_registry/apps/show_ml_model_version.vue b/app/assets/javascripts/ml/model_registry/apps/show_ml_model_version.vue index 6608f44ecf70c..e3ee9be4ccc6a 100644 --- a/app/assets/javascripts/ml/model_registry/apps/show_ml_model_version.vue +++ b/app/assets/javascripts/ml/model_registry/apps/show_ml_model_version.vue @@ -1,22 +1,91 @@ <script> import TitleArea from '~/vue_shared/components/registry/title_area.vue'; +import * as Sentry from '~/sentry/sentry_browser_wrapper'; +import getModelVersionQuery from '~/ml/model_registry/graphql/queries/get_model_version.query.graphql'; +import { convertToGraphQLId } from '~/graphql_shared/utils'; +import { makeLoadVersionsErrorMessage } from '~/ml/model_registry/translations'; import ModelVersionDetail from '../components/model_version_detail.vue'; +import LoadOrErrorOrShow from '../components/load_or_error_or_show.vue'; export default { name: 'ShowMlModelVersionApp', components: { + LoadOrErrorOrShow, ModelVersionDetail, TitleArea, }, + provide() { + return { + projectPath: this.projectPath, + }; + }, props: { - modelVersion: { - type: Object, + modelId: { + type: Number, + required: true, + }, + modelVersionId: { + type: Number, + required: true, + }, + versionName: { + type: String, + required: true, + }, + modelName: { + type: String, + required: true, + }, + projectPath: { + type: String, required: true, }, }, + apollo: { + modelWithModelVersion: { + query: getModelVersionQuery, + variables() { + return this.queryVariables; + }, + update(data) { + return data?.mlModel; + }, + error(error) { + this.handleError(error); + }, + }, + }, + data() { + return { + modelWithModelVersion: {}, + errorMessage: '', + }; + }, computed: { + modelVersion() { + return this.modelWithModelVersion?.version; + }, + isLoading() { + return this.$apollo.queries.modelWithModelVersion.loading; + }, title() { - return `${this.modelVersion.model.name} / ${this.modelVersion.version}`; + return `${this.modelName} / ${this.versionName}`; + }, + queryVariables() { + return { + modelId: convertToGraphQLId('Ml::Model', this.modelId), + modelVersionId: convertToGraphQLId('Ml::ModelVersion', this.modelVersionId), + }; + }, + }, + methods: { + handleError(error) { + this.errorMessage = makeLoadVersionsErrorMessage(error.message); + Sentry.captureException(error, { + tags: { + vue_component: 'show_ml_model_version', + }, + }); }, }, }; @@ -25,6 +94,8 @@ export default { <template> <div> <title-area :title="title" /> - <model-version-detail :model-version="modelVersion" /> + <load-or-error-or-show :is-loading="isLoading" :error-message="errorMessage"> + <model-version-detail :model-version="modelVersion" /> + </load-or-error-or-show> </div> </template> diff --git a/app/assets/javascripts/ml/model_registry/components/model_version_detail.vue b/app/assets/javascripts/ml/model_registry/components/model_version_detail.vue index 8d3e8cf2023cc..b2788a9c9ded4 100644 --- a/app/assets/javascripts/ml/model_registry/components/model_version_detail.vue +++ b/app/assets/javascripts/ml/model_registry/components/model_version_detail.vue @@ -1,5 +1,6 @@ <script> -import { convertToGraphQLId } from '~/graphql_shared/utils'; +import { convertCandidateFromGraphql } from '~/ml/model_registry/utils'; +import { convertToGraphQLId, isGid } from '~/graphql_shared/utils'; import { TYPENAME_PACKAGES_PACKAGE } from '~/graphql_shared/constants'; import * as i18n from '../translations'; import CandidateDetail from './candidate_detail.vue'; @@ -11,6 +12,7 @@ export default { import('~/packages_and_registries/package_registry/components/details/package_files.vue'), CandidateDetail, }, + inject: ['projectPath'], props: { modelVersion: { type: Object, @@ -18,15 +20,26 @@ export default { }, }, computed: { - packageId() { - return convertToGraphQLId(TYPENAME_PACKAGES_PACKAGE, this.modelVersion.packageId); - }, - projectPath() { - return this.modelVersion.projectPath; - }, packageType() { return 'ml_model'; }, + isFromGraphql() { + return isGid(this.modelVersion.id); + }, + candidate() { + if (this.isFromGraphql) { + return convertCandidateFromGraphql(this.modelVersion.candidate); + } + + return this.modelVersion.candidate; + }, + packageId() { + if (this.isFromGraphql) { + return this.modelVersion.packageId; + } + + return convertToGraphQLId(TYPENAME_PACKAGES_PACKAGE, this.modelVersion.packageId); + }, }, i18n, }; @@ -53,9 +66,9 @@ export default { <div class="gl-mt-5"> <span class="gl-font-weight-bold">{{ $options.i18n.MLFLOW_ID_LABEL }}:</span> - {{ modelVersion.candidate.info.eid }} + {{ candidate.info.eid }} </div> - <candidate-detail :candidate="modelVersion.candidate" :show-info-section="false" /> + <candidate-detail :candidate="candidate" :show-info-section="false" /> </div> </template> diff --git a/app/assets/javascripts/ml/model_registry/graphql/queries/get_model_version.query.graphql b/app/assets/javascripts/ml/model_registry/graphql/queries/get_model_version.query.graphql new file mode 100644 index 0000000000000..de36f252c9397 --- /dev/null +++ b/app/assets/javascripts/ml/model_registry/graphql/queries/get_model_version.query.graphql @@ -0,0 +1,69 @@ +query getModelVersion($modelId: MlModelID!, $modelVersionId: MlModelVersionID!) { + mlModel(id: $modelId) { + id + name + version(modelVersionId: $modelVersionId) { + id + version + packageId + description + candidate { + id + name + iid + eid + status + params { + nodes { + id + name + value + } + } + metadata { + nodes { + id + name + value + } + } + metrics { + nodes { + id + name + value + step + } + } + ciJob { + id + webPath + name + pipeline { + id + mergeRequest { + id + iid + title + webUrl + } + user { + id + avatarUrl + webUrl + username + name + } + } + } + _links { + showPath + artifactPath + } + } + _links { + showPath + } + } + } +} diff --git a/app/assets/javascripts/ml/model_registry/utils.js b/app/assets/javascripts/ml/model_registry/utils.js new file mode 100644 index 0000000000000..42b989e7bf488 --- /dev/null +++ b/app/assets/javascripts/ml/model_registry/utils.js @@ -0,0 +1,53 @@ +export function convertCandidateFromGraphql(graphqlCandidate) { + const { iid, eid, status, ciJob } = graphqlCandidate; + const links = graphqlCandidate._links; + + let ciJobValues = null; + + if (ciJob) { + let userInfo = null; + let mergeRequestInfo = null; + const user = ciJob?.pipeline.user; + const mr = ciJob?.pipeline.mergeRequest; + + if (user) { + userInfo = { + avatar: user.avatarUrl, + path: user.webUrl, + username: user.username, + name: user.name, + }; + } + + if (mr) { + mergeRequestInfo = { + title: mr.title, + path: mr.webUrl, + iid: mr.iid, + }; + } + + ciJobValues = { + name: ciJob.name, + path: ciJob.webPath, + user: userInfo, + mergeRequest: mergeRequestInfo, + }; + } + + return { + info: { + iid, + eid, + status, + experimentName: '', + pathToExperiment: '', + pathToArtifact: links.artifactPath, + path: links.showPath, + ciJob: ciJobValues, + }, + metrics: graphqlCandidate.metrics.nodes, + params: graphqlCandidate.params.nodes, + metadata: graphqlCandidate.metadata.nodes, + }; +} diff --git a/app/helpers/projects/ml/model_registry_helper.rb b/app/helpers/projects/ml/model_registry_helper.rb index c6fe5bed91a24..d75befffe5256 100644 --- a/app/helpers/projects/ml/model_registry_helper.rb +++ b/app/helpers/projects/ml/model_registry_helper.rb @@ -39,6 +39,21 @@ def show_ml_model_data(model, user) to_json(data) end + def show_ml_model_version_data(model_version, user) + project = model_version.project + + data = { + project_path: project.full_path, + model_id: model_version.model.id, + model_version_id: model_version.id, + model_name: model_version.name, + version_name: model_version.version, + can_write_model_registry: can_write_model_registry?(user, project) + } + + to_json(data) + end + private def can_write_model_registry?(user, project) diff --git a/app/views/projects/ml/model_versions/show.html.haml b/app/views/projects/ml/model_versions/show.html.haml index 1b4bdd29842d8..144c4deaa2088 100644 --- a/app/views/projects/ml/model_versions/show.html.haml +++ b/app/views/projects/ml/model_versions/show.html.haml @@ -3,4 +3,4 @@ - breadcrumb_title @model_version.version - page_title "#{@model_version.name} / #{@model_version.version}" -= render(Projects::Ml::ShowMlModelVersionComponent.new(model_version: @model_version, current_user: current_user)) +#js-mount-show-ml-model-version{ data: { view_model: show_ml_model_version_data(@model_version, @current_user) } } diff --git a/spec/frontend/ml/model_registry/apps/show_ml_model_version_spec.js b/spec/frontend/ml/model_registry/apps/show_ml_model_version_spec.js index 2605a75d9611e..d2062dbf5e68c 100644 --- a/spec/frontend/ml/model_registry/apps/show_ml_model_version_spec.js +++ b/spec/frontend/ml/model_registry/apps/show_ml_model_version_spec.js @@ -1,25 +1,84 @@ -import { shallowMount } from '@vue/test-utils'; +import { mount } from '@vue/test-utils'; +import Vue from 'vue'; +import VueApollo from 'vue-apollo'; import { ShowMlModelVersion } from '~/ml/model_registry/apps'; import ModelVersionDetail from '~/ml/model_registry/components/model_version_detail.vue'; import TitleArea from '~/vue_shared/components/registry/title_area.vue'; -import { MODEL_VERSION } from '../mock_data'; +import LoadOrErrorOrShow from '~/ml/model_registry/components/load_or_error_or_show.vue'; +import createMockApollo from 'helpers/mock_apollo_helper'; +import getModelVersionQuery from '~/ml/model_registry/graphql/queries/get_model_version.query.graphql'; +import waitForPromises from 'helpers/wait_for_promises'; +import * as Sentry from '~/sentry/sentry_browser_wrapper'; +import { modelVersionQuery, modelVersionWithCandidate } from '../graphql_mock_data'; -let wrapper; -const createWrapper = () => { - wrapper = shallowMount(ShowMlModelVersion, { propsData: { modelVersion: MODEL_VERSION } }); -}; - -const findTitleArea = () => wrapper.findComponent(TitleArea); -const findModelVersionDetail = () => wrapper.findComponent(ModelVersionDetail); +Vue.use(VueApollo); describe('ml/model_registry/apps/show_model_version.vue', () => { - beforeEach(() => createWrapper()); + let wrapper; + let apolloProvider; + + beforeEach(() => { + jest.spyOn(Sentry, 'captureException').mockImplementation(); + }); + + const createWrapper = (resolver = jest.fn().mockResolvedValue(modelVersionQuery)) => { + const requestHandlers = [[getModelVersionQuery, resolver]]; + apolloProvider = createMockApollo(requestHandlers); + + wrapper = mount(ShowMlModelVersion, { + propsData: { + modelName: 'blah', + versionName: '1.2.3', + modelId: 1, + modelVersionId: 2, + projectPath: 'path/to/project', + }, + apolloProvider, + }); + }; + + const findTitleArea = () => wrapper.findComponent(TitleArea); + const findModelVersionDetail = () => wrapper.findComponent(ModelVersionDetail); + const findLoadOrErrorOrShow = () => wrapper.findComponent(LoadOrErrorOrShow); it('renders the title', () => { + createWrapper(); + expect(findTitleArea().props('title')).toBe('blah / 1.2.3'); }); - it('renders the model version detail', () => { - expect(findModelVersionDetail().props('modelVersion')).toBe(MODEL_VERSION); + it('Requests data with the right parameters', async () => { + const resolver = jest.fn().mockResolvedValue(modelVersionQuery); + + createWrapper(resolver); + + await waitForPromises(); + + expect(resolver).toHaveBeenLastCalledWith( + expect.objectContaining({ + modelId: 'gid://gitlab/Ml::Model/1', + modelVersionId: 'gid://gitlab/Ml::ModelVersion/2', + }), + ); + }); + + it('Displays data when loaded', async () => { + createWrapper(); + + await waitForPromises(); + + expect(findModelVersionDetail().props('modelVersion')).toMatchObject(modelVersionWithCandidate); + }); + + it('Shows error message on error', async () => { + const error = new Error('Failure!'); + createWrapper(jest.fn().mockRejectedValue(error)); + + await waitForPromises(); + + expect(findLoadOrErrorOrShow().props('errorMessage')).toBe( + 'Failed to load model versions with error: Failure!', + ); + expect(Sentry.captureException).toHaveBeenCalled(); }); }); diff --git a/spec/frontend/ml/model_registry/components/model_version_detail_spec.js b/spec/frontend/ml/model_registry/components/model_version_detail_spec.js index d1874346ad740..db113165a184e 100644 --- a/spec/frontend/ml/model_registry/components/model_version_detail_spec.js +++ b/spec/frontend/ml/model_registry/components/model_version_detail_spec.js @@ -5,62 +5,125 @@ import ModelVersionDetail from '~/ml/model_registry/components/model_version_det import PackageFiles from '~/packages_and_registries/package_registry/components/details/package_files.vue'; import CandidateDetail from '~/ml/model_registry/components/candidate_detail.vue'; import createMockApollo from 'helpers/mock_apollo_helper'; +import { convertCandidateFromGraphql } from '~/ml/model_registry/utils'; +import { modelVersionWithCandidate } from '../graphql_mock_data'; import { makeModelVersion, MODEL_VERSION } from '../mock_data'; Vue.use(VueApollo); +const makeGraphqlModelVersion = (overrides = {}) => { + return { ...modelVersionWithCandidate, ...overrides }; +}; + let wrapper; -const createWrapper = (modelVersion = MODEL_VERSION) => { +const createWrapper = (modelVersion = modelVersionWithCandidate) => { const apolloProvider = createMockApollo([]); - wrapper = shallowMount(ModelVersionDetail, { apolloProvider, propsData: { modelVersion } }); + wrapper = shallowMount(ModelVersionDetail, { + apolloProvider, + propsData: { modelVersion }, + provide: { + projectPath: 'path/to/project', + }, + }); }; const findPackageFiles = () => wrapper.findComponent(PackageFiles); const findCandidateDetail = () => wrapper.findComponent(CandidateDetail); describe('ml/model_registry/components/model_version_detail.vue', () => { - describe('base behaviour', () => { - beforeEach(() => createWrapper()); + describe('When passing modelVersion passed on page load', () => { + describe('base behaviour', () => { + beforeEach(() => createWrapper(MODEL_VERSION)); - it('shows the description', () => { - expect(wrapper.text()).toContain(MODEL_VERSION.description); - }); + it('shows the description', () => { + expect(wrapper.text()).toContain(MODEL_VERSION.description); + }); - it('shows the candidate', () => { - expect(findCandidateDetail().props('candidate')).toBe(MODEL_VERSION.candidate); - }); + it('shows the candidate', () => { + expect(findCandidateDetail().props('candidate')).toMatchObject(MODEL_VERSION.candidate); + }); - it('shows the mlflow label string', () => { - expect(wrapper.text()).toContain('MLflow run ID'); + it('shows the mlflow label string', () => { + expect(wrapper.text()).toContain('MLflow run ID'); + }); + + it('shows the mlflow id', () => { + expect(wrapper.text()).toContain(MODEL_VERSION.candidate.info.eid); + }); + + it('renders files', () => { + expect(findPackageFiles().props()).toEqual({ + packageId: 'gid://gitlab/Packages::Package/12', + projectPath: 'path/to/project', + packageType: 'ml_model', + canDelete: false, + }); + }); }); - it('shows the mlflow id', () => { - expect(wrapper.text()).toContain(MODEL_VERSION.candidate.info.eid); + describe('if package does not exist', () => { + beforeEach(() => createWrapper(makeModelVersion({ packageId: 0 }))); + + it('does not render files', () => { + expect(findPackageFiles().exists()).toBe(false); + }); }); - it('renders files', () => { - expect(findPackageFiles().props()).toEqual({ - packageId: 'gid://gitlab/Packages::Package/12', - projectPath: MODEL_VERSION.projectPath, - packageType: 'ml_model', - canDelete: false, + describe('if model version does not have description', () => { + beforeEach(() => createWrapper(makeModelVersion({ description: null }))); + + it('renders no description provided label', () => { + expect(wrapper.text()).toContain('No description provided'); }); }); }); - describe('if package does not exist', () => { - beforeEach(() => createWrapper(makeModelVersion({ packageId: 0 }))); + describe('When passing modelVersion fetched from graphql', () => { + describe('base behaviour', () => { + beforeEach(() => createWrapper()); - it('does not render files', () => { - expect(findPackageFiles().exists()).toBe(false); + it('shows the description', () => { + expect(wrapper.text()).toContain('A model version description'); + }); + + it('shows the candidate', () => { + expect(findCandidateDetail().props('candidate')).toMatchObject( + convertCandidateFromGraphql(modelVersionWithCandidate.candidate), + ); + }); + + it('shows the mlflow label string', () => { + expect(wrapper.text()).toContain('MLflow run ID'); + }); + + it('shows the mlflow id', () => { + expect(wrapper.text()).toContain(modelVersionWithCandidate.candidate.eid); + }); + + it('renders files', () => { + expect(findPackageFiles().props()).toEqual({ + packageId: 'gid://gitlab/Packages::Package/12', + projectPath: 'path/to/project', + packageType: 'ml_model', + canDelete: false, + }); + }); }); - }); - describe('if model version does not have description', () => { - beforeEach(() => createWrapper(makeModelVersion({ description: null }))); + describe('if package does not exist', () => { + beforeEach(() => createWrapper(makeGraphqlModelVersion({ packageId: 0 }))); - it('renders no description provided label', () => { - expect(wrapper.text()).toContain('No description provided'); + it('does not render files', () => { + expect(findPackageFiles().exists()).toBe(false); + }); + }); + + describe('if model version does not have description', () => { + beforeEach(() => createWrapper(makeGraphqlModelVersion({ description: null }))); + + it('renders no description provided label', () => { + expect(wrapper.text()).toContain('No description provided'); + }); }); }); }); diff --git a/spec/frontend/ml/model_registry/graphql_mock_data.js b/spec/frontend/ml/model_registry/graphql_mock_data.js index fd4c88b4083b2..4aa78586b2ac9 100644 --- a/spec/frontend/ml/model_registry/graphql_mock_data.js +++ b/spec/frontend/ml/model_registry/graphql_mock_data.js @@ -41,6 +41,78 @@ export const modelVersionsQuery = (versions = graphqlModelVersions) => ({ }, }); +export const candidate = { + id: 'gid://gitlab/Ml::Candidate/1', + name: 'hare-zebra-cobra-9745', + iid: 1, + eid: 'e9a71521-45c6-4b0a-b0c3-21f0b4528a5c', + status: 'running', + params: { + nodes: [ + { + id: 'gid://gitlab/Ml::CandidateParam/1', + name: 'param1', + value: 'value1', + }, + ], + }, + metadata: { + nodes: [ + { + id: 'gid://gitlab/Ml::CandidateMetadata/1', + name: 'metadata1', + value: 'metadataValue1', + }, + ], + }, + metrics: { + nodes: [ + { + id: 'gid://gitlab/Ml::CandidateMetric/1', + name: 'metric1', + value: 0.3, + step: 0, + }, + ], + }, + ciJob: { + id: 'gid://gitlab/Ci::Build/1', + webPath: '/gitlab-org/gitlab-test/-/jobs/1', + name: 'build:linux', + pipeline: { + id: 'gid://gitlab/Ci::Pipeline/1', + mergeRequest: { + id: 'gid://gitlab/MergeRequest/1', + title: 'Merge Request 1', + webUrl: 'path/to/mr', + iid: 1, + }, + user: { + id: 'gid://gitlab/User/1', + avatarUrl: 'path/to/avatar', + webUrl: 'path/to/user/1', + username: 'user1', + name: 'User 1', + }, + }, + }, + _links: { + showPath: '/root/test-project/-/ml/candidates/1', + artifactPath: '/root/test-project/-/packages/1', + }, +}; + +export const modelVersionWithCandidate = { + id: 'gid://gitlab/Ml::ModelVersion/1', + version: '1.0.4999', + packageId: 'gid://gitlab/Packages::Package/12', + description: 'A model version description', + candidate, + _links: { + showPath: '/root/test-project/-/ml/models/1/versions/5000', + }, +}; + export const graphqlCandidates = [ { id: 'gid://gitlab/Ml::Candidate/1', @@ -201,6 +273,21 @@ export const modelWithoutVersion = { }, }; +export const model = { + id: 'gid://gitlab/Ml::Model/1', + description: 'A model description', + name: 'gitlab_amazing_model', + versionCount: 1, + candidateCount: 0, + latestVersion: modelVersionWithCandidate, +}; + +export const modelDetailQuery = { + data: { + mlModel: model, + }, +}; + export const modelsQuery = ( models = [modelWithOneVersion, modelWithoutVersion], pageInfo = graphqlPageInfo, @@ -216,3 +303,13 @@ export const modelsQuery = ( }, }, }); + +export const modelVersionQuery = { + data: { + mlModel: { + id: 'gid://gitlab/Ml::Model/1', + name: 'blah', + version: modelVersionWithCandidate, + }, + }, +}; diff --git a/spec/frontend/ml/model_registry/utils_spec.js b/spec/frontend/ml/model_registry/utils_spec.js new file mode 100644 index 0000000000000..9c9baf89328b6 --- /dev/null +++ b/spec/frontend/ml/model_registry/utils_spec.js @@ -0,0 +1,60 @@ +import { convertCandidateFromGraphql } from '~/ml/model_registry/utils'; +import { candidate } from './graphql_mock_data'; + +describe('~/ml/model_registry/utils', () => { + describe('convertCandidateFromGraphql', () => { + it('converts from graphql response', () => { + const converted = convertCandidateFromGraphql(candidate); + const expectedResponse = { + info: { + iid: 1, + eid: 'e9a71521-45c6-4b0a-b0c3-21f0b4528a5c', + status: 'running', + experimentName: '', + pathToExperiment: '', + pathToArtifact: '/root/test-project/-/packages/1', + path: '/root/test-project/-/ml/candidates/1', + ciJob: { + mergeRequest: { + iid: 1, + path: 'path/to/mr', + title: 'Merge Request 1', + }, + name: 'build:linux', + path: '/gitlab-org/gitlab-test/-/jobs/1', + user: { + avatar: 'path/to/avatar', + name: 'User 1', + path: 'path/to/user/1', + username: 'user1', + }, + }, + }, + metrics: [ + { + id: 'gid://gitlab/Ml::CandidateMetric/1', + name: 'metric1', + value: 0.3, + step: 0, + }, + ], + params: [ + { + id: 'gid://gitlab/Ml::CandidateParam/1', + name: 'param1', + value: 'value1', + }, + ], + metadata: [ + { + id: 'gid://gitlab/Ml::CandidateMetadata/1', + name: 'metadata1', + value: 'metadataValue1', + }, + ], + }; + + expect(converted).toEqual(expectedResponse); + }); + }); +}); diff --git a/spec/helpers/projects/ml/model_registry_helper_spec.rb b/spec/helpers/projects/ml/model_registry_helper_spec.rb index 4123df178a54e..f1fc2b197fb0f 100644 --- a/spec/helpers/projects/ml/model_registry_helper_spec.rb +++ b/spec/helpers/projects/ml/model_registry_helper_spec.rb @@ -102,4 +102,40 @@ end end end + + describe '#show_ml_model_version_data' do + let_it_be(:model) do + build_stubbed(:ml_models, :with_latest_version_and_package, project: project, id: 1) + end + + let_it_be(:model_version) do + model.latest_version + end + + subject(:parsed) { Gitlab::Json.parse(helper.show_ml_model_version_data(model_version, user)) } + + it 'generates the correct data' do + is_expected.to eq({ + "projectPath" => project.full_path, + "modelId" => model.id, + "modelVersionId" => model_version.id, + "modelName" => model_version.name, + "versionName" => model_version.version, + "canWriteModelRegistry" => true + }) + end + + context 'when user does not have write access to model registry' do + before do + allow(Ability).to receive(:allowed?).and_call_original + allow(Ability).to receive(:allowed?) + .with(user, :write_model_registry, project) + .and_return(false) + end + + it 'canWriteModelRegistry is false' do + expect(parsed['canWriteModelRegistry']).to eq(false) + end + end + end end -- GitLab