From 1236829b799f1afdfcd93792240ce9c18b2f4b99 Mon Sep 17 00:00:00 2001 From: Eduardo Bonet <ebonet@gitlab.com> Date: Fri, 19 Jan 2024 16:28:06 +0000 Subject: [PATCH] Migrates index_ml_models to graphql To avoid technical debt and improve user experience, this MR migrates index_ml_model.vue to use graphql instead of preloading and refreshing the page on every change to the search query and pagination. To do so, we introduce search to the SearchableList component, in a way that doesn't affect existent usage. --- .../model_registry/apps/index_ml_models.vue | 134 +++++++--- .../components/candidate_list.vue | 21 +- .../model_registry/components/model_row.vue | 33 ++- .../components/model_version_list.vue | 21 +- .../components/searchable_list.vue | 87 ++++++- .../graphql/queries/get_models.query.graphql | 46 ++++ .../ml/model_registry/translations.js | 5 + .../projects/ml/models_controller.rb | 12 +- .../projects/ml/model_registry_helper.rb | 30 +++ app/views/projects/ml/models/index.html.haml | 2 +- locale/gitlab.pot | 3 + .../apps/index_ml_models_spec.js | 243 +++++++++++++----- .../components/model_row_spec.js | 33 ++- .../components/searchable_list_spec.js | 94 ++++++- .../ml/model_registry/graphql_mock_data.js | 61 +++++ .../projects/ml/model_registry_helper_spec.rb | 37 +++ .../projects/ml/models_controller_spec.rb | 69 +---- 17 files changed, 701 insertions(+), 230 deletions(-) create mode 100644 app/assets/javascripts/ml/model_registry/graphql/queries/get_models.query.graphql create mode 100644 app/helpers/projects/ml/model_registry_helper.rb create mode 100644 spec/helpers/projects/ml/model_registry_helper_spec.rb diff --git a/app/assets/javascripts/ml/model_registry/apps/index_ml_models.vue b/app/assets/javascripts/ml/model_registry/apps/index_ml_models.vue index 59b68fc00636c..7a04ccfe16329 100644 --- a/app/assets/javascripts/ml/model_registry/apps/index_ml_models.vue +++ b/app/assets/javascripts/ml/model_registry/apps/index_ml_models.vue @@ -1,32 +1,29 @@ <script> -import { isEmpty } from 'lodash'; -import { GlBadge, GlButton, GlTooltipDirective } from '@gitlab/ui'; -import Pagination from '~/vue_shared/components/incubation/pagination.vue'; +import { GlExperimentBadge, GlButton } from '@gitlab/ui'; import MetadataItem from '~/vue_shared/components/registry/metadata_item.vue'; import TitleArea from '~/vue_shared/components/registry/title_area.vue'; import { helpPagePath } from '~/helpers/help_page_helper'; +import * as Sentry from '~/sentry/sentry_browser_wrapper'; import EmptyState from '../components/empty_state.vue'; import * as i18n from '../translations'; -import { BASE_SORT_FIELDS, MODEL_ENTITIES } from '../constants'; -import SearchBar from '../components/search_bar.vue'; +import { BASE_SORT_FIELDS, GRAPHQL_PAGE_SIZE, MODEL_ENTITIES } from '../constants'; import ModelRow from '../components/model_row.vue'; import ActionsDropdown from '../components/actions_dropdown.vue'; +import getModelsQuery from '../graphql/queries/get_models.query.graphql'; +import { makeLoadModelErrorMessage } from '../translations'; +import SearchableList from '../components/searchable_list.vue'; export default { name: 'IndexMlModels', components: { - Pagination, ModelRow, - SearchBar, MetadataItem, TitleArea, - GlBadge, - EmptyState, + GlExperimentBadge, GlButton, + EmptyState, ActionsDropdown, - }, - directives: { - GlTooltip: GlTooltipDirective, + SearchableList, }, provide() { return { @@ -34,23 +31,14 @@ export default { }; }, props: { - models: { - type: Array, - required: true, - }, - pageInfo: { - type: Object, + projectPath: { + type: String, required: true, }, createModelPath: { type: String, required: true, }, - modelCount: { - type: Number, - required: false, - default: 0, - }, canWriteModelRegistry: { type: Boolean, required: false, @@ -62,9 +50,68 @@ export default { default: '', }, }, + apollo: { + models: { + query: getModelsQuery, + variables() { + return this.queryVariables; + }, + update(data) { + return data?.project?.mlModels ?? []; + }, + error(error) { + this.handleError(error); + }, + }, + }, + data() { + return { + models: [], + errorMessage: undefined, + }; + }, computed: { - hasModels() { - return !isEmpty(this.models); + pageInfo() { + return this.models?.pageInfo ?? {}; + }, + items() { + return this.models?.nodes ?? []; + }, + count() { + return this.models?.count ?? 0; + }, + isLoading() { + return this.$apollo.queries.models.loading; + }, + queryVariables() { + return { + fullPath: this.projectPath, + first: GRAPHQL_PAGE_SIZE, + }; + }, + }, + methods: { + fetchPage(variables) { + const vars = { + ...this.queryVariables, + ...variables, + name: variables.name, + orderBy: variables.orderBy?.toUpperCase() || 'CREATED_AT', + sort: variables.sort?.toUpperCase() || 'DESC', + }; + + this.$apollo.queries.models + .fetchMore({ + variables: vars, + updateQuery: (previousResult, { fetchMoreResult }) => { + return fetchMoreResult; + }, + }) + .catch(this.handleError); + }, + handleError(error) { + this.errorMessage = makeLoadModelErrorMessage(error.message); + Sentry.captureException(error); }, }, i18n, @@ -80,28 +127,39 @@ export default { <template #title> <div class="gl-flex-grow-1 gl-display-flex gl-align-items-center"> <span>{{ $options.i18n.TITLE_LABEL }}</span> - <gl-badge variant="neutral" class="gl-mx-4" size="lg" :href="$options.docHref"> - {{ __('Experiment') }} - </gl-badge> + <gl-experiment-badge :help-page-url="$options.docHref" /> </div> </template> <template #metadata-models-count> - <metadata-item icon="machine-learning" :text="$options.i18n.modelsCountLabel(modelCount)" /> + <metadata-item icon="machine-learning" :text="$options.i18n.modelsCountLabel(count)" /> </template> <template #right-actions> - <gl-button v-if="canWriteModelRegistry" :href="createModelPath">{{ - $options.i18n.CREATE_MODEL_LABEL - }}</gl-button> + <gl-button + v-if="canWriteModelRegistry" + :href="createModelPath" + data-testid="create-model-button" + >{{ $options.i18n.CREATE_MODEL_LABEL }}</gl-button + > <actions-dropdown /> </template> </title-area> - <template v-if="hasModels"> - <search-bar :sortable-fields="$options.sortableFields" /> - <model-row v-for="model in models" :key="model.name" :model="model" /> - <pagination v-bind="pageInfo" /> - </template> + <searchable-list + show-search + :page-info="pageInfo" + :items="items" + :error-message="errorMessage" + :is-loading="isLoading" + :sortable-fields="$options.sortableFields" + @fetch-page="fetchPage" + > + <template #empty-state> + <empty-state :entity-type="$options.modelEntity" /> + </template> - <empty-state v-else :entity-type="$options.modelEntity" /> + <template #item="{ item }"> + <model-row :model="item" /> + </template> + </searchable-list> </div> </template> diff --git a/app/assets/javascripts/ml/model_registry/components/candidate_list.vue b/app/assets/javascripts/ml/model_registry/components/candidate_list.vue index fca4462d7d2b4..d05a827c54512 100644 --- a/app/assets/javascripts/ml/model_registry/components/candidate_list.vue +++ b/app/assets/javascripts/ml/model_registry/components/candidate_list.vue @@ -35,8 +35,7 @@ export default { return data.mlModel?.candidates ?? {}; }, error(error) { - this.errorMessage = makeLoadCandidatesErrorMessage(error.message); - Sentry.captureException(error); + this.handleError(error); }, }, }, @@ -67,12 +66,18 @@ export default { ...newPageInfo, }; - this.$apollo.queries.candidates.fetchMore({ - variables, - updateQuery: (previousResult, { fetchMoreResult }) => { - return fetchMoreResult; - }, - }); + this.$apollo.queries.candidates + .fetchMore({ + variables, + updateQuery: (previousResult, { fetchMoreResult }) => { + return fetchMoreResult; + }, + }) + .catch(this.handleError); + }, + handleError(error) { + this.errorMessage = makeLoadCandidatesErrorMessage(error.message); + Sentry.captureException(error); }, }, i18n: { diff --git a/app/assets/javascripts/ml/model_registry/components/model_row.vue b/app/assets/javascripts/ml/model_registry/components/model_row.vue index 15be7bd0b47c8..49f72c7cef2ad 100644 --- a/app/assets/javascripts/ml/model_registry/components/model_row.vue +++ b/app/assets/javascripts/ml/model_registry/components/model_row.vue @@ -1,11 +1,14 @@ <script> -import { GlLink } from '@gitlab/ui'; +import { GlLink, GlTruncate } from '@gitlab/ui'; import { s__, n__ } from '~/locale'; +import ListItem from '~/vue_shared/components/registry/list_item.vue'; export default { name: 'MlModelRow', components: { GlLink, + ListItem, + GlTruncate, }, props: { model: { @@ -15,7 +18,7 @@ export default { }, computed: { hasVersions() { - return this.model.version != null; + return this.model.versionCount > 0; }, modelVersionCountMessage() { if (!this.model.versionCount) return s__('MlModelRegistry|No registered versions'); @@ -31,15 +34,23 @@ export default { </script> <template> - <div class="gl-border-b-solid gl-border-b-1 gl-border-b-gray-100 gl-py-3"> - <gl-link :href="model.path" class="gl-text-body gl-font-weight-bold gl-line-height-24"> - {{ model.name }} - </gl-link> + <list-item v-bind="$attrs"> + <template #left-primary> + <div class="gl-display-flex gl-align-items-center"> + <gl-link class="gl-text-body" :href="model._links.showPath"> + <gl-truncate :text="model.name" /> + </gl-link> + </div> + </template> - <div class="gl-text-secondary"> - <gl-link v-if="hasVersions" :href="model.versionPath">{{ model.version }}</gl-link> + <template #left-secondary> + <div class="gl-text-secondary"> + <gl-link v-if="hasVersions" :href="model.latestVersion._links.showPath">{{ + model.latestVersion.version + }}</gl-link> - {{ modelVersionCountMessage }} - </div> - </div> + {{ modelVersionCountMessage }} + </div> + </template> + </list-item> </template> diff --git a/app/assets/javascripts/ml/model_registry/components/model_version_list.vue b/app/assets/javascripts/ml/model_registry/components/model_version_list.vue index 5a649a9596a6c..ea5258a299ed0 100644 --- a/app/assets/javascripts/ml/model_registry/components/model_version_list.vue +++ b/app/assets/javascripts/ml/model_registry/components/model_version_list.vue @@ -36,8 +36,7 @@ export default { return data.mlModel?.versions ?? {}; }, error(error) { - this.errorMessage = makeLoadVersionsErrorMessage(error.message); - Sentry.captureException(error); + this.handleError(error); }, }, }, @@ -68,12 +67,18 @@ export default { ...pageInfo, }; - this.$apollo.queries.modelVersions.fetchMore({ - variables, - updateQuery: (previousResult, { fetchMoreResult }) => { - return fetchMoreResult; - }, - }); + this.$apollo.queries.modelVersions + .fetchMore({ + variables, + updateQuery: (previousResult, { fetchMoreResult }) => { + return fetchMoreResult; + }, + }) + .catch(this.handleError); + }, + handleError(error) { + this.errorMessage = makeLoadVersionsErrorMessage(error.message); + Sentry.captureException(error); }, }, modelVersionEntity: MODEL_ENTITIES.modelVersion, diff --git a/app/assets/javascripts/ml/model_registry/components/searchable_list.vue b/app/assets/javascripts/ml/model_registry/components/searchable_list.vue index 05062ae6fbf73..1ff8cc578a18c 100644 --- a/app/assets/javascripts/ml/model_registry/components/searchable_list.vue +++ b/app/assets/javascripts/ml/model_registry/components/searchable_list.vue @@ -2,11 +2,14 @@ import { GlAlert } from '@gitlab/ui'; import PackagesListLoader from '~/packages_and_registries/shared/components/packages_list_loader.vue'; import RegistryList from '~/packages_and_registries/shared/components/registry_list.vue'; -import { GRAPHQL_PAGE_SIZE } from '~/ml/model_registry/constants'; +import RegistrySearch from '~/vue_shared/components/registry/registry_search.vue'; +import { GRAPHQL_PAGE_SIZE, LIST_KEY_CREATED_AT } from '~/ml/model_registry/constants'; +import { queryToObject, setUrlParams, updateHistory } from '~/lib/utils/url_utility'; +import { FILTERED_SEARCH_TERM } from '~/vue_shared/components/filtered_search_bar/constants'; export default { name: 'SearchableList', - components: { PackagesListLoader, RegistryList, GlAlert }, + components: { PackagesListLoader, RegistryList, RegistrySearch, GlAlert }, props: { items: { type: Array, @@ -26,30 +29,92 @@ export default { required: false, default: '', }, + showSearch: { + type: Boolean, + required: false, + default: false, + }, + sortableFields: { + type: Array, + required: false, + default: () => [], + }, + }, + data() { + const query = queryToObject(window.location.search); + + const filter = query.name ? [{ value: { data: query.name }, type: FILTERED_SEARCH_TERM }] : []; + + const orderBy = query.orderBy || LIST_KEY_CREATED_AT; + + return { + filters: filter, + sorting: { + orderBy, + sort: (query.sort || 'desc').toLowerCase(), + }, + }; }, computed: { isListEmpty() { return this.items.length === 0; }, + parsedQuery() { + const name = this.filters + .map((f) => f.value.data) + .join(' ') + .trim(); + + const filterByQuery = name === '' ? {} : { name }; + + return { ...filterByQuery, ...this.sorting }; + }, + }, + created() { + this.nextPage(); }, methods: { prevPage() { - const pageInfo = { + const variables = { first: null, last: GRAPHQL_PAGE_SIZE, before: this.pageInfo.startCursor, + ...this.parsedQuery, }; - this.$emit('fetch-page', pageInfo); + this.fetchPage(variables); }, nextPage() { - const pageInfo = { + const variables = { first: GRAPHQL_PAGE_SIZE, last: null, after: this.pageInfo.endCursor, + ...this.parsedQuery, }; - this.$emit('fetch-page', pageInfo); + this.fetchPage(variables); + }, + fetchPage(variables) { + updateHistory({ + url: setUrlParams(variables, window.location.href, true), + title: document.title, + replace: true, + }); + + this.$emit('fetch-page', variables); + }, + submitFilters() { + this.fetchPage(this.parsedQuery); + }, + updateFilters(newValue) { + this.filters = newValue; + }, + updateSorting(newValue) { + this.sorting = { ...this.sorting, ...newValue }; + }, + updateSortingAndEmitUpdate(newValue) { + this.updateSorting(newValue); + this.submitFilters(); }, }, }; @@ -57,6 +122,16 @@ export default { <template> <div> + <registry-search + v-if="showSearch" + :filters="filters" + :sorting="sorting" + :sortable-fields="sortableFields" + @sorting:changed="updateSortingAndEmitUpdate" + @filter:changed="updateFilters" + @filter:submit="submitFilters" + @filter:clear="filters = []" + /> <packages-list-loader v-if="isLoading" /> <gl-alert v-else-if="errorMessage" variant="danger" :dismissible="false"> {{ errorMessage }} diff --git a/app/assets/javascripts/ml/model_registry/graphql/queries/get_models.query.graphql b/app/assets/javascripts/ml/model_registry/graphql/queries/get_models.query.graphql new file mode 100644 index 0000000000000..a9559bd7f5dcd --- /dev/null +++ b/app/assets/javascripts/ml/model_registry/graphql/queries/get_models.query.graphql @@ -0,0 +1,46 @@ +#import "~/graphql_shared/fragments/page_info.fragment.graphql" + +query getModels( + $fullPath: ID! + $name: String + $orderBy: MlModelsOrderBy + $sort: SortDirectionEnum + $first: Int + $last: Int + $after: String + $before: String +) { + project(fullPath: $fullPath) { + id + mlModels( + name: $name + orderBy: $orderBy + sort: $sort + after: $after + before: $before + first: $first + last: $last + ) { + count + nodes { + id + name + versionCount + createdAt + latestVersion { + id + version + _links { + showPath + } + } + _links { + showPath + } + } + pageInfo { + ...PageInfo + } + } + } +} diff --git a/app/assets/javascripts/ml/model_registry/translations.js b/app/assets/javascripts/ml/model_registry/translations.js index 006142979e258..9d3e1e7badb5d 100644 --- a/app/assets/javascripts/ml/model_registry/translations.js +++ b/app/assets/javascripts/ml/model_registry/translations.js @@ -47,6 +47,11 @@ export const makeLoadVersionsErrorMessage = (message) => message, }); +export const makeLoadModelErrorMessage = (message) => + sprintf(s__('MlModelRegistry|Failed to load model with error: %{message}'), { + message, + }); + export const NO_CANDIDATES_LABEL = s__('MlModelRegistry|This model has no candidates'); export const makeLoadCandidatesErrorMessage = (message) => sprintf(s__('MlModelRegistry|Failed to load model candidates with error: %{message}'), { diff --git a/app/controllers/projects/ml/models_controller.rb b/app/controllers/projects/ml/models_controller.rb index 2dff3ec33252c..1c5d773bc7d52 100644 --- a/app/controllers/projects/ml/models_controller.rb +++ b/app/controllers/projects/ml/models_controller.rb @@ -10,17 +10,7 @@ class ModelsController < ::Projects::ApplicationController MAX_MODELS_PER_PAGE = 20 - def index - find_params = params - .transform_keys(&:underscore) - .permit(:name, :order_by, :sort) - - finder = ::Projects::Ml::ModelFinder.new(@project, find_params) - - @paginator = finder.execute.keyset_paginate(cursor: params[:cursor], per_page: MAX_MODELS_PER_PAGE) - - @model_count = finder.count - end + def index; end def new; end diff --git a/app/helpers/projects/ml/model_registry_helper.rb b/app/helpers/projects/ml/model_registry_helper.rb new file mode 100644 index 0000000000000..828ae088b1b61 --- /dev/null +++ b/app/helpers/projects/ml/model_registry_helper.rb @@ -0,0 +1,30 @@ +# frozen_string_literal: true + +module Projects + module Ml + module ModelRegistryHelper + require 'json' + + def index_ml_model_data(project, user) + data = { + projectPath: project.full_path, + create_model_path: new_project_ml_model_path(project), + can_write_model_registry: user&.can?(:write_model_registry, project), + mlflow_tracking_url: mlflow_tracking_url(project) + } + + Gitlab::Json.generate(data.deep_transform_keys { |k| k.to_s.camelize(:lower) }) + end + + private + + def mlflow_tracking_url(project) + path = api_v4_projects_ml_mlflow_api_2_0_mlflow_registered_models_create_path(id: project.id) + + path = path.delete_suffix('registered-models/create') + + expose_url(path) + end + end + end +end diff --git a/app/views/projects/ml/models/index.html.haml b/app/views/projects/ml/models/index.html.haml index ba695bce435d1..1e6a268ed5c52 100644 --- a/app/views/projects/ml/models/index.html.haml +++ b/app/views/projects/ml/models/index.html.haml @@ -1,4 +1,4 @@ - breadcrumb_title s_('ModelRegistry|Model registry') - page_title s_('ModelRegistry|Model registry') -= render(Projects::Ml::ModelsIndexComponent.new(project: @project, current_user: current_user, paginator: @paginator, model_count: @model_count)) +#js-index-ml-models{ data: { view_model: index_ml_model_data(@project, @current_user) } } diff --git a/locale/gitlab.pot b/locale/gitlab.pot index 92bb70a880b90..c5b30afcee732 100644 --- a/locale/gitlab.pot +++ b/locale/gitlab.pot @@ -31334,6 +31334,9 @@ msgstr "" msgid "MlModelRegistry|Failed to load model versions with error: %{message}" msgstr "" +msgid "MlModelRegistry|Failed to load model with error: %{message}" +msgstr "" + msgid "MlModelRegistry|ID" msgstr "" diff --git a/spec/frontend/ml/model_registry/apps/index_ml_models_spec.js b/spec/frontend/ml/model_registry/apps/index_ml_models_spec.js index 07d8b4b8b3da8..12677470127e4 100644 --- a/spec/frontend/ml/model_registry/apps/index_ml_models_spec.js +++ b/spec/frontend/ml/model_registry/apps/index_ml_models_spec.js @@ -1,55 +1,80 @@ -import { GlBadge, GlButton } from '@gitlab/ui'; -import { shallowMountExtended } from 'helpers/vue_test_utils_helper'; +import { GlExperimentBadge } from '@gitlab/ui'; +import Vue from 'vue'; +import VueApollo from 'vue-apollo'; +import { mountExtended } from 'helpers/vue_test_utils_helper'; import { IndexMlModels } from '~/ml/model_registry/apps'; import ModelRow from '~/ml/model_registry/components/model_row.vue'; -import Pagination from '~/vue_shared/components/incubation/pagination.vue'; -import SearchBar from '~/ml/model_registry/components/search_bar.vue'; -import { BASE_SORT_FIELDS, MODEL_ENTITIES } from '~/ml/model_registry/constants'; +import { MODEL_ENTITIES } from '~/ml/model_registry/constants'; import TitleArea from '~/vue_shared/components/registry/title_area.vue'; import MetadataItem from '~/vue_shared/components/registry/metadata_item.vue'; import EmptyState from '~/ml/model_registry/components/empty_state.vue'; import ActionsDropdown from '~/ml/model_registry/components/actions_dropdown.vue'; -import { mockModels, startCursor, defaultPageInfo } from '../mock_data'; - -let wrapper; - -const createWrapper = (propsData = {}) => { - wrapper = shallowMountExtended(IndexMlModels, { - propsData: { - models: mockModels, - pageInfo: defaultPageInfo, - modelCount: 2, - createModelPath: 'path/to/create', - canWriteModelRegistry: false, - ...propsData, - }, - }); +import SearchableList from '~/ml/model_registry/components/searchable_list.vue'; +import createMockApollo from 'helpers/mock_apollo_helper'; +import getModelsQuery from '~/ml/model_registry/graphql/queries/get_models.query.graphql'; +import * as Sentry from '~/sentry/sentry_browser_wrapper'; +import waitForPromises from 'helpers/wait_for_promises'; +import { modelsQuery, modelWithOneVersion, modelWithoutVersion } from '../graphql_mock_data'; + +Vue.use(VueApollo); + +const defaultProps = { + projectPath: 'path/to/project', + createModelPath: 'path/to/create', + canWriteModelRegistry: false, }; -const findModelRow = (index) => wrapper.findAllComponents(ModelRow).at(index); -const findPagination = () => wrapper.findComponent(Pagination); -const findEmptyState = () => wrapper.findComponent(EmptyState); -const findSearchBar = () => wrapper.findComponent(SearchBar); -const findTitleArea = () => wrapper.findComponent(TitleArea); -const findModelCountMetadataItem = () => findTitleArea().findComponent(MetadataItem); -const findBadge = () => wrapper.findComponent(GlBadge); -const findCreateButton = () => findTitleArea().findComponent(GlButton); -const findActionsDropdown = () => wrapper.findComponent(ActionsDropdown); - describe('ml/model_registry/apps/index_ml_models', () => { - describe('empty state', () => { - beforeEach(() => createWrapper({ models: [], pageInfo: defaultPageInfo })); + let wrapper; + let apolloProvider; + + const createWrapper = ({ + props = {}, + resolver = jest.fn().mockResolvedValue(modelsQuery()), + } = {}) => { + const requestHandlers = [[getModelsQuery, resolver]]; + apolloProvider = createMockApollo(requestHandlers); + + const propsData = { + ...defaultProps, + ...props, + }; + + wrapper = mountExtended(IndexMlModels, { + apolloProvider, + propsData, + }); + }; - it('shows empty state', () => { - expect(findEmptyState().props('entityType')).toBe(MODEL_ENTITIES.model); + beforeEach(() => { + jest.spyOn(Sentry, 'captureException').mockImplementation(); + }); + + const emptyQueryResolver = () => jest.fn().mockResolvedValue(modelsQuery([])); + + const findAllRows = () => wrapper.findAllComponents(ModelRow); + const findRow = (index) => findAllRows().at(index); + const findEmptyState = () => wrapper.findComponent(EmptyState); + const findTitleArea = () => wrapper.findComponent(TitleArea); + const findModelCountMetadataItem = () => findTitleArea().findComponent(MetadataItem); + const findBadge = () => wrapper.findComponent(GlExperimentBadge); + const findCreateButton = () => wrapper.findByTestId('create-model-button'); + const findActionsDropdown = () => wrapper.findComponent(ActionsDropdown); + const findSearchableList = () => wrapper.findComponent(SearchableList); + + describe('header', () => { + beforeEach(() => { + createWrapper(); }); - it('does not show pagination', () => { - expect(findPagination().exists()).toBe(false); + it('displays the title', () => { + expect(findTitleArea().text()).toContain('Model registry'); }); - it('does not show search bar', () => { - expect(findSearchBar().exists()).toBe(false); + it('displays the experiment badge', () => { + expect(findBadge().props('helpPageUrl')).toBe( + '/help/user/project/ml/model_registry/index.md', + ); }); it('renders the extra actions button', () => { @@ -57,65 +82,155 @@ describe('ml/model_registry/apps/index_ml_models', () => { }); }); + describe('empty state', () => { + it('shows empty state', async () => { + createWrapper({ resolver: emptyQueryResolver() }); + + await waitForPromises(); + + expect(findEmptyState().props('entityType')).toBe(MODEL_ENTITIES.model); + }); + }); + describe('create button', () => { describe('when user has no permission to write model registry', () => { - it('does not display create button', () => { - createWrapper(); + it('does not display create button', async () => { + createWrapper({ resolver: emptyQueryResolver() }); + + await waitForPromises(); expect(findCreateButton().exists()).toBe(false); }); }); describe('when user has permission to write model registry', () => { - it('displays create button', () => { - createWrapper({ canWriteModelRegistry: true }); + it('displays create button', async () => { + createWrapper({ + props: { canWriteModelRegistry: true }, + resolver: emptyQueryResolver(), + }); + + await waitForPromises(); expect(findCreateButton().attributes().href).toBe('path/to/create'); }); }); }); + describe('when loading data fails', () => { + beforeEach(async () => { + const error = new Error('Failure!'); + + createWrapper({ resolver: jest.fn().mockRejectedValue(error) }); + + await waitForPromises(); + }); + + it('error message is displayed', () => { + expect(findSearchableList().props('errorMessage')).toBe( + 'Failed to load model with error: Failure!', + ); + }); + + it('error is logged in sentry', () => { + expect(Sentry.captureException).toHaveBeenCalled(); + }); + }); + describe('with data', () => { - beforeEach(() => { + it('does not show empty state', async () => { createWrapper(); - }); + await waitForPromises(); - it('does not show empty state', () => { expect(findEmptyState().exists()).toBe(false); }); describe('header', () => { - it('displays the title', () => { - expect(findTitleArea().text()).toContain('Model registry'); + it('sets model metadata item to model count', async () => { + createWrapper(); + await waitForPromises(); + + expect(findModelCountMetadataItem().props('text')).toBe('2 models'); }); + }); - it('displays the experiment badge', () => { - expect(findBadge().attributes().href).toBe('/help/user/project/ml/model_registry/index.md'); + describe('shows models', () => { + beforeEach(async () => { + createWrapper(); + await waitForPromises(); }); - it('sets model metadata item to model count', () => { - expect(findModelCountMetadataItem().props('text')).toBe(`2 models`); + it('passes items to list', () => { + expect(findSearchableList().props('items')).toEqual([ + modelWithOneVersion, + modelWithoutVersion, + ]); }); - }); - it('adds a search bar', () => { - expect(findSearchBar().props()).toMatchObject({ sortableFields: BASE_SORT_FIELDS }); - }); + it('displays package version rows', () => { + expect(findAllRows()).toHaveLength(2); + }); + + it('binds the correct props', () => { + expect(findRow(0).props()).toMatchObject({ + model: expect.objectContaining(modelWithOneVersion), + }); - describe('model list', () => { - it('displays the models', () => { - expect(findModelRow(0).props('model')).toMatchObject(mockModels[0]); - expect(findModelRow(1).props('model')).toMatchObject(mockModels[1]); + expect(findRow(1).props()).toMatchObject({ + model: expect.objectContaining(modelWithoutVersion), + }); }); }); - describe('pagination', () => { - it('should show', () => { - expect(findPagination().exists()).toBe(true); + describe('when query is updated', () => { + let resolver; + + beforeEach(() => { + resolver = jest.fn().mockResolvedValue(modelsQuery()); + createWrapper({ resolver }); + }); + + it('when orderBy or sort are not present, use default value', async () => { + findSearchableList().vm.$emit('fetch-page', { + after: 'eyJpZCI6IjIifQ', + first: 30, + }); + + await waitForPromises(); + + expect(resolver).toHaveBeenLastCalledWith( + expect.objectContaining({ + fullPath: 'path/to/project', + first: 30, + name: undefined, + orderBy: 'CREATED_AT', + sort: 'DESC', + after: 'eyJpZCI6IjIifQ', + }), + ); }); - it('passes pagination to pagination component', () => { - expect(findPagination().props('startCursor')).toBe(startCursor); + it('when orderBy or sort present, updates filters', async () => { + findSearchableList().vm.$emit('fetch-page', { + after: 'eyJpZCI6IjIifQ', + first: 30, + orderBy: 'name', + sort: 'asc', + name: 'something', + }); + + await waitForPromises(); + + expect(resolver).toHaveBeenLastCalledWith( + expect.objectContaining({ + fullPath: 'path/to/project', + first: 30, + name: 'something', + orderBy: 'NAME', + sort: 'ASC', + after: 'eyJpZCI6IjIifQ', + }), + ); }); }); }); diff --git a/spec/frontend/ml/model_registry/components/model_row_spec.js b/spec/frontend/ml/model_registry/components/model_row_spec.js index 0972929235514..02359949f5a48 100644 --- a/spec/frontend/ml/model_registry/components/model_row_spec.js +++ b/spec/frontend/ml/model_registry/components/model_row_spec.js @@ -1,38 +1,45 @@ -import { GlLink } from '@gitlab/ui'; +import { GlLink, GlTruncate } from '@gitlab/ui'; import { shallowMountExtended } from 'helpers/vue_test_utils_helper'; +import ListItem from '~/vue_shared/components/registry/list_item.vue'; import ModelRow from '~/ml/model_registry/components/model_row.vue'; -import { mockModels, modelWithoutVersion } from '../mock_data'; +import { modelWithOneVersion, modelWithVersions, modelWithoutVersion } from '../graphql_mock_data'; let wrapper; -const createWrapper = (model = mockModels[0]) => { +const createWrapper = (model = modelWithVersions) => { wrapper = shallowMountExtended(ModelRow, { propsData: { model } }); }; -const findTitleLink = () => wrapper.findAllComponents(GlLink).at(0); -const findVersionLink = () => wrapper.findAllComponents(GlLink).at(1); +const findListItem = () => wrapper.findComponent(ListItem); +const findTitleLink = () => findListItem().findAllComponents(GlLink).at(0); +const findTruncated = () => findTitleLink().findComponent(GlTruncate); +const findVersionLink = () => findListItem().findAllComponents(GlLink).at(1); const findMessage = (message) => wrapper.findByText(message); describe('ModelRow', () => { it('Has a link to the model', () => { createWrapper(); - expect(findTitleLink().text()).toBe(mockModels[0].name); - expect(findTitleLink().attributes('href')).toBe(mockModels[0].path); + expect(findTruncated().props('text')).toBe(modelWithVersions.name); + expect(findTitleLink().attributes('href')).toBe(modelWithVersions._links.showPath); }); it('Shows the latest version and the version count', () => { createWrapper(); - expect(findVersionLink().text()).toBe(mockModels[0].version); - expect(findVersionLink().attributes('href')).toBe(mockModels[0].versionPath); - expect(findMessage('· 3 versions').exists()).toBe(true); + expect(findVersionLink().text()).toBe(modelWithVersions.latestVersion.version); + expect(findVersionLink().attributes('href')).toBe( + modelWithVersions.latestVersion._links.showPath, + ); + expect(findMessage('· 2 versions').exists()).toBe(true); }); it('Shows the latest version and no version count if it has only 1 version', () => { - createWrapper(mockModels[1]); + createWrapper(modelWithOneVersion); - expect(findVersionLink().text()).toBe(mockModels[1].version); - expect(findVersionLink().attributes('href')).toBe(mockModels[1].versionPath); + expect(findVersionLink().text()).toBe(modelWithOneVersion.latestVersion.version); + expect(findVersionLink().attributes('href')).toBe( + modelWithOneVersion.latestVersion._links.showPath, + ); expect(findMessage('· 1 version').exists()).toBe(true); }); diff --git a/spec/frontend/ml/model_registry/components/searchable_list_spec.js b/spec/frontend/ml/model_registry/components/searchable_list_spec.js index ea58a9a830abe..67c61a0d58369 100644 --- a/spec/frontend/ml/model_registry/components/searchable_list_spec.js +++ b/spec/frontend/ml/model_registry/components/searchable_list_spec.js @@ -3,6 +3,9 @@ import { shallowMountExtended } from 'helpers/vue_test_utils_helper'; import SearchableList from '~/ml/model_registry/components/searchable_list.vue'; import PackagesListLoader from '~/packages_and_registries/shared/components/packages_list_loader.vue'; import RegistryList from '~/packages_and_registries/shared/components/registry_list.vue'; +import RegistrySearch from '~/vue_shared/components/registry/registry_search.vue'; +import { BASE_SORT_FIELDS } from '~/ml/model_registry/constants'; +import * as urlHelpers from '~/lib/utils/url_utility'; import { defaultPageInfo } from '../mock_data'; describe('ml/model_registry/components/searchable_list.vue', () => { @@ -14,12 +17,23 @@ describe('ml/model_registry/components/searchable_list.vue', () => { const findEmptyState = () => wrapper.findByTestId('empty-state-slot'); const findFirstRow = () => wrapper.findByTestId('element'); const findRows = () => wrapper.findAllByTestId('element'); + const findSearch = () => wrapper.findComponent(RegistrySearch); + + const expectedFirstPage = { + after: 'eyJpZCI6IjIifQ', + first: 30, + last: null, + orderBy: 'created_at', + sort: 'desc', + }; const defaultProps = { items: ['a', 'b', 'c'], pageInfo: defaultPageInfo, isLoading: false, errorMessage: '', + showSearch: false, + sortableFields: [], }; const mountComponent = (props = {}) => { @@ -143,6 +157,12 @@ describe('ml/model_registry/components/searchable_list.vue', () => { describe('when user interacts with pagination', () => { beforeEach(() => mountComponent()); + it('when it is created emits fetch-page to get first page', () => { + mountComponent({ showSearch: true, sortableFields: BASE_SORT_FIELDS }); + + expect(wrapper.emitted('fetch-page')).toEqual([[expectedFirstPage]]); + }); + it('when list emits next-page emits fetchPage with correct pageInfo', () => { findRegistryList().vm.$emit('next-page'); @@ -150,9 +170,11 @@ describe('ml/model_registry/components/searchable_list.vue', () => { after: 'eyJpZCI6IjIifQ', first: 30, last: null, + orderBy: 'created_at', + sort: 'desc', }; - expect(wrapper.emitted('fetch-page')).toEqual([[expectedNewPageInfo]]); + expect(wrapper.emitted('fetch-page')).toEqual([[expectedFirstPage], [expectedNewPageInfo]]); }); it('when list emits prev-page emits fetchPage with correct pageInfo', () => { @@ -162,9 +184,77 @@ describe('ml/model_registry/components/searchable_list.vue', () => { before: 'eyJpZCI6IjE2In0', first: null, last: 30, + orderBy: 'created_at', + sort: 'desc', + }; + + expect(wrapper.emitted('fetch-page')).toEqual([[expectedFirstPage], [expectedNewPageInfo]]); + }); + }); + + describe('search', () => { + beforeEach(() => { + jest.spyOn(urlHelpers, 'updateHistory').mockImplementation(() => {}); + }); + + it('does not show search bar when showSearch is false', () => { + mountComponent({ showSearch: false }); + + expect(findSearch().exists()).toBe(false); + }); + + it('mounts search correctly', () => { + mountComponent({ showSearch: true, sortableFields: BASE_SORT_FIELDS }); + + expect(findSearch().props()).toMatchObject({ + filters: [], + sorting: { + orderBy: 'created_at', + sort: 'desc', + }, + sortableFields: BASE_SORT_FIELDS, + }); + }); + + it('on search submit, emits fetch-page with correct variables', () => { + mountComponent({ showSearch: true, sortableFields: BASE_SORT_FIELDS }); + + findSearch().vm.$emit('filter:submit'); + + const expectedVariables = { + orderBy: 'created_at', + sort: 'desc', + }; + + expect(wrapper.emitted('fetch-page')).toEqual([[expectedFirstPage], [expectedVariables]]); + }); + + it('on sorting changed, emits fetch-page with correct variables', () => { + mountComponent({ showSearch: true, sortableFields: BASE_SORT_FIELDS }); + + const orderBy = 'name'; + findSearch().vm.$emit('sorting:changed', { orderBy }); + + const expectedVariables = { + orderBy: 'name', + sort: 'desc', + }; + + expect(wrapper.emitted('fetch-page')).toEqual([[expectedFirstPage], [expectedVariables]]); + }); + + it('on direction changed, emits fetch-page with correct variables', () => { + mountComponent({ showSearch: true, sortableFields: BASE_SORT_FIELDS }); + + const sort = 'asc'; + findSearch().vm.$emit('sorting:changed', { sort }); + + const expectedVariables = { + orderBy: 'created_at', + sort: 'asc', }; - expect(wrapper.emitted('fetch-page')).toEqual([[expectedNewPageInfo]]); + expect(wrapper.emitted('fetch-page')).toEqual([[expectedFirstPage], [expectedVariables]]); }); }); }); diff --git a/spec/frontend/ml/model_registry/graphql_mock_data.js b/spec/frontend/ml/model_registry/graphql_mock_data.js index 27424fbf0dfc1..b44963577bf6d 100644 --- a/spec/frontend/ml/model_registry/graphql_mock_data.js +++ b/spec/frontend/ml/model_registry/graphql_mock_data.js @@ -138,3 +138,64 @@ export const createModelResponses = { }, }, }; + +export const modelWithVersions = { + id: 'gid://gitlab/Ml::Model/1', + name: 'model_1', + versionCount: 2, + createdAt: '2023-12-06T12:41:48Z', + latestVersion: { + id: 'gid://gitlab/Ml::ModelVersion/1', + version: '1.0.0', + _links: { + showPath: '/my_project/-/ml/models/1/versions/1', + }, + }, + _links: { + showPath: '/my_project/-/ml/models/1', + }, +}; + +export const modelWithOneVersion = { + id: 'gid://gitlab/Ml::Model/2', + name: 'model_2', + versionCount: 1, + createdAt: '2023-12-06T12:41:48Z', + latestVersion: { + id: 'gid://gitlab/Ml::ModelVersion/1', + version: '1.0.0', + _links: { + showPath: '/my_project/-/ml/models/2/versions/1', + }, + }, + _links: { + showPath: '/my_project/-/ml/models/2', + }, +}; + +export const modelWithoutVersion = { + id: 'gid://gitlab/Ml::Model/3', + name: 'model_3', + versionCount: 0, + latestVersion: null, + createdAt: '2023-12-06T12:41:48Z', + _links: { + showPath: '/my_project/-/ml/models/3', + }, +}; + +export const modelsQuery = ( + models = [modelWithOneVersion, modelWithoutVersion], + pageInfo = graphqlPageInfo, +) => ({ + data: { + project: { + id: 'gid://gitlab/Project/1', + mlModels: { + count: models.length, + nodes: models, + pageInfo, + }, + }, + }, +}); diff --git a/spec/helpers/projects/ml/model_registry_helper_spec.rb b/spec/helpers/projects/ml/model_registry_helper_spec.rb new file mode 100644 index 0000000000000..2180d4388ca97 --- /dev/null +++ b/spec/helpers/projects/ml/model_registry_helper_spec.rb @@ -0,0 +1,37 @@ +# frozen_string_literal: true + +require 'rspec' + +require 'spec_helper' +require 'mime/types' + +RSpec.describe Projects::Ml::ModelRegistryHelper, feature_category: :mlops do + let_it_be(:project) { build_stubbed(:project) } + let_it_be(:user) { project.owner } + + describe '#index_ml_model_data' do + subject(:parsed) { Gitlab::Json.parse(helper.index_ml_model_data(project, user)) } + + it 'generates the correct data' do + is_expected.to eq({ + 'projectPath' => project.full_path, + 'createModelPath' => "/#{project.full_path}/-/ml/models/new", + 'canWriteModelRegistry' => true, + 'mlflowTrackingUrl' => "http://localhost/api/v4/projects/#{project.id}/ml/mlflow/api/2.0/mlflow/" + }) + end + + context 'when user does not have write access to model registry' do + before do + allow(Ability).to receive(:allowed?).and_call_original + allow(Ability).to receive(:allowed?) + .with(user, :write_model_registry, project) + .and_return(false) + end + + it 'canWriteModelRegistry is false' do + expect(parsed['canWriteModelRegistry']).to eq(false) + end + end + end +end diff --git a/spec/requests/projects/ml/models_controller_spec.rb b/spec/requests/projects/ml/models_controller_spec.rb index e469ee837bcd8..8f68d0c8d0013 100644 --- a/spec/requests/projects/ml/models_controller_spec.rb +++ b/spec/requests/projects/ml/models_controller_spec.rb @@ -6,9 +6,6 @@ let_it_be(:project) { create(:project) } let_it_be(:user) { project.first_owner } let_it_be(:model1) { create(:ml_models, :with_versions, project: project) } - let_it_be(:model2) { create(:ml_models, project: project) } - let_it_be(:model3) { create(:ml_models, project: project) } - let_it_be(:model_in_different_project) { create(:ml_models) } let(:read_model_registry) { true } let(:write_model_registry) { true } @@ -37,36 +34,6 @@ expect(index_request).to render_template('projects/ml/models/index') end - it 'fetches the models using the finder' do - expect(::Projects::Ml::ModelFinder).to receive(:new).with(project, {}).and_call_original - - index_request - end - - it 'fetches the correct variables', :aggregate_failures do - stub_const("Projects::Ml::ModelsController::MAX_MODELS_PER_PAGE", 2) - - index_request - - page_models = [model3, model2] - all_models = [model3, model2, model1] - - expect(assigns(:paginator).records).to match_array(page_models) - expect(assigns(:model_count)).to be all_models.count - end - - it 'does not perform N+1 sql queries' do - list_models - - control_count = ActiveRecord::QueryRecorder.new(skip_cached: false) { list_models } - - create_list(:ml_model_versions, 2, model: model1) - create_list(:ml_model_versions, 2, model: model2) - create_list(:ml_models, 4, project: project) - - expect { list_models }.not_to exceed_all_query_limit(control_count) - end - context 'when user does not have access' do let(:read_model_registry) { false } @@ -74,40 +41,6 @@ is_expected.to have_gitlab_http_status(:not_found) end end - - context 'with search params' do - let(:params) { { name: 'some_name', order_by: 'name', sort: 'asc' } } - - it 'passes down params to the finder' do - expect(Projects::Ml::ModelFinder).to receive(:new).and_call_original do |_exp, params| - expect(params.to_h).to include({ - name: 'some_name', - order_by: 'name', - sort: 'asc' - }) - end - - index_request - end - end - - describe 'pagination' do - before do - stub_const("Projects::Ml::ModelsController::MAX_MODELS_PER_PAGE", 2) - end - - it 'paginates', :aggregate_failures do - list_models - - paginator = assigns(:paginator) - - expect(paginator.records).to match_array([model3, model2]) - - list_models({ cursor: paginator.cursor_for_next_page }) - - expect(assigns(:paginator).records.first).to eq(model1) - end - end end describe 'show' do @@ -140,7 +73,7 @@ end context 'when model project does not match project id' do - let(:request_project) { model_in_different_project.project } + let(:request_project) { create(:project) } it { is_expected.to have_gitlab_http_status(:not_found) } end -- GitLab