Skip to content

Commit

Permalink
chore: adhere to PEP 585 and removed unused imports (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwarajanand authored Dec 10, 2024
1 parent e7d2c05 commit 779f18a
Show file tree
Hide file tree
Showing 19 changed files with 311 additions and 325 deletions.
4 changes: 2 additions & 2 deletions samples/index_tuning_sample/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class HNSWIndex(
index_type: str = "hnsw",
# Distance strategy does not affect recall and has minimal little on latency; refer to this guide to learn more https://cloud.google.com/spanner/docs/choose-vector-distance-function
distance_strategy: DistanceStrategy = lambda : DistanceStrategy.COSINE_DISTANCE,
partial_indexes: List[str] | None = None,
partial_indexes: list[str] | None = None,
m: int = 16,
ef_construction: int = 64
)
Expand Down Expand Up @@ -235,7 +235,7 @@ class IVFFlatIndex(
name: str = DEFAULT_INDEX_NAME,
index_type: str = "ivfflat",
distance_strategy: DistanceStrategy = lambda : DistanceStrategy.COSINE_DISTANCE,
partial_indexes: List[str] | None = None,
partial_indexes: list[str] | None = None,
lists: int = 1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List

import vertexai # type: ignore
from config import (
Expand All @@ -39,14 +38,14 @@
engine = None # Use global variable to share connection pooling


def similarity_search(query: str) -> List[Document]:
def similarity_search(query: str) -> list[Document]:
"""Searches and returns movies.
Args:
query: The user query to search for related items
Returns:
List[Document]: A list of Documents
list[Document]: A list of Documents
"""
global engine
if not engine: # Reuse connection pool
Expand Down
4 changes: 2 additions & 2 deletions src/langchain_google_alloydb_pg/async_chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import json
from typing import List, Sequence
from typing import Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, messages_from_dict
Expand Down Expand Up @@ -128,7 +128,7 @@ async def aclear(self) -> None:
await conn.execute(text(query), {"session_id": self.session_id})
await conn.commit()

async def _aget_messages(self) -> List[BaseMessage]:
async def _aget_messages(self) -> list[BaseMessage]:
"""Retrieve the messages from AlloyDB."""
query = f"""SELECT data, type FROM "{self.schema_name}"."{self.table_name}" WHERE session_id = :session_id ORDER BY id;"""
async with self.pool.connect() as conn:
Expand Down
50 changes: 25 additions & 25 deletions src/langchain_google_alloydb_pg/async_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import json
from typing import Any, AsyncIterator, Callable, Dict, Iterable, List, Optional
from typing import Any, AsyncIterator, Callable, Iterable, Optional

from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document
Expand All @@ -28,24 +28,24 @@
DEFAULT_METADATA_COL = "langchain_metadata"


def text_formatter(row: dict, content_columns: List[str]) -> str:
def text_formatter(row: dict, content_columns: list[str]) -> str:
"""txt document formatter."""
return " ".join(str(row[column]) for column in content_columns if column in row)


def csv_formatter(row: dict, content_columns: List[str]) -> str:
def csv_formatter(row: dict, content_columns: list[str]) -> str:
"""CSV document formatter."""
return ", ".join(str(row[column]) for column in content_columns if column in row)


def yaml_formatter(row: dict, content_columns: List[str]) -> str:
def yaml_formatter(row: dict, content_columns: list[str]) -> str:
"""YAML document formatter."""
return "\n".join(
f"{column}: {str(row[column])}" for column in content_columns if column in row
)


def json_formatter(row: dict, content_columns: List[str]) -> str:
def json_formatter(row: dict, content_columns: list[str]) -> str:
"""JSON document formatter."""
dictionary = {}
for column in content_columns:
Expand All @@ -63,7 +63,7 @@ def _parse_doc_from_row(
) -> Document:
"""Parse row into document."""
page_content = formatter(row, content_columns)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
# unnest metadata from langchain_metadata column
if metadata_json_column and row.get(metadata_json_column):
for k, v in row[metadata_json_column].items():
Expand All @@ -81,10 +81,10 @@ def _parse_row_from_doc(
column_names: Iterable[str],
content_column: str = DEFAULT_CONTENT_COL,
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
) -> Dict:
) -> dict:
"""Parse document into a dictionary of rows."""
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {content_column: doc.page_content}
row: dict[str, Any] = {content_column: doc.page_content}
for entry in doc.metadata:
if entry in column_names:
row[entry] = doc_metadata[entry]
Expand All @@ -111,8 +111,8 @@ def __init__(
key: object,
pool: AsyncEngine,
query: str,
content_columns: List[str],
metadata_columns: List[str],
content_columns: list[str],
metadata_columns: list[str],
formatter: Callable,
metadata_json_column: Optional[str] = None,
) -> None:
Expand All @@ -122,8 +122,8 @@ def __init__(
key (object): Prevent direct constructor usage.
engine (AlloyDBEngine): AsyncEngine with pool connection to the postgres database
query (Optional[str], optional): SQL query. Defaults to None.
content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None.
metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata".
Expand All @@ -149,8 +149,8 @@ async def create(
query: Optional[str] = None,
table_name: Optional[str] = None,
schema_name: str = "public",
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
content_columns: Optional[list[str]] = None,
metadata_columns: Optional[list[str]] = None,
metadata_json_column: Optional[str] = None,
format: Optional[str] = None,
formatter: Optional[Callable] = None,
Expand All @@ -162,8 +162,8 @@ async def create(
query (Optional[str], optional): SQL query. Defaults to None.
table_name (Optional[str], optional): Name of table to query. Defaults to None.
schema_name (str, optional): Name of the schema where table is located. Defaults to "public".
content_columns (Optional[List[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
content_columns (Optional[list[str]], optional): Column that represent a Document's page_content. Defaults to the first column.
metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to None.
metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata".
format (Optional[str], optional): Format of page content (OneOf: text, csv, YAML, JSON). Defaults to 'text'.
formatter (Optional[Callable], optional): A function to format page content (OneOf: format, formatter). Defaults to None.
Expand Down Expand Up @@ -236,7 +236,7 @@ async def create(
metadata_json_column,
)

async def aload(self) -> List[Document]:
async def aload(self) -> list[Document]:
"""Load PostgreSQL data into Document objects."""
return [doc async for doc in self.alazy_load()]

Expand Down Expand Up @@ -282,7 +282,7 @@ def __init__(
table_name: str,
content_column: str,
schema_name: str = "public",
metadata_columns: List[str] = [],
metadata_columns: list[str] = [],
metadata_json_column: Optional[str] = None,
):
"""AsyncAlloyDBDocumentSaver constructor.
Expand All @@ -293,7 +293,7 @@ def __init__(
table_name (str): Name of table to query.
schema_name (str, optional): Name of schema where the table is located. Defaults to "public".
content_column (str, optional): Column that represent a Document's page_content. Defaults to "page_content".
metadata_columns (Optional[List[str]], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_columns (Optional[list[str]], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_json_column (str, optional): Column to store metadata as JSON. Defaults to "langchain_metadata".
Raises:
Expand All @@ -317,7 +317,7 @@ async def create(
table_name: str,
schema_name: str = "public",
content_column: str = DEFAULT_CONTENT_COL,
metadata_columns: List[str] = [],
metadata_columns: list[str] = [],
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
) -> AsyncAlloyDBDocumentSaver:
"""Create an AsyncAlloyDBDocumentSaver instance.
Expand All @@ -327,7 +327,7 @@ async def create(
table_name (str): Name of table to query.
schema_name (str, optional): Name of schema where the table is located. Defaults to "public".
content_column (str, optional): Column that represent a Document's page_content. Defaults to "page_content".
metadata_columns (List[str], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_columns (list[str], optional): Column(s) that represent a Document's metadata. Defaults to an empty list.
metadata_json_column (Optional[str], optional): Column to store metadata as JSON. Defaults to "langchain_metadata".
Returns:
Expand Down Expand Up @@ -370,13 +370,13 @@ async def create(
metadata_json_column,
)

async def aadd_documents(self, docs: List[Document]) -> None:
async def aadd_documents(self, docs: list[Document]) -> None:
"""
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
stored in langchain_metadata JSON column.
Args:
docs (List[langchain_core.documents.Document]): a list of documents to be saved.
docs (list[langchain_core.documents.Document]): a list of documents to be saved.
"""

for doc in docs:
Expand Down Expand Up @@ -414,13 +414,13 @@ async def aadd_documents(self, docs: List[Document]) -> None:
await conn.execute(text(query), row)
await conn.commit()

async def adelete(self, docs: List[Document]) -> None:
async def adelete(self, docs: list[Document]) -> None:
"""
Delete all instances of a document from the DocumentSaver table by matching the entire Document
object.
Args:
docs (List[langchain_core.documents.Document]): a list of documents to be deleted.
docs (list[langchain_core.documents.Document]): a list of documents to be deleted.
"""
for doc in docs:
row = _parse_row_from_doc(
Expand Down
Loading

0 comments on commit 779f18a

Please sign in to comment.