Skip to content

Commit

Permalink
feat: add engine_args argument to engine creation functions (#293)
Browse files Browse the repository at this point in the history
* feat: add engine_args argument to engine creation functions

* chore: fix formatting issues

* chore: add space for multi-line formatting

* chore: add basic tests

* chore: fix reference to pool size

* chore: diff syntax

* chore: diff syntax

* chore: real syntax this time

* Update tests/test_engine.py

---------

Co-authored-by: Averi Kitsch <[email protected]>
  • Loading branch information
kurtisvg and averikitsch authored Dec 19, 2024
1 parent c953125 commit 3cfa2f6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/langchain_google_alloydb_pg/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from concurrent.futures import Future
from dataclasses import dataclass
from threading import Thread
from typing import TYPE_CHECKING, Any, Awaitable, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Awaitable, Mapping, Optional, TypeVar, Union

import aiohttp
import google.auth # type: ignore
Expand Down Expand Up @@ -143,6 +143,7 @@ def __start_background_loop(
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> Future:
# Running a loop in a background thread allows us to support
# async methods from non-async environments
Expand All @@ -164,6 +165,7 @@ def __start_background_loop(
loop=cls._default_loop,
thread=cls._default_thread,
iam_account_email=iam_account_email,
engine_args=engine_args,
)
return asyncio.run_coroutine_threadsafe(coro, cls._default_loop)

Expand All @@ -179,6 +181,7 @@ def from_instance(
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.
Expand All @@ -192,6 +195,9 @@ def from_instance(
password (Optional[str]): Cloud AlloyDB user password. Defaults to None.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
engine_args (Mapping): Additional arguments that are passed directly to
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
used to specify additional parameters to the underlying pool during it's creation.
Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
Expand All @@ -206,6 +212,7 @@ def from_instance(
password,
ip_type,
iam_account_email=iam_account_email,
engine_args=engine_args,
)
return future.result()

Expand All @@ -223,6 +230,7 @@ async def _create(
loop: Optional[asyncio.AbstractEventLoop] = None,
thread: Optional[Thread] = None,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.
Expand All @@ -238,6 +246,9 @@ async def _create(
loop (Optional[asyncio.AbstractEventLoop]): Async event loop used to create the engine.
thread (Optional[Thread]): Thread used to create the engine async.
iam_account_email (Optional[str]): IAM service account email.
engine_args (Mapping): Additional arguments that are passed directly to
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
used to specify additional parameters to the underlying pool during it's creation.
Raises:
ValueError: Raises error if only one of 'user' or 'password' is specified.
Expand Down Expand Up @@ -290,6 +301,7 @@ async def getconn() -> asyncpg.Connection:
engine = create_async_engine(
"postgresql+asyncpg://",
async_creator=getconn,
**engine_args,
)
return cls(cls.__create_key, engine, loop, thread)

Expand All @@ -305,6 +317,7 @@ async def afrom_instance(
password: Optional[str] = None,
ip_type: Union[str, IPTypes] = IPTypes.PUBLIC,
iam_account_email: Optional[str] = None,
engine_args: Mapping = {},
) -> AlloyDBEngine:
"""Create an AlloyDBEngine from an AlloyDB instance.
Expand All @@ -318,6 +331,9 @@ async def afrom_instance(
password (Optional[str], optional): Cloud AlloyDB user password. Defaults to None.
ip_type (Union[str, IPTypes], optional): IP address type. Defaults to IPTypes.PUBLIC.
iam_account_email (Optional[str], optional): IAM service account email. Defaults to None.
engine_args (Mapping): Additional arguments that are passed directly to
:func:`~sqlalchemy.ext.asyncio.mymodule.MyClass.create_async_engine`. This can be
used to specify additional parameters to the underlying pool during it's creation.
Returns:
AlloyDBEngine: A newly created AlloyDBEngine instance.
Expand All @@ -332,6 +348,7 @@ async def afrom_instance(
password,
ip_type,
iam_account_email=iam_account_email,
engine_args=engine_args,
)
return await asyncio.wrap_future(future)

Expand All @@ -347,7 +364,7 @@ def from_engine(
@classmethod
def from_engine_args(
cls,
url: Union[str | URL],
url: str | URL,
**kwargs: Any,
) -> AlloyDBEngine:
"""Create an AlloyDBEngine instance from arguments
Expand Down
8 changes: 8 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ async def engine(self, db_project, db_region, db_cluster, db_instance, db_name):
instance=db_instance,
region=db_region,
database=db_name,
engine_args={
# add some connection args to validate engine_args works correctly
"pool_size": 3,
"max_overflow": 2,
},
)
yield engine
await aexecute(engine, f'DROP TABLE "{CUSTOM_TABLE}"')
Expand All @@ -130,6 +135,9 @@ async def test_init_table(self, engine):
stmt = f"INSERT INTO {DEFAULT_TABLE} (langchain_id, content, embedding) VALUES ('{id}', '{content}','{embedding}');"
await aexecute(engine, stmt)

async def test_engine_args(self, engine):
assert "Pool size: 3" in engine._pool.pool.status()

async def test_init_table_custom(self, engine):
await engine.ainit_vectorstore_table(
CUSTOM_TABLE,
Expand Down

0 comments on commit 3cfa2f6

Please sign in to comment.