Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Dec 9, 2024
1 parent 39be3c7 commit 0daf245
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 61 deletions.
45 changes: 1 addition & 44 deletions python/langsmith/evaluation/_arunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import asyncio
import concurrent.futures as cf
import datetime
import inspect
import logging
import pathlib
import uuid
Expand Down Expand Up @@ -41,6 +40,7 @@
_ExperimentManagerMixin,
_extract_feedback_keys,
_ForwardResults,
_include_attachments,
_is_langchain_runnable,
_load_examples_map,
_load_experiment,
Expand Down Expand Up @@ -1058,49 +1058,6 @@ def _get_run(r: run_trees.RunTree) -> None:
)


def _include_attachments(
target: Union[ATARGET_T, Iterable[schemas.Run], AsyncIterable[dict], Runnable],
) -> bool:
"""Whether the target function accepts attachments."""
if _is_langchain_runnable(target) or not callable(target):
return False
# Check function signature
sig = inspect.signature(target)
params = list(sig.parameters.values())
positional_params = [
p
for p in params
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
and p.default is p.empty
]

if len(positional_params) == 0:
raise ValueError(
"Target function must accept at least one positional argument (inputs)"
)
elif len(positional_params) > 2:
raise ValueError(
"Target function must accept at most two positional "
"arguments (inputs, attachments)"
)
elif len(positional_params) == 2:
mismatches = []
for i, (p, expected) in enumerate(
zip(positional_params, ("inputs", "attachments"))
):
if p.name != expected:
mismatches.append((i, p.name))

if mismatches:
raise ValueError(
"When target function has two positional arguments, they must be named "
"'inputs' and 'attachments', respectively. Received: "
+ ",".join(f"'{p}' at index {i}" for i, p in mismatches)
)

return len(positional_params) == 2


def _ensure_async_traceable(
target: ATARGET_T,
) -> rh.SupportsLangsmithExtra[[dict], Awaitable]:
Expand Down
34 changes: 17 additions & 17 deletions python/langsmith/evaluation/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,47 +1913,47 @@ def _ensure_traceable(
return fn


def _include_attachments(
target: Union[TARGET_T, Iterable[schemas.Run], Runnable],
) -> bool:
def _include_attachments(target: Any) -> bool:
"""Whether the target function accepts attachments."""
if _is_langchain_runnable(target) or not callable(target):
return False
# Check function signature
sig = inspect.signature(target)
params = list(sig.parameters.values())
positional_params = [
p
for p in params
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
and p.default is p.empty
p for p in params if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
]
positional_no_default = [p for p in positional_params if p.default is p.empty]

if len(positional_params) == 0:
raise ValueError(
"Target function must accept at least one positional argument (inputs)"
"Target function must accept at least one positional argument (inputs)."
)
elif len(positional_params) > 2:
elif len(positional_no_default) > 2:
raise ValueError(
"Target function must accept at most two positional "
"arguments (inputs, attachments)"
"Target function must accept at most two "
"arguments without default values: (inputs, attachments)."
)
elif len(positional_params) == 2:
else:
mismatches = []
num_args = 0
for i, (p, expected) in enumerate(
zip(positional_params, ("inputs", "attachments"))
):
if p.name != expected:
mismatches.append((i, p.name))
else:
num_args += 1

if mismatches:
raise ValueError(
"When target function has two positional arguments, they must be named "
"'inputs' and 'attachments', respectively. Received: "
+ ",".join(f"'{p}' at index {i}" for i, p in mismatches)
msg = (
"Target function is expected to have a first positional argument "
"'inputs' and optionally a second positional argument 'attachments'. "
"Received: " + ", ".join(f"'{p}' at index {i}" for i, p in mismatches)
)
raise ValueError(msg)

return len(positional_params) == 2
return num_args == 2


def _resolve_experiment(
Expand Down

0 comments on commit 0daf245

Please sign in to comment.