Skip to content

Commit

Permalink
Feat/vector db manage (langgenius#997)
Browse files Browse the repository at this point in the history
Co-authored-by: jyong <[email protected]>
  • Loading branch information
JohnJyong and JohnJyong authored Aug 24, 2023
1 parent 5397799 commit b1fd1b3
Show file tree
Hide file tree
Showing 8 changed files with 1,975 additions and 10 deletions.
68 changes: 68 additions & 0 deletions api/commands.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import datetime
import json
import math
import random
import string
import time

import click
from flask import current_app
from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound

from core.embedding.cached_embedding import CacheEmbedding
from core.index.index import IndexBuilder
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.hosted import hosted_model_providers
from core.model_providers.providers.openai_provider import OpenAIProvider
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
Expand Down Expand Up @@ -296,6 +303,66 @@ def sync_anthropic_hosted_providers():
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))


@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
def create_qdrant_indexes():
click.echo(click.style('Start create qdrant indexes.', fg='green'))
create_count = 0

page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break

page += 1
for dataset in datasets:
try:
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)

from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.create_qdrant_dataset(dataset)
create_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue

click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))


def register_commands(app):
app.cli.add_command(reset_password)
app.cli.add_command(reset_email)
Expand All @@ -304,3 +371,4 @@ def register_commands(app):
app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)
app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes)
2 changes: 1 addition & 1 deletion api/core/data_loader/loader/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def load(self) -> List[Document]:
else:
row_dict = dict(zip(keys, list(map(str, row))))
row_dict = {k: v for k, v in row_dict.items() if v}
item = ''.join(f'{k}:{v}\n' for k, v in row_dict.items())
item = ''.join(f'{k}:{v};' for k, v in row_dict.items())
document = Document(page_content=item, metadata={'source': self._file_path})
data.append(document)

Expand Down
46 changes: 46 additions & 0 deletions api/core/index/vector_index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,49 @@ def recreate_dataset(self, dataset: Dataset):

self.dataset = dataset
logging.info(f"Dataset {dataset.id} recreate successfully.")

def create_qdrant_dataset(self, dataset: Dataset):
logging.info(f"create_qdrant_dataset {dataset.id}")

try:
self.delete()
except UnexpectedStatusCodeException as e:
if e.status_code != 400:
# 400 means index not exists
raise e

dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()

documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()

for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)

documents.append(document)

if documents:
try:
self.create(documents)
except Exception as e:
raise e

logging.info(f"Dataset {dataset.id} recreate successfully.")
114 changes: 114 additions & 0 deletions api/core/index/vector_index/milvus_vector_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Optional, cast

from langchain.embeddings.base import Embeddings
from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore, milvus
from pydantic import BaseModel, root_validator

from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset


class MilvusConfig(BaseModel):
endpoint: str
user: str
password: str
batch_size: int = 100

@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config MILVUS_ENDPOINT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values


class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)

def get_type(self) -> str:
return 'milvus'

def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'

return class_prefix

dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'


def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self.get_index_name(self.dataset)}
}

def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)

return self

def _get_vector_store(self) -> VectorStore:
"""Only for created index."""
if self._vector_store:
return self._vector_store

attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']

return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
)

def _get_vector_store_class(self) -> type:
return MilvusVectorStore

def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return

vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})

def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True

return False
Loading

0 comments on commit b1fd1b3

Please sign in to comment.