Skip to content
代码片段 群组 项目
未验证 提交 05d02a39 编辑于 作者: Eduardo Bonet's avatar Eduardo Bonet 提交者: GitLab
浏览文件

Improvements to the index creation script

上级 0b42ee37
No related branches found
No related tags found
无相关合并请求
# Usage
# 1. Install requirements:
# pip install requests langchain langchain_text_splitter
# 2. Run the script:
# GLAB_TOKEN=<api_token> python3 scripts/custom_models/create_index.py --version_tag="v17.0.0"
import argparse import argparse
import glob import glob
import os import os
import datetime import datetime
import re import re
import sqlite3 import sqlite3
import sys
import requests import requests
import json import json
from zipfile import ZipFile from zipfile import ZipFile
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain_text_splitters import MarkdownHeaderTextSplitter from langchain_text_splitters import MarkdownHeaderTextSplitter
import tempfile import tempfile
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Function to parse command-line arguments # Function to parse command-line arguments
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="Generate and upload GitLab docs index.") parser = argparse.ArgumentParser(description="Generate and upload GitLab docs index.")
parser.add_argument("--version_tag", help="GitLab version tag to include in the URL (e.g., v17.1.0-ee)") parser.add_argument("--project_id", help="GitLab project ID", default=278964)
parser.add_argument("upload_url", help="URL to upload the database") parser.add_argument("--version_tag", help="GitLab version tag to include in the URL (e.g., v17.1.0-ee)",
parser.add_argument("private_token", help="GitLab personal access token") default='master')
parser.add_argument("--base_url", help="URL to gitlab instance", default="https://gitlab.com")
return parser.parse_args() return parser.parse_args()
def execution_error(error_message):
logger.error(error_message)
sys.exit(1)
# Function to fetch documents from GitLab # Function to fetch documents from GitLab
def fetch_documents(version_tag=None): def fetch_documents(version_tag):
if version_tag: docs_url = f"https://gitlab.com/gitlab-org/gitlab/-/archive/{version_tag}/gitlab-{version_tag}.zip?path=doc"
docs_url = f"https://gitlab.com/gitlab-org/gitlab/-/archive/{version_tag}/gitlab-{version_tag}.zip?path=doc"
else:
print("No version tag provided. Defaulting to fetching from master.")
docs_url = f"https://gitlab.com/gitlab-org/gitlab/-/archive/master/gitlab-master.zip?path=doc"
response = requests.get(docs_url) response = requests.get(docs_url)
...@@ -40,23 +55,44 @@ def fetch_documents(version_tag=None): ...@@ -40,23 +55,44 @@ def fetch_documents(version_tag=None):
# Find the directory that was extracted # Find the directory that was extracted
extracted_dirs = [os.path.join(tmpdirname, name) for name in os.listdir(tmpdirname) if os.path.isdir(os.path.join(tmpdirname, name))] extracted_dirs = [os.path.join(tmpdirname, name) for name in os.listdir(tmpdirname) if os.path.isdir(os.path.join(tmpdirname, name))]
if not extracted_dirs: if not extracted_dirs:
print("No directory found after extraction. Exiting.") execution_error("No directory found after extraction. Exiting.")
return None
print("Documents are fetched.") logger.info("Documents are fetched.")
extracted_dir = extracted_dirs[0] extracted_dir = extracted_dirs[0]
print(f"Extracted documents to {extracted_dir}") logger.info(f"Extracted documents to {extracted_dir}")
return extracted_dir return extracted_dir
else: else:
print(f"Failed to download documents. Status code: {response.status_code}") execution_error(f"Failed to download documents. Status code: {response.status_code}")
return None
def upload_url(base_url, project_id, version_tag):
return f"{base_url}/api/v4/projects/{project_id}/packages/generic/gitlab-duo-local-documentation-index/{version_tag}/docs.db"
def build_row_corpus(row):
corpus = row['content']
# Remove the preamble
preamble_start = corpus.find('---')
if preamble_start != -1:
preamble_end = corpus.find('---', preamble_start + 1)
corpus = corpus[preamble_end + 2:-1]
if not corpus:
return ''
# Attach the titles to the corpus, these can still be useful
corpus = ''.join(row['metadata'].get(f"Header{i}", '') for i in range(1, 6)) + ' ' + corpus
# Stemming could be helpful, but it is already applied by the sqlite
# Remove punctuation and set to lowercase, this should reduce the size of the corpus and allow
# the query to be a bit more robust
corpus = corpus.lower()
corpus = re.sub(r'[^\w\s]', '', corpus)
return corpus
# Function to process documents and create the database # Function to process documents and create the database
def create_database(path, output_path): def create_database(path, output_path):
files = glob.glob(os.path.join(path, "doc/**/*.md"), recursive=True) files = glob.glob(os.path.join(path, "doc/**/*.md"), recursive=True)
if not files: if not files:
print("No markdown files found. Exiting.") execution_error("No markdown files found")
return
documents = [] documents = []
...@@ -86,25 +122,6 @@ def create_database(path, output_path): ...@@ -86,25 +122,6 @@ def create_database(path, output_path):
metadata = {**chunk.metadata, **d.metadata} metadata = {**chunk.metadata, **d.metadata}
rows_to_insert.append({"content": chunk.page_content, "metadata": metadata}) rows_to_insert.append({"content": chunk.page_content, "metadata": metadata})
# Process each row to yield better results
def build_row_corpus(row):
corpus = row['content']
# Remove the preamble
preamble_start = corpus.find('---')
if preamble_start != -1:
preamble_end = corpus.find('---', preamble_start + 1)
corpus = corpus[preamble_end + 2:-1]
if not corpus:
return ''
# Attach the titles to the corpus, these can still be useful
corpus = ''.join(row['metadata'].get(f"Header{i}", '') for i in range(1, 6)) + ' ' + corpus
# Stemming could be helpful, but it is already applied by the sqlite
# Remove punctuation and set to lowercase, this should reduce the size of the corpus and allow
# the query to be a bit more robust
corpus = corpus.lower()
corpus = re.sub(r'[^\w\s]', '', corpus)
return corpus
for r in rows_to_insert: for r in rows_to_insert:
r['processed'] = build_row_corpus(r) r['processed'] = build_row_corpus(r)
# sql_tuples = [(r['processed'], r['content'], r['metadata']['filename']) for r in rows_to_insert if r['processed']] # sql_tuples = [(r['processed'], r['content'], r['metadata']['filename']) for r in rows_to_insert if r['processed']]
...@@ -117,7 +134,8 @@ def create_database(path, output_path): ...@@ -117,7 +134,8 @@ def create_database(path, output_path):
conn.commit() conn.commit()
conn.close() conn.close()
# Function to upload the database file to GitLab model registry
# Function to upload the database file to GitLab package registry
def upload_to_gitlab(upload_url, file_path, private_token): def upload_to_gitlab(upload_url, file_path, private_token):
headers = {"Authorization": f"Bearer {private_token}"} headers = {"Authorization": f"Bearer {private_token}"}
...@@ -126,31 +144,35 @@ def upload_to_gitlab(upload_url, file_path, private_token): ...@@ -126,31 +144,35 @@ def upload_to_gitlab(upload_url, file_path, private_token):
response = requests.put(upload_url, headers=headers, files=files) response = requests.put(upload_url, headers=headers, files=files)
if response.status_code in {200, 201}: if response.status_code in {200, 201}:
print("Database uploaded successfully.") logger.info("Database uploaded successfully.")
else: else:
print(f"Upload failed with status code: {response.status_code}, response: {response.content}") logger.error(f"Upload failed with status code: {response.status_code}, response: {response.content}")
# Main function if __name__ == "__main__":
def main():
args = parse_arguments() args = parse_arguments()
private_token = os.environ['GLAB_TOKEN']
if not private_token:
execution_error("Private token must be set.")
# Fetch documents based on version tag (if provided) # Fetch documents based on version tag (if provided)
docs_path = fetch_documents(version_tag=args.version_tag) docs_path = fetch_documents(version_tag=args.version_tag)
if not docs_path: if not docs_path:
print("Fetching documents failed. Exiting.") execution_error("Fetching documents failed")
return
# Create database # Create database
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
output_path = f"{docs_path}/created_index_docs_{timestamp}.db" output_path = f"{docs_path}/created_index_docs_{timestamp}.db"
create_database(docs_path, output_path) create_database(docs_path, output_path)
print(f"Database created at {output_path}") logger.info(f"Database created at {output_path}")
# Upload to GitLab # Upload to GitLab
if os.path.exists(output_path): if not os.path.exists(output_path):
upload_to_gitlab(args.upload_url, output_path, args.private_token) execution_error("Database file not found.")
else:
print("Database file not found. Upload skipped.")
if __name__ == "__main__": url = upload_url(args.base_url, args.project_id, args.version_tag)
main()
logger.info(f"Uploading to {url}")
upload_to_gitlab(url, output_path, private_token)
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册