diff --git a/libs/weaviate/langchain_weaviate/vectorstores.py b/libs/weaviate/langchain_weaviate/vectorstores.py index 43d1f8f..d3ca03a 100644 --- a/libs/weaviate/langchain_weaviate/vectorstores.py +++ b/libs/weaviate/langchain_weaviate/vectorstores.py @@ -91,6 +91,7 @@ def __init__( index_name: Optional[str], text_key: str, embedding: Optional[Embeddings] = None, + schema: Optional[dict] = None, attributes: Optional[List[str]] = None, relevance_score_fn: Optional[ Callable[[float], float] @@ -113,12 +114,15 @@ def __init__( if attributes is not None: self._query_attrs.extend(attributes) - schema = _default_schema(self._index_name) - schema["MultiTenancyConfig"] = {"enabled": use_multi_tenancy} + if not schema: + self.schema = _default_schema(self._index_name) + self.schema["MultiTenancyConfig"] = {"enabled": use_multi_tenancy} + else: + self.schema = schema # check whether the index already exists if not client.collections.exists(self._index_name): - client.collections.create_from_dict(schema) + client.collections.create_from_dict(self.schema) # store collection for convenience # this does not actually send a request to weaviate