From 1b40136dd20f4fc25746af3b2504644ba7a12c2f Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 16 Feb 2024 10:12:44 -0800 Subject: [PATCH] Wfh/project name (#449) Add support for: ``` @traceable(project_name="foo") def foo(): pass ``` and ``` langsmith.run_helpers.get_current_run_tree() ``` Add support for ``` run_tree.add_metadata() run_tree.add_events() run_tree.add_tags() ``` --- js/src/tests/batch_client.int.test.ts | 8 +- python/langsmith/cli/main.py | 1 - python/langsmith/client.py | 306 ++++++++++------ python/langsmith/evaluation/evaluator.py | 26 +- .../langsmith/evaluation/string_evaluator.py | 1 + python/langsmith/run_helpers.py | 337 +++++++++++++----- python/langsmith/run_trees.py | 73 +++- python/langsmith/schemas.py | 115 ++++-- python/langsmith/utils.py | 48 ++- python/langsmith/wrappers/__init__.py | 2 + python/pyproject.toml | 18 +- python/tests/unit_tests/test_run_helpers.py | 56 ++- 12 files changed, 745 insertions(+), 246 deletions(-) diff --git a/js/src/tests/batch_client.int.test.ts b/js/src/tests/batch_client.int.test.ts index 095ff0d05..aed01da82 100644 --- a/js/src/tests/batch_client.int.test.ts +++ b/js/src/tests/batch_client.int.test.ts @@ -57,7 +57,7 @@ test.concurrent( async () => { const langchainClient = new Client({ autoBatchTracing: true, - callerOptions: { maxRetries: 0 }, + callerOptions: { maxRetries: 2 }, timeout_ms: 30_000, }); const projectName = "__test_persist_update_run_batch_1"; @@ -96,7 +96,7 @@ test.concurrent( async () => { const langchainClient = new Client({ autoBatchTracing: true, - callerOptions: { maxRetries: 0 }, + callerOptions: { maxRetries: 2 }, pendingAutoBatchedRunLimit: 2, timeout_ms: 30_000, }); @@ -142,7 +142,7 @@ test.concurrent( async () => { const langchainClient = new Client({ autoBatchTracing: true, - callerOptions: { maxRetries: 0 }, + callerOptions: { maxRetries: 2 }, timeout_ms: 30_000, }); const projectName = "__test_persist_update_run_batch_with_delay"; @@ -183,7 +183,7 @@ test.concurrent( async () => { const langchainClient = new Client({ autoBatchTracing: true, - callerOptions: { maxRetries: 0 }, + callerOptions: { maxRetries: 2 }, timeout_ms: 30_000, }); const projectName = "__test_persist_update_run_tree"; diff --git a/python/langsmith/cli/main.py b/python/langsmith/cli/main.py index 11878ec5a..6ebeb1ff5 100644 --- a/python/langsmith/cli/main.py +++ b/python/langsmith/cli/main.py @@ -190,7 +190,6 @@ def logs(self) -> None: def status(self) -> None: """Provide information about the status LangSmith server.""" - command = [ *self.docker_compose_command, "-f", diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 16e7e2667..badc7e965 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -51,7 +51,7 @@ from langsmith.evaluation import evaluator as ls_evaluator if TYPE_CHECKING: - import pandas as pd + import pandas as pd # type: ignore logger = logging.getLogger(__name__) _urllib3_logger = logging.getLogger("urllib3.connectionpool") @@ -65,7 +65,7 @@ def _is_localhost(url: str) -> bool: url : str The URL to check. - Returns + Returns: ------- bool True if the URL is localhost, False otherwise. @@ -105,7 +105,7 @@ def _is_langchain_hosted(url: str) -> bool: url : str The URL to check. - Returns + Returns: ------- bool True if the URL is langchain hosted, False otherwise. @@ -128,7 +128,7 @@ def _default_retry_config() -> Retry: If urllib3 version is 1.26 or greater, retry on all methods. - Returns + Returns: ------- Retry The default retry configuration. @@ -237,7 +237,7 @@ def _validate_api_key_if_hosted(api_url: str, api_key: Optional[str]) -> None: api_key : str or None The API key. - Raises + Raises: ------ LangSmithUserError If the API key is not provided when using the hosted service. @@ -253,7 +253,7 @@ def _validate_api_key_if_hosted(api_url: str, api_key: Optional[str]) -> None: def _get_tracing_sampling_rate() -> float | None: """Get the tracing sampling rate. - Returns + Returns: ------- float The tracing sampling rate. @@ -322,6 +322,14 @@ def _parse_url(url): @dataclass(order=True) class TracingQueueItem: + """An item in the tracing queue. + + Attributes: + priority (str): The priority of the item. + action (str): The action associated with the item. + item (Any): The item itself. + """ + priority: str action: str item: Any = field(compare=False) @@ -377,7 +385,7 @@ def __init__( The session to use for requests. If None, a new session will be created. - Raises + Raises: ------ LangSmithUserError If the API key is not provided when using the hosted service. @@ -418,7 +426,7 @@ def __init__( def _repr_html_(self) -> str: """Return an HTML representation of the instance with a link to the URL. - Returns + Returns: ------- str The HTML representation of the instance. @@ -429,7 +437,7 @@ def _repr_html_(self) -> str: def __repr__(self) -> str: """Return a string representation of the instance with a link to the URL. - Returns + Returns: ------- str The string representation of the instance. @@ -462,7 +470,7 @@ def _host_url(self) -> str: def _headers(self) -> Dict[str, str]: """Get the headers for the API request. - Returns + Returns: ------- Dict[str, str] The headers for the API request. @@ -476,7 +484,7 @@ def _headers(self) -> Dict[str, str]: def info(self) -> Optional[ls_schemas.LangSmithInfo]: """Get the information about the LangSmith API. - Returns + Returns: ------- Optional[ls_schemas.LangSmithInfo] The information about the LangSmith API, or None if the API is @@ -533,12 +541,12 @@ def request_with_retries( to_ignore : Sequence[Type[BaseException]] or None, default=None The exceptions to ignore / pass on. - Returns + Returns: ------- Response The response object. - Raises + Raises: ------ LangSmithAPIError If a server error occurs. @@ -658,7 +666,7 @@ def _get_paginated_list( params : dict or None, default=None The query parameters. - Yields + Yields: ------ dict The items in the paginated list. @@ -700,7 +708,7 @@ def _get_cursor_paginated_list( The HTTP request method. data_key : str, default="runs" - Yields + Yields: ------ dict The items in the paginated list. @@ -756,12 +764,12 @@ def upload_dataframe( data_type : DataType or None, default=DataType.kv The data type of the dataset. - Returns + Returns: ------- Dataset The uploaded dataset. - Raises + Raises: ------ ValueError If the csv_file is not a string or tuple. @@ -807,12 +815,12 @@ def upload_csv( data_type : DataType or None, default=DataType.kv The data type of the dataset. - Returns + Returns: ------- Dataset The uploaded dataset. - Raises + Raises: ------ ValueError If the csv_file is not a string or tuple. @@ -860,8 +868,7 @@ def upload_csv( def _run_transform( run: Union[ls_schemas.Run, dict, ls_schemas.RunLikeDict], update: bool = False ) -> dict: - """ - Transforms the given run object into a dictionary representation. + """Transform the given run object into a dictionary representation. Args: run (Union[ls_schemas.Run, dict]): The run object to transform. @@ -948,7 +955,7 @@ def create_run( **kwargs : Any Additional keyword arguments. - Raises + Raises: ------ LangSmithUserError If the API key is not provided when using the hosted service. @@ -1013,8 +1020,7 @@ def batch_ingest_runs( *, pre_sampled: bool = False, ): - """ - Batch ingest/upsert multiple runs in the Langsmith system. + """Batch ingest/upsert multiple runs in the Langsmith system. Args: create (Optional[Sequence[Union[ls_schemas.Run, RunLikeDict]]]): @@ -1037,7 +1043,6 @@ def batch_ingest_runs( - The run objects MUST contain the dotted_order and trace_id fields to be accepted by the API. """ - if not create and not update: return # transform and convert to dicts @@ -1192,12 +1197,12 @@ def _load_child_runs(self, run: ls_schemas.Run) -> ls_schemas.Run: run : Run The run to load child runs for. - Returns + Returns: ------- Run The run with loaded child runs. - Raises + Raises: ------ LangSmithError If a child run has no parent. @@ -1232,7 +1237,7 @@ def read_run( load_child_runs : bool, default=False Whether to load nested child runs. - Returns + Returns: ------- Run The run. @@ -1293,7 +1298,7 @@ def list_runs( **kwargs : Any Additional keyword arguments. - Yields + Yields: ------ Run The runs. @@ -1352,7 +1357,7 @@ def get_run_url( project_id : UUID or None, default=None The ID of the project. - Returns + Returns: ------- str The URL for the run. @@ -1397,6 +1402,15 @@ def unshare_run(self, run_id: ID_TYPE) -> None: ls_utils.raise_for_status_with_text(response) def read_run_shared_link(self, run_id: ID_TYPE) -> Optional[str]: + """Retrieve the shared link for a specific run. + + Args: + run_id (ID_TYPE): The ID of the run. + + Returns: + Optional[str]: The shared link for the run, or None if the link is not + available. + """ response = self.session.get( f"{self.api_url}/runs/{_as_uuid(run_id, 'run_id')}/share", headers=self._headers, @@ -1433,6 +1447,20 @@ def read_dataset_shared_schema( *, dataset_name: Optional[str] = None, ) -> ls_schemas.DatasetShareSchema: + """Retrieve the shared schema of a dataset. + + Args: + dataset_id (Optional[ID_TYPE]): The ID of the dataset. + Either `dataset_id` or `dataset_name` must be given. + dataset_name (Optional[str]): The name of the dataset. + Either `dataset_id` or `dataset_name` must be given. + + Returns: + ls_schemas.DatasetShareSchema: The shared schema of the dataset. + + Raises: + ValueError: If neither `dataset_id` nor `dataset_name` is given. + """ if dataset_id is None and dataset_name is None: raise ValueError("Either dataset_id or dataset_name must be given") if dataset_id is None: @@ -1528,6 +1556,21 @@ def list_shared_projects( name: Optional[str] = None, name_contains: Optional[str] = None, ) -> Iterator[ls_schemas.TracerSessionResult]: + """List shared projects. + + Args: + dataset_share_token : str + The share token of the dataset. + project_ids : List[ID_TYPE], optional + List of project IDs to filter the results, by default None. + name : str, optional + Name of the project to filter the results, by default None. + name_contains : str, optional + Substring to search for in project names, by default None. + + Yields: + TracerSessionResult: The shared projects. + """ params = {"id": project_ids, "name": name, "name_contains": name_contains} share_token = _as_uuid(dataset_share_token, "dataset_share_token") yield from [ @@ -1565,7 +1608,7 @@ def create_project( reference_dataset_id: UUID or None, default=None The ID of the reference dataset to associate with the project. - Returns + Returns: ------- TracerSession The created project. @@ -1618,7 +1661,7 @@ def update_project( project_extra : dict or None, default=None Additional project information. - Returns + Returns: ------- TracerSession The updated project. @@ -1674,7 +1717,7 @@ def read_project( include_stats : bool, default=False Whether to include a project's aggregate statistics in the response. - Returns + Returns: ------- TracerSessionResult The project. @@ -1712,7 +1755,7 @@ def has_project( project_id : str or None, default=None The ID of the project to check for. - Returns + Returns: ------- bool Whether the project exists. @@ -1734,7 +1777,7 @@ def get_test_results( Note: this will fetch whatever data exists in the DB. Results are not immediately available in the DB upon evaluation run completion. - Returns + Returns: ------- pd.DataFrame A dataframe containing the test results. @@ -1801,8 +1844,7 @@ def list_projects( reference_dataset_name: Optional[str] = None, reference_free: Optional[bool] = None, ) -> Iterator[ls_schemas.TracerSession]: - """ - List projects from the LangSmith API. + """List projects from the LangSmith API. Parameters ---------- @@ -1819,7 +1861,7 @@ def list_projects( reference_free : Optional[bool], optional Whether to filter for only projects not associated with a dataset. - Yields + Yields: ------ TracerSession The projects. @@ -1891,7 +1933,7 @@ def create_dataset( data_type : DataType or None, default=DataType.kv The data type of the dataset. - Returns + Returns: ------- Dataset The created dataset. @@ -1925,7 +1967,7 @@ def has_dataset( dataset_id : str or None, default=None The ID of the dataset to check. - Returns + Returns: ------- bool Whether the dataset exists. @@ -1952,7 +1994,7 @@ def read_dataset( dataset_id : UUID or None, default=None The ID of the dataset to read. - Returns + Returns: ------- Dataset The dataset. @@ -1985,8 +2027,7 @@ def read_dataset( def read_dataset_openai_finetuning( self, dataset_id: Optional[str] = None, *, dataset_name: Optional[str] = None ) -> list: - """ - Download a dataset in OpenAI Jsonl format and load it as a list of dicts. + """Download a dataset in OpenAI Jsonl format and load it as a list of dicts. Parameters ---------- @@ -1995,7 +2036,7 @@ def read_dataset_openai_finetuning( dataset_name : str The name of the dataset to download. - Returns + Returns: ------- list The dataset loaded as a list of dicts. @@ -2023,7 +2064,7 @@ def list_datasets( ) -> Iterator[ls_schemas.Dataset]: """List the datasets on the LangSmith API. - Yields + Yields: ------ Dataset The datasets. @@ -2298,11 +2339,11 @@ def create_examples( dataset_name : Optional[str], default=None The name of the dataset to create the examples in. - Returns + Returns: ------- None - Raises + Raises: ------ ValueError If both `dataset_id` and `dataset_name` are `None`. @@ -2344,26 +2385,23 @@ def create_example( and expected outputs (or other reference information) for a model or chain. - Parameters - ---------- - inputs : Mapping[str, Any] - The input values for the example. - dataset_id : UUID or None, default=None - The ID of the dataset to create the example in. - dataset_name : str or None, default=None - The name of the dataset to create the example in. - created_at : datetime or None, default=None - The creation timestamp of the example. - outputs : Mapping[str, Any] or None, default=None - The output values for the example. - exemple_id : UUID or None, default=None - The ID of the example to create. If not provided, a new - example will be created. + Args: + inputs : Mapping[str, Any] + The input values for the example. + dataset_id : UUID or None, default=None + The ID of the dataset to create the example in. + dataset_name : str or None, default=None + The name of the dataset to create the example in. + created_at : datetime or None, default=None + The creation timestamp of the example. + outputs : Mapping[str, Any] or None, default=None + The output values for the example. + exemple_id : UUID or None, default=None + The ID of the example to create. If not provided, a new + example will be created. - Returns - ------- - Example - The created example. + Returns: + Example: The created example. """ if dataset_id is None: dataset_id = self.read_dataset(dataset_name=dataset_name).id @@ -2392,15 +2430,11 @@ def create_example( def read_example(self, example_id: ID_TYPE) -> ls_schemas.Example: """Read an example from the LangSmith API. - Parameters - ---------- - example_id : str or UUID - The ID of the example to read. + Args: + example_id (UUID): The ID of the example to read. - Returns - ------- - Example - The example. + Returns: + Example: The example. """ response = self._get_with_retries( f"/examples/{_as_uuid(example_id, 'example_id')}", @@ -2420,19 +2454,18 @@ def list_examples( ) -> Iterator[ls_schemas.Example]: """Retrieve the example rows of the specified dataset. - Parameters - ---------- - dataset_id : UUID or None, default=None - The ID of the dataset to filter by. - dataset_name : str or None, default=None - The name of the dataset to filter by. - example_ids : List[UUID] or None, default=None - The IDs of the examples to filter by. - - Yields - ------ - Example - The examples. + Args: + dataset_id (UUID, optional): The ID of the dataset to filter by. + Defaults to None. + dataset_name (str, optional): The name of the dataset to filter by. + Defaults to None. + example_ids (List[UUID], optional): The IDs of the examples to filter by. + Defaults to None. + inline_s3_urls (bool, optional): Whether to inline S3 URLs. + Defaults to True. + + Yields: + Example: The examples. """ params: Dict[str, Any] = {} if dataset_id is not None: @@ -2473,7 +2506,7 @@ def update_example( dataset_id : UUID or None, default=None The ID of the dataset to update. - Returns + Returns: ------- Dict[str, Any] The updated example. @@ -2519,12 +2552,12 @@ def _resolve_run_id( load_child_runs : bool Whether to load child runs. - Returns + Returns: ------- Run The resolved run. - Raises + Raises: ------ TypeError If the run type is invalid. @@ -2549,7 +2582,7 @@ def _resolve_example_id( run : Run The run associated with the example. - Returns + Returns: ------- Example or None The resolved example. @@ -2611,7 +2644,7 @@ def evaluate_run( load_child_runs : bool, default=False Whether to load child runs when resolving the run ID. - Returns + Returns: ------- Feedback The feedback object created by the evaluation. @@ -2685,7 +2718,7 @@ async def aevaluate_run( load_child_runs : bool, default=False Whether to load child runs when resolving the run ID. - Returns + Returns: ------- EvaluationResult The evaluation result object created by the evaluation. @@ -2864,7 +2897,7 @@ def read_feedback(self, feedback_id: ID_TYPE) -> ls_schemas.Feedback: feedback_id : str or UUID The ID of the feedback to read. - Returns + Returns: ------- Feedback The feedback. @@ -2897,7 +2930,7 @@ def list_feedback( **kwargs : Any Additional keyword arguments. - Yields + Yields: ------ Feedback The feedback objects. @@ -2938,6 +2971,20 @@ def list_annotation_queues( name: Optional[str] = None, name_contains: Optional[str] = None, ) -> Iterator[ls_schemas.AnnotationQueue]: + """List the annotation queues on the LangSmith API. + + Args: + queue_ids : List[str or UUID] or None, default=None + The IDs of the queues to filter by. + name : str or None, default=None + The name of the queue to filter by. + name_contains : str or None, default=None + The substring that the queue name should contain. + + Yields: + AnnotationQueue + The annotation queues. + """ params: dict = { "ids": ( [_as_uuid(id_, f"queue_ids[{i}]") for i, id_ in enumerate(queue_ids)] @@ -2959,6 +3006,20 @@ def create_annotation_queue( description: Optional[str] = None, queue_id: Optional[ID_TYPE] = None, ) -> ls_schemas.AnnotationQueue: + """Create an annotation queue on the LangSmith API. + + Args: + name : str + The name of the annotation queue. + description : str, optional + The description of the annotation queue. + queue_id : str or UUID, optional + The ID of the annotation queue. + + Returns: + AnnotationQueue + The created annotation queue object. + """ body = { "name": name, "description": description, @@ -2976,12 +3037,28 @@ def create_annotation_queue( return ls_schemas.AnnotationQueue(**response.json()) def read_annotation_queue(self, queue_id: ID_TYPE) -> ls_schemas.AnnotationQueue: + """Read an annotation queue with the specified queue ID. + + Args: + queue_id (ID_TYPE): The ID of the annotation queue to read. + + Returns: + ls_schemas.AnnotationQueue: The annotation queue object. + """ # TODO: Replace when actual endpoint is added return next(self.list_annotation_queues(queue_ids=[queue_id])) def update_annotation_queue( self, queue_id: ID_TYPE, *, name: str, description: Optional[str] = None ) -> None: + """Update an annotation queue with the specified queue_id. + + Args: + queue_id (ID_TYPE): The ID of the annotation queue to update. + name (str): The new name for the annotation queue. + description (Optional[str], optional): The new description for the + annotation queue. Defaults to None. + """ response = self.request_with_retries( "patch", f"{self.api_url}/annotation-queues/{_as_uuid(queue_id, 'queue_id')}", @@ -2996,6 +3073,11 @@ def update_annotation_queue( ls_utils.raise_for_status_with_text(response) def delete_annotation_queue(self, queue_id: ID_TYPE) -> None: + """Delete an annotation queue with the specified queue ID. + + Args: + queue_id (ID_TYPE): The ID of the annotation queue to delete. + """ response = self.session.delete( f"{self.api_url}/annotation-queues/{_as_uuid(queue_id, 'queue_id')}", headers=self._headers, @@ -3005,6 +3087,13 @@ def delete_annotation_queue(self, queue_id: ID_TYPE) -> None: def add_runs_to_annotation_queue( self, queue_id: ID_TYPE, *, run_ids: List[ID_TYPE] ) -> None: + """Add runs to an annotation queue with the specified queue ID. + + Args: + queue_id (ID_TYPE): The ID of the annotation queue. + run_ids (List[ID_TYPE]): The IDs of the runs to be added to the annotation + queue. + """ response = self.request_with_retries( "post", f"{self.api_url}/annotation-queues/{_as_uuid(queue_id, 'queue_id')}/runs", @@ -3020,6 +3109,15 @@ def add_runs_to_annotation_queue( def list_runs_from_annotation_queue( self, queue_id: ID_TYPE ) -> Iterator[ls_schemas.RunWithAnnotationQueueInfo]: + """List runs from an annotation queue with the specified queue ID. + + Args: + queue_id (ID_TYPE): The ID of the annotation queue. + + Yields: + ls_schemas.RunWithAnnotationQueueInfo: An iterator of runs from the + annotation queue. + """ path = f"/annotation-queues/{_as_uuid(queue_id, 'queue_id')}/runs" yield from ( ls_schemas.RunWithAnnotationQueueInfo(**run) @@ -3040,9 +3138,9 @@ async def arun_on_dataset( input_mapper: Optional[Callable[[Dict], Any]] = None, revision_id: Optional[str] = None, ) -> Dict[str, Any]: - """ - Asynchronously run the Chain or language model on a dataset - and store traces to the specified project name. + """Asynchronously run the Chain or language model on a dataset. + + Store traces to the specified project name. Args: dataset_name: Name of the dataset to run the chain on. @@ -3070,9 +3168,8 @@ async def arun_on_dataset( For the synchronous version, see client.run_on_dataset. - Examples + Examples: -------- - .. code-block:: python from langsmith import Client @@ -3181,9 +3278,9 @@ def run_on_dataset( input_mapper: Optional[Callable[[Dict], Any]] = None, revision_id: Optional[str] = None, ) -> Dict[str, Any]: - """ - Run the Chain or language model on a dataset and store traces - to the specified project name. + """Run the Chain or language model on a dataset. + + Store traces to the specified project name. Args: dataset_name: Name of the dataset to run the chain on. @@ -3212,9 +3309,8 @@ def run_on_dataset( For the (usually faster) async version of this function, see `client.arun_on_dataset`. - Examples + Examples: -------- - .. code-block:: python from langsmith import Client diff --git a/python/langsmith/evaluation/evaluator.py b/python/langsmith/evaluation/evaluator.py index 92d067fed..592d301ee 100644 --- a/python/langsmith/evaluation/evaluator.py +++ b/python/langsmith/evaluation/evaluator.py @@ -1,3 +1,5 @@ +"""This module contains the evaluator classes for evaluating runs.""" + import asyncio import uuid from abc import abstractmethod @@ -43,8 +45,11 @@ class Config: class EvaluationResults(TypedDict, total=False): - """Batch evaluation results, if your evaluator wishes - to return multiple scores.""" + """Batch evaluation results. + + This makes it easy for your evaluator to return multiple + metrics at once. + """ results: List[EvaluationResult] """The evaluation results.""" @@ -69,8 +74,7 @@ async def aevaluate_run( class DynamicRunEvaluator(RunEvaluator): - """ - A dynamic evaluator that wraps a function and transforms it into a `RunEvaluator`. + """A dynamic evaluator that wraps a function and transforms it into a `RunEvaluator`. This class is designed to be used with the `@run_evaluator` decorator, allowing functions that take a `Run` and an optional `Example` as arguments, and return @@ -86,8 +90,7 @@ def __init__( [Run, Optional[Example]], Union[EvaluationResult, EvaluationResults, dict] ], ): - """ - Initialize the DynamicRunEvaluator with a given function. + """Initialize the DynamicRunEvaluator with a given function. Args: func (Callable): A function that takes a `Run` and an optional `Example` as @@ -130,8 +133,7 @@ def _coerce_evaluation_results( def evaluate_run( self, run: Run, example: Optional[Example] = None ) -> Union[EvaluationResult, EvaluationResults]: - """ - Evaluate a run using the wrapped function. + """Evaluate a run using the wrapped function. This method directly invokes the wrapped function with the provided arguments. @@ -154,8 +156,7 @@ def evaluate_run( def __call__( self, run: Run, example: Optional[Example] = None ) -> Union[EvaluationResult, EvaluationResults]: - """ - Make the evaluator callable, allowing it to be used like a function. + """Make the evaluator callable, allowing it to be used like a function. This method enables the evaluator instance to be called directly, forwarding the call to `evaluate_run`. @@ -175,5 +176,8 @@ def run_evaluator( [Run, Optional[Example]], Union[EvaluationResult, EvaluationResults, dict] ], ): - """Decorator to create a run evaluator from a function.""" + """Create a run evaluator from a function. + + Decorator that transforms a function into a `RunEvaluator`. + """ return DynamicRunEvaluator(func) diff --git a/python/langsmith/evaluation/string_evaluator.py b/python/langsmith/evaluation/string_evaluator.py index 3854bdece..85b6ce530 100644 --- a/python/langsmith/evaluation/string_evaluator.py +++ b/python/langsmith/evaluation/string_evaluator.py @@ -1,3 +1,4 @@ +"""This module contains the StringEvaluator class.""" from typing import Callable, Dict, Optional from pydantic import BaseModel diff --git a/python/langsmith/run_helpers.py b/python/langsmith/run_helpers.py index ea0a6dd47..3f4f3057d 100644 --- a/python/langsmith/run_helpers.py +++ b/python/langsmith/run_helpers.py @@ -115,11 +115,23 @@ class _TraceableContainer(TypedDict, total=False): outer_tags: Optional[List[str]] +class _ContainerInput(TypedDict, total=False): + """Typed response when initializing a run a traceable.""" + + extra_outer: Optional[Dict] + name: Optional[str] + metadata: Optional[Dict[str, Any]] + tags: Optional[List[str]] + client: Optional[ls_client.Client] + reduce_fn: Optional[Callable] + project_name: Optional[str] + run_type: ls_client.RUN_TYPE_T + + def _container_end( container: _TraceableContainer, outputs: Optional[Any] = None, error: Optional[str] = None, - events: Optional[List[dict]] = None, ): """End the run.""" run_tree = container.get("new_run") @@ -127,7 +139,7 @@ def _container_end( # Tracing disabled return outputs_ = outputs if isinstance(outputs, dict) else {"output": outputs} - run_tree.end(outputs=outputs_, error=error, events=events) + run_tree.end(outputs=outputs_, error=error) run_tree.patch() @@ -142,33 +154,38 @@ def _collect_extra(extra_outer: dict, langsmith_extra: LangSmithExtra) -> dict: def _setup_run( func: Callable, - run_type: ls_client.RUN_TYPE_T, - extra_outer: dict, + container_input: _ContainerInput, langsmith_extra: Optional[LangSmithExtra] = None, - name: Optional[str] = None, - metadata: Optional[Mapping[str, Any]] = None, - tags: Optional[List[str]] = None, - client: Optional[ls_client.Client] = None, args: Any = None, kwargs: Any = None, ) -> _TraceableContainer: - outer_project = _PROJECT_NAME.get() or utils.get_tracer_project() + """Create a new run or create_child() if run is passed in kwargs.""" + extra_outer = container_input.get("extra_outer") or {} + name = container_input.get("name") + metadata = container_input.get("metadata") + tags = container_input.get("tags") + client = container_input.get("client") + run_type = container_input.get("run_type") or "chain" + outer_project = _PROJECT_NAME.get() langsmith_extra = langsmith_extra or LangSmithExtra() - parent_run_ = langsmith_extra.get("run_tree") or _PARENT_RUN_TREE.get() + parent_run_ = langsmith_extra.get("run_tree") or get_run_tree_context() + selected_project = ( + _PROJECT_NAME.get() # From parent trace + or langsmith_extra.get("project_name") # at invocation time + or container_input["project_name"] # at decorator time + or utils.get_tracer_project() # default + ) if not parent_run_ and not utils.tracing_is_enabled(): utils.log_once( logging.DEBUG, "LangSmith tracing is disabled, returning original function." ) return _TraceableContainer( new_run=None, - project_name=outer_project, + project_name=selected_project, outer_project=outer_project, outer_metadata=None, outer_tags=None, ) - # Else either the env var is set OR a parent run was explicitly set, - # which occurs in the `as_runnable()` flow - project_name_ = langsmith_extra.get("project_name", outer_project) signature = inspect.signature(func) name_ = name or func.__name__ docstring = func.__doc__ @@ -219,7 +236,7 @@ def _setup_run( inputs=inputs, run_type=run_type, reference_example_id=langsmith_extra.get("reference_example_id"), - project_name=project_name_, + project_name=selected_project, extra=extra_inner, tags=tags_, client=client_, @@ -230,7 +247,7 @@ def _setup_run( logger.error(f"Failed to post run {new_run.id}: {e}") response_container = _TraceableContainer( new_run=new_run, - project_name=project_name_, + project_name=selected_project, outer_project=outer_project, outer_metadata=outer_metadata, outer_tags=outer_tags, @@ -255,12 +272,36 @@ def _setup_run( @runtime_checkable class SupportsLangsmithExtra(Protocol, Generic[R]): + """Implementations of this Protoc accept an optional langsmith_extra parameter. + + Args: + *args: Variable length arguments. + langsmith_extra (Optional[Dict[str, Any]]): Optional dictionary of + additional parameters for Langsmith. + **kwargs: Keyword arguments. + + Returns: + R: The return type of the callable. + """ + def __call__( self, *args: Any, langsmith_extra: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> R: + """Call the instance when it is called as a function. + + Args: + *args: Variable length argument list. + langsmith_extra: Optional dictionary containing additional + parameters specific to Langsmith. + **kwargs: Arbitrary keyword arguments. + + Returns: + R: The return value of the method. + + """ ... @@ -279,8 +320,8 @@ def traceable( metadata: Optional[Mapping[str, Any]] = None, tags: Optional[List[str]] = None, client: Optional[ls_client.Client] = None, - extra: Optional[Dict] = None, reduce_fn: Optional[Callable] = None, + project_name: Optional[str] = None, ) -> Callable[[Callable[..., R]], SupportsLangsmithExtra[R]]: ... @@ -289,10 +330,10 @@ def traceable( *args: Any, **kwargs: Any, ) -> Union[Callable, Callable[[Callable], Callable]]: - """Decorator for creating or adding a run to a run tree. + """Trace a function with langsmith. Args: - run_type: The type of run to create. Examples: llm, chain, tool, prompt, + run_type: The type of run (span) to create. Examples: llm, chain, tool, prompt, retriever, etc. Defaults to "chain". name: The name of the run. Defaults to the function name. metadata: The metadata to add to the run. Defaults to None. @@ -304,12 +345,122 @@ def traceable( logged as a list. Note: if the iterator is never exhausted (e.g. the function returns an infinite generator), this will never be called, and the run itself will be stuck in a pending state. + project_name: The name of the project to log the run to. Defaults to None, + which will use the default project. + + + Returns: + Union[Callable, Callable[[Callable], Callable]]: The decorated function. + + Note: + - Requires that LANGCHAIN_TRACING_V2 be set to 'true' in the environment. + + Examples: + .. code-block:: python + import httpx + import asyncio + + from typing import Iterable + from langsmith import traceable, Client + + + # Basic usage: + @traceable + def my_function(x: float, y: float) -> float: + return x + y + + + my_function(5, 6) + + @traceable + async def my_async_function(query_params: dict) -> dict: + async with httpx.AsyncClient() as http_client: + response = await http_client.get( + "https://api.example.com/data", + params=query_params, + ) + return response.json() + + + asyncio.run(my_async_function({"param": "value"})) + + + # Streaming data with a generator: + @traceable + def my_generator(n: int) -> Iterable: + for i in range(n): + yield i + + + for item in my_generator(5): + print(item) + + + # Async streaming data + @traceable + async def my_async_generator(query_params: dict) -> Iterable: + async with httpx.AsyncClient() as http_client: + response = await http_client.get( + "https://api.example.com/data", + params=query_params, + ) + for item in response.json(): + yield item + + + async def async_code(): + async for item in my_async_generator({"param": "value"}): + print(item) + + + asyncio.run(async_code()) + + + # Specifying a run type and name: + @traceable(name="CustomName", run_type="tool") + def another_function(a: float, b: float) -> float: + return a * b + + + another_function(5, 6) + + + # Logging with custom metadata and tags: + @traceable( + metadata={"version": "1.0", "author": "John Doe"}, + tags=["beta", "test"] + ) + def tagged_function(x): + return x**2 + + + tagged_function(5) + + # Specifying a custom client and project name: + custom_client = Client(api_key="your_api_key") + + + @traceable(client=custom_client, project_name="My Special Project") + def project_specific_function(data): + return data + + + project_specific_function({"data": "to process"}) + + + # Manually passing langsmith_extra: + @traceable + def manual_extra_function(x): + return x**2 + + + manual_extra_function(5, langsmith_extra={"metadata": {"version": "1.0"}}) """ run_type: ls_client.RUN_TYPE_T = ( args[0] if args and isinstance(args[0], str) - else (kwargs.get("run_type") or "chain") + else (kwargs.pop("run_type", None) or "chain") ) if run_type not in _VALID_RUN_TYPES: warnings.warn( @@ -322,12 +473,29 @@ def traceable( "which should be the run_type. All other arguments should be passed " "as keyword arguments." ) - extra_outer = kwargs.get("extra") or {} - name = kwargs.get("name") - metadata = kwargs.get("metadata") - tags = kwargs.get("tags") - client = kwargs.get("client") - reduce_fn = kwargs.get("reduce_fn") + if "extra" in kwargs: + warnings.warn( + "The `extra` keyword argument is deprecated. Please use `metadata` " + "instead.", + DeprecationWarning, + ) + reduce_fn = kwargs.pop("reduce_fn", None) + container_input = _ContainerInput( + # TODO: Deprecate raw extra + extra_outer=kwargs.pop("extra", None), + name=kwargs.pop("name", None), + metadata=kwargs.pop("metadata", None), + tags=kwargs.pop("tags", None), + client=kwargs.pop("client", None), + project_name=kwargs.pop("project_name", None), + run_type=run_type, + ) + if kwargs: + warnings.warn( + f"The following keyword arguments are not recognized and will be ignored: " + f"{sorted(kwargs.keys())}.", + DeprecationWarning, + ) def decorator(func: Callable): @functools.wraps(func) @@ -336,17 +504,12 @@ async def async_wrapper( langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any, ) -> Any: - """Async version of wrapper function""" - context_run = _PARENT_RUN_TREE.get() + """Async version of wrapper function.""" + context_run = get_run_tree_context() run_container = _setup_run( func, - run_type=run_type, + container_input=container_input, langsmith_extra=langsmith_extra, - extra_outer=extra_outer, - name=name, - metadata=metadata, - tags=tags, - client=client, args=args, kwargs=kwargs, ) @@ -376,17 +539,11 @@ async def async_wrapper( async def async_generator_wrapper( *args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any ) -> AsyncGenerator: - events: List[dict] = [] - context_run = _PARENT_RUN_TREE.get() + context_run = get_run_tree_context() run_container = _setup_run( func, - run_type=run_type, + container_input=container_input, langsmith_extra=langsmith_extra, - extra_outer=extra_outer, - name=name, - metadata=metadata, - tags=tags, - client=client, args=args, kwargs=kwargs, ) @@ -413,20 +570,21 @@ async def async_generator_wrapper( async_gen_result = await async_gen_result async for item in async_gen_result: if run_type == "llm": - events.append( - { - "name": "new_token", - "time": datetime.datetime.now( - datetime.timezone.utc - ).isoformat(), - "kwargs": {"token": item}, - }, - ) + if run_container["new_run"]: + run_container["new_run"].add_event( + { + "name": "new_token", + "time": datetime.datetime.now( + datetime.timezone.utc + ).isoformat(), + "kwargs": {"token": item}, + } + ) results.append(item) yield item except BaseException as e: stacktrace = traceback.format_exc() - _container_end(run_container, error=stacktrace, events=events) + _container_end(run_container, error=stacktrace) raise e finally: _PARENT_RUN_TREE.set(context_run) @@ -444,7 +602,7 @@ async def async_generator_wrapper( function_result = results else: function_result = None - _container_end(run_container, outputs=function_result, events=events) + _container_end(run_container, outputs=function_result) @functools.wraps(func) def wrapper( @@ -453,16 +611,11 @@ def wrapper( **kwargs: Any, ) -> Any: """Create a new run or create_child() if run is passed in kwargs.""" - context_run = _PARENT_RUN_TREE.get() + context_run = get_run_tree_context() run_container = _setup_run( func, - run_type=run_type, + container_input=container_input, langsmith_extra=langsmith_extra, - extra_outer=extra_outer, - name=name, - metadata=metadata, - tags=tags, - client=client, args=args, kwargs=kwargs, ) @@ -492,17 +645,11 @@ def wrapper( def generator_wrapper( *args: Any, langsmith_extra: Optional[LangSmithExtra] = None, **kwargs: Any ) -> Any: - context_run = _PARENT_RUN_TREE.get() - events: List[dict] = [] + context_run = get_run_tree_context() run_container = _setup_run( func, - run_type=run_type, + container_input=container_input, langsmith_extra=langsmith_extra, - extra_outer=extra_outer, - name=name, - metadata=metadata, - tags=tags, - client=client, args=args, kwargs=kwargs, ) @@ -522,15 +669,16 @@ def generator_wrapper( generator_result = func(*args, **kwargs) for item in generator_result: if run_type == "llm": - events.append( - { - "name": "new_token", - "time": datetime.datetime.now( - datetime.timezone.utc - ).isoformat(), - "kwargs": {"token": item}, - }, - ) + if run_container["new_run"]: + run_container["new_run"].add_event( + { + "name": "new_token", + "time": datetime.datetime.now( + datetime.timezone.utc + ).isoformat(), + "kwargs": {"token": item}, + } + ) results.append(item) try: yield item @@ -538,7 +686,7 @@ def generator_wrapper( break except BaseException as e: stacktrace = traceback.format_exc() - _container_end(run_container, error=stacktrace, events=events) + _container_end(run_container, error=stacktrace) raise e finally: _PARENT_RUN_TREE.set(context_run) @@ -556,7 +704,7 @@ def generator_wrapper( function_result = results else: function_result = None - _container_end(run_container, outputs=function_result, events=events) + _container_end(run_container, outputs=function_result) if inspect.isasyncgenfunction(func): selected_wrapper: Callable = async_generator_wrapper @@ -604,9 +752,9 @@ def trace( outer_tags = _TAGS.get() outer_metadata = _METADATA.get() outer_project = _PROJECT_NAME.get() or utils.get_tracer_project() - parent_run_ = _PARENT_RUN_TREE.get() if run_tree is None else run_tree + parent_run_ = get_run_tree_context() if run_tree is None else run_tree - # Merge and set context varaibles + # Merge and set context variables tags_ = sorted(set((tags or []) + (outer_tags or []))) _TAGS.set(tags_) metadata = {**(metadata or {}), **(outer_metadata or {}), "ls_method": "trace"} @@ -655,6 +803,27 @@ def trace( def as_runnable(traceable_fn: Callable) -> Runnable: + """Convert a function wrapped by the LangSmith @traceable decorator to a Runnable. + + Args: + traceable_fn (Callable): The function wrapped by the @traceable decorator. + + Returns: + Runnable: A Runnable object that maintains a consistent LangSmith + tracing context. + + Raises: + ImportError: If langchain module is not installed. + ValueError: If the provided function is not wrapped by the @traceable decorator. + + Example: + >>> @traceable + ... def my_function(input_data): + ... # Function implementation + ... pass + ... + >>> runnable = as_runnable(my_function) + """ try: from langchain.callbacks.manager import ( AsyncCallbackManager, @@ -679,8 +848,9 @@ def as_runnable(traceable_fn: Callable) -> Runnable: ) class RunnableTraceable(RunnableLambda): - """RunnableTraceable converts a @traceable decorated function - to a Runnable in a way that hands off the LangSmith tracing context. + """Converts a @traceable decorated function to a Runnable. + + This helps maintain a consistent LangSmith tracing context. """ def __init__( @@ -756,7 +926,6 @@ def _wrap_async( afunc: Optional[Callable[..., Awaitable[Output]]], ) -> Optional[Callable[[Input, RunnableConfig], Awaitable[Output]]]: """Wrap an async function to make it synchronous.""" - if afunc is None: return None diff --git a/python/langsmith/run_trees.py b/python/langsmith/run_trees.py index 5e57caccc..392a0c9e5 100644 --- a/python/langsmith/run_trees.py +++ b/python/langsmith/run_trees.py @@ -4,26 +4,22 @@ import logging from datetime import datetime, timezone -from typing import Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID, uuid4 try: - from pydantic.v1 import ( # type: ignore[import] - Field, - root_validator, - validator, - ) + from pydantic.v1 import Field, root_validator, validator # type: ignore[import] except ImportError: from pydantic import Field, root_validator, validator +from langsmith import schemas as ls_schemas from langsmith import utils from langsmith.client import ID_TYPE, RUN_TYPE_T, Client -from langsmith.schemas import RunBase logger = logging.getLogger(__name__) -class RunTree(RunBase): +class RunTree(ls_schemas.RunBase): """Run Schema with back-references for posting runs.""" name: str @@ -48,6 +44,8 @@ class RunTree(RunBase): trace_id: UUID = Field(default="", description="The trace id of the run.") class Config: + """Pydantic model configuration.""" + arbitrary_types_allowed = True allow_population_by_field_name = True extra = "allow" @@ -78,6 +76,7 @@ def infer_defaults(cls, values: dict) -> dict: @root_validator(pre=False) def ensure_dotted_order(cls, values: dict) -> dict: + """Ensure the dotted order of the run.""" current_dotted_order = values.get("dotted_order") if current_dotted_order and current_dotted_order.strip(): return values @@ -92,13 +91,65 @@ def ensure_dotted_order(cls, values: dict) -> dict: values["dotted_order"] = current_dotted_order return values + def add_tags(self, tags: Union[Sequence[str], str]) -> None: + """Add tags to the run.""" + if isinstance(tags, str): + tags = [tags] + if self.tags is None: + self.tags = [] + self.tags.extend(tags) + + def add_metadata(self, metadata: Dict[str, Any]) -> None: + """Add metadata to the run.""" + if self.extra is None: + self.extra = {} + metadata_: dict = self.extra.setdefault("metadata", {}) + metadata_.update(metadata) + + def add_event( + self, + events: Union[ + ls_schemas.RunEvent, + Sequence[ls_schemas.RunEvent], + Sequence[dict], + dict, + str, + ], + ) -> None: + """Add an event to the list of events. + + Args: + events (Union[ls_schemas.RunEvent, Sequence[ls_schemas.RunEvent], + Sequence[dict], dict, str]): + The event(s) to be added. It can be a single event, a sequence + of events, + a sequence of dictionaries, a dictionary, or a string. + + Returns: + None + """ + if self.events is None: + self.events = [] + if isinstance(events, dict): + self.events.append(events) # type: ignore[arg-type] + elif isinstance(events, str): + self.events.append( + { + "name": "event", + "time": datetime.now(timezone.utc).isoformat(), + "message": events, + } + ) + else: + self.events.extend(events) # type: ignore[arg-type] + def end( self, *, outputs: Optional[Dict] = None, error: Optional[str] = None, end_time: Optional[datetime] = None, - events: Optional[List[Dict]] = None, + events: Optional[Sequence[ls_schemas.RunEvent]] = None, ) -> None: """Set the end time of the run and all child runs.""" self.end_time = end_time or datetime.now(timezone.utc) @@ -107,7 +158,7 @@ def end( if error is not None: self.error = error if events is not None: - self.events = events + self.add_event(events) def create_child( self, @@ -127,6 +178,8 @@ def create_child( ) -> RunTree: """Add a child run to the run tree.""" serialized_ = serialized or {"name": name} + + logger.warning(f"session_name: {self.session_name} {self.name}") run = RunTree( name=name, id=run_id or uuid4(), diff --git a/python/langsmith/schemas.py b/python/langsmith/schemas.py index ef269f9b6..df421e657 100644 --- a/python/langsmith/schemas.py +++ b/python/langsmith/schemas.py @@ -2,6 +2,7 @@ from __future__ import annotations +import threading from datetime import datetime, timedelta, timezone from enum import Enum from typing import ( @@ -50,6 +51,8 @@ class ExampleBase(BaseModel): outputs: Optional[Dict[str, Any]] = Field(default=None) class Config: + """Configuration class for the schema.""" + frozen = True @@ -101,6 +104,8 @@ class ExampleUpdate(BaseModel): outputs: Optional[Dict[str, Any]] = None class Config: + """Configuration class for the schema.""" + frozen = True @@ -120,6 +125,8 @@ class DatasetBase(BaseModel): data_type: Optional[DataType] = None class Config: + """Configuration class for the schema.""" + frozen = True @@ -181,9 +188,12 @@ class RunTypeEnum(str, Enum): class RunBase(BaseModel): - """ - Base Run schema. - Contains the fundamental fields to define a run in a system. + """Base Run schema. + + A Run is a span representing a single unit of work or operation within your LLM app. + This could be a single call to an LLM or chain, to a prompt formatting call, + to a runnable lambda invocation. If you are familiar with OpenTelemetry, + you can think of a run as a span. """ id: UUID @@ -230,6 +240,27 @@ class RunBase(BaseModel): tags: Optional[List[str]] = None """Tags for categorizing or annotating the run.""" + _lock: threading.Lock = Field(default_factory=threading.Lock) + + class Config: + """Configuration class for the schema.""" + + underscore_attrs_are_private = True + + @property + def metadata(self) -> dict[str, Any]: + """Retrieve the metadata (if any).""" + with self._lock: + if self.extra is None: + self.extra = {} + metadata = self.extra.setdefault("metadata", {}) + return metadata + + @property + def revision_id(self) -> Optional[UUID]: + """Retrieve the revision ID (if any).""" + return self.metadata.get("revision_id") + class Run(RunBase): """Run schema when loading from the DB.""" @@ -291,18 +322,6 @@ def url(self) -> Optional[str]: return f"{self._host_url}{self.app_path}" return None - @property - def metadata(self) -> dict[str, Any]: - """Retrieve the metadata (if any).""" - if self.extra is None or "metadata" not in self.extra: - return {} - return self.extra["metadata"] - - @property - def revision_id(self) -> Optional[UUID]: - """Retrieve the revision ID (if any).""" - return self.metadata.get("revision_id") - class RunLikeDict(TypedDict, total=False): """Run-like dictionary, for type-hinting.""" @@ -342,6 +361,17 @@ class RunWithAnnotationQueueInfo(RunBase): class FeedbackSourceBase(BaseModel): + """Base class for feedback sources. + + This represents whether feedback is submitted from the API, model, human labeler, + etc. + + Attributes: + type (str): The type of the feedback source. + metadata (Optional[Dict[str, Any]]): Additional metadata for the feedback + source. + """ + type: str metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) @@ -392,6 +422,8 @@ class FeedbackBase(BaseModel): """The source of the feedback.""" class Config: + """Configuration class for the schema.""" + frozen = True @@ -465,8 +497,10 @@ def tags(self) -> List[str]: class TracerSessionResult(TracerSession): - """TracerSession schema returned when reading a project - by ID. Sessions are also referred to as "Projects" in the UI.""" + """A project, hydrated with additional information. + + Sessions are also referred to as "Projects" in the UI. + """ run_count: Optional[int] """The number of runs in the project.""" @@ -492,9 +526,7 @@ class TracerSessionResult(TracerSession): @runtime_checkable class BaseMessageLike(Protocol): - """ - A protocol representing objects similar to BaseMessage. - """ + """A protocol representing objects similar to BaseMessage.""" content: str additional_kwargs: Dict @@ -505,12 +537,34 @@ def type(self) -> str: class DatasetShareSchema(TypedDict, total=False): + """Represents the schema for a dataset share. + + Attributes: + dataset_id (UUID): The ID of the dataset. + share_token (UUID): The token for sharing the dataset. + url (str): The URL of the shared dataset. + """ + dataset_id: UUID share_token: UUID url: str class AnnotationQueue(BaseModel): + """Represents an annotation queue. + + Attributes: + id (UUID): The ID of the annotation queue. + name (str): The name of the annotation queue. + description (Optional[str], optional): The description of the annotation queue. + Defaults to None. + created_at (datetime, optional): The creation timestamp of the annotation queue. + Defaults to the current UTC time. + updated_at (datetime, optional): The last update timestamp of the annotation + queue. Defaults to the current UTC time. + tenant_id (UUID): The ID of the tenant associated with the annotation queue. + """ + id: UUID name: str description: Optional[str] = None @@ -520,6 +574,16 @@ class AnnotationQueue(BaseModel): class BatchIngestConfig(TypedDict, total=False): + """Configuration for batch ingestion. + + Attributes: + scale_up_qsize_trigger (int): The queue size threshold that triggers scaling up. + scale_up_nthreads_limit (int): The maximum number of threads to scale up to. + scale_down_nempty_trigger (int): The number of empty threads that triggers + scaling down. + size_limit (int): The maximum size limit for the batch. + """ + scale_up_qsize_trigger: int scale_up_nthreads_limit: int scale_down_nempty_trigger: int @@ -537,3 +601,14 @@ class LangSmithInfo(BaseModel): Example.update_forward_refs() + + +class RunEvent(TypedDict, total=False): + """Run event schema.""" + + name: str + """Type of event.""" + time: Union[datetime, str] + """Time of the event.""" + kwargs: Optional[Dict[str, Any]] + """Additional metadata for the event.""" diff --git a/python/langsmith/utils.py b/python/langsmith/utils.py index 8cffd488e..6472f01b5 100644 --- a/python/langsmith/utils.py +++ b/python/langsmith/utils.py @@ -113,6 +113,7 @@ def get_enum_value(enu: Union[enum.Enum, str]) -> str: @functools.lru_cache(maxsize=1) def log_once(level: int, message: str) -> None: + """Log a message at the specified level, but only once.""" _LOGGER.log(level, message) @@ -162,6 +163,18 @@ def _convert_message(message: Mapping[str, Any]) -> Dict[str, Any]: def get_messages_from_inputs(inputs: Mapping[str, Any]) -> List[Dict[str, Any]]: + """Extract messages from the given inputs dictionary. + + Args: + inputs (Mapping[str, Any]): The inputs dictionary. + + Returns: + List[Dict[str, Any]]: A list of dictionaries representing + the extracted messages. + + Raises: + ValueError: If no message(s) are found in the inputs dictionary. + """ if "messages" in inputs: return [_convert_message(message) for message in inputs["messages"]] if "message" in inputs: @@ -170,6 +183,17 @@ def get_messages_from_inputs(inputs: Mapping[str, Any]) -> List[Dict[str, Any]]: def get_message_generation_from_outputs(outputs: Mapping[str, Any]) -> Dict[str, Any]: + """Retrieve the message generation from the given outputs. + + Args: + outputs (Mapping[str, Any]): The outputs dictionary. + + Returns: + Dict[str, Any]: The message generation. + + Raises: + ValueError: If no generations are found or if multiple generations are present. + """ if "generations" not in outputs: raise ValueError(f"No generations found in in run with output: {outputs}.") generations = outputs["generations"] @@ -188,6 +212,17 @@ def get_message_generation_from_outputs(outputs: Mapping[str, Any]) -> Dict[str, def get_prompt_from_inputs(inputs: Mapping[str, Any]) -> str: + """Retrieve the prompt from the given inputs. + + Args: + inputs (Mapping[str, Any]): The inputs dictionary. + + Returns: + str: The prompt. + + Raises: + ValueError: If the prompt is not found or if multiple prompts are present. + """ if "prompt" in inputs: return inputs["prompt"] if "prompts" in inputs: @@ -202,6 +237,7 @@ def get_prompt_from_inputs(inputs: Mapping[str, Any]) -> str: def get_llm_generation_from_outputs(outputs: Mapping[str, Any]) -> str: + """Get the LLM generation from the outputs.""" if "generations" not in outputs: raise ValueError(f"No generations found in in run with output: {outputs}.") generations = outputs["generations"] @@ -253,8 +289,7 @@ def convert_langchain_message(message: ls_schemas.BaseMessageLike) -> dict: def is_base_message_like(obj: object) -> bool: - """ - Check if the given object is similar to BaseMessage. + """Check if the given object is similar to BaseMessage. Args: obj (object): The object to check. @@ -295,6 +330,12 @@ class FilterPoolFullWarning(logging.Filter): """Filter urrllib3 warnings logged when the connection pool isn't reused.""" def __init__(self, name: str = "", host: str = "") -> None: + """Initialize the FilterPoolFullWarning filter. + + Args: + name (str, optional): The name of the filter. Defaults to "". + host (str, optional): The host to filter. Defaults to "". + """ super().__init__(name) self._host = host @@ -327,8 +368,7 @@ class LangSmithRetry(Retry): def filter_logs( logger: logging.Logger, filters: Sequence[logging.Filter] ) -> Generator[None, None, None]: - """ - Temporarily adds specified filters to a logger. + """Temporarily adds specified filters to a logger. Parameters: - logger: The logger to which the filters will be added. diff --git a/python/langsmith/wrappers/__init__.py b/python/langsmith/wrappers/__init__.py index 0f9a0150f..34f425953 100644 --- a/python/langsmith/wrappers/__init__.py +++ b/python/langsmith/wrappers/__init__.py @@ -1,3 +1,5 @@ +"""This module provides convenient tracing wrappers for popular libraries.""" + from langsmith.wrappers._openai import wrap_openai __all__ = ["wrap_openai"] diff --git a/python/pyproject.toml b/python/pyproject.toml index ed66eee04..1496acf6b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -59,10 +59,22 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] select = [ - "E", # pycodestyle - "F", # pyflakes - "I", # isort + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "D", # pydocstyle + "D401", # First line should be in imperative mood ] +ignore = [ + # Relax the convention by _not_ requiring documentation for every function parameter. + "D417", +] +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["D"] +"langsmith/cli/*" = ["D"] [tool.mypy] ignore_missing_imports = "True" diff --git a/python/tests/unit_tests/test_run_helpers.py b/python/tests/unit_tests/test_run_helpers.py index 9b5a6aeaf..c43c60038 100644 --- a/python/tests/unit_tests/test_run_helpers.py +++ b/python/tests/unit_tests/test_run_helpers.py @@ -158,13 +158,17 @@ def foo(kwargs: int, *, b: int, c: int, **some_other_kwargs: Any) -> None: } -@pytest.fixture -def mock_client() -> Client: +def _get_mock_client() -> Client: mock_session = MagicMock() client = Client(session=mock_session, api_key="test") return client +@pytest.fixture +def mock_client() -> Client: + return _get_mock_client() + + @pytest.mark.parametrize("use_next", [True, False]) def test_traceable_iterator(use_next: bool, mock_client: Client) -> None: with patch.dict(os.environ, {"LANGCHAIN_TRACING_V2": "true"}): @@ -187,7 +191,7 @@ def my_iterator_fn(a, b, d): results = list(genout) assert results == expected # Wait for batcher - time.sleep(0.1) + time.sleep(0.25) # check the mock_calls mock_calls = mock_client.session.request.mock_calls # type: ignore assert 1 <= len(mock_calls) <= 2 @@ -219,7 +223,7 @@ async def my_iterator_fn(a, b, d): results = [item async for item in genout] assert results == expected # Wait for batcher - await asyncio.sleep(0.1) + await asyncio.sleep(0.25) # check the mock_calls mock_calls = mock_client.session.request.mock_calls # type: ignore assert 1 <= len(mock_calls) <= 2 @@ -309,6 +313,50 @@ async def my_function(a, b, d): assert result == [6, 7] +def test_traceable_project_name() -> None: + with patch.dict(os.environ, {"LANGCHAIN_TRACING_V2": "true"}): + mock_client_ = _get_mock_client() + + @traceable(client=mock_client_, project_name="my foo project") + def my_function(a: int, b: int, d: int) -> int: + return a + b + d + + my_function(1, 2, 3) + time.sleep(0.25) + # Inspect the mock_calls and asser tthat "my foo project" is in + # the session_name arg of the body + mock_calls = mock_client_.session.request.mock_calls # type: ignore + assert 1 <= len(mock_calls) <= 2 + call = mock_calls[0] + assert call.args[0] == "post" + assert call.args[1].startswith("https://api.smith.langchain.com") + body = json.loads(call.kwargs["data"]) + assert body["post"] + assert body["post"][0]["session_name"] == "my foo project" + + # reset + mock_client_ = _get_mock_client() + + @traceable(client=mock_client_, project_name="my bar project") + def my_other_function(run_tree) -> int: + return my_function(1, 2, 3) + + my_other_function() + time.sleep(0.25) + # Inspect the mock_calls and assert that "my bar project" is in + # both all POST runs in the single request. We want to ensure + # all runs in a trace are associated with the same project. + mock_calls = mock_client_.session.request.mock_calls # type: ignore + assert 1 <= len(mock_calls) <= 2 + call = mock_calls[0] + assert call.args[0] == "post" + assert call.args[1].startswith("https://api.smith.langchain.com") + body = json.loads(call.kwargs["data"]) + assert body["post"] + assert body["post"][0]["session_name"] == "my bar project" + assert body["post"][1]["session_name"] == "my bar project" + + def test_is_traceable_function(mock_client: Client) -> None: @traceable(client=mock_client) def my_function(a: int, b: int, d: int) -> int: