diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index e869dbdac..3c3871fde 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -19,6 +19,7 @@ NoReturn, Optional, TypeVar, + Union, ) from collections.abc import AsyncGenerator, Generator, Iterable from copy import deepcopy @@ -38,41 +39,57 @@ def _make_field_optional( field: FieldInfo, ) -> tuple[Any, FieldInfo]: tmp_field = deepcopy(field) - annotation = field.annotation # Handle generics (like List, Dict, etc.) if get_origin(annotation) is not None: - # Get the generic base (like List, Dict) and its arguments (like User in List[User]) generic_base = get_origin(annotation) generic_args = get_args(annotation) - # Recursively apply Partial to each of the generic arguments - modified_args = tuple( - ( - Partial[arg, MakeFieldsOptional] # type: ignore[valid-type] - if isinstance(arg, type) and issubclass(arg, BaseModel) + # Handle Union types specially + if generic_base is Union: + modified_args = tuple( + Partial[arg, MakeFieldsOptional] if isinstance(arg, type) and issubclass(arg, BaseModel) else arg + for arg in generic_args ) - for arg in generic_args - ) - - # Reconstruct the generic type with modified arguments - tmp_field.annotation = ( - Optional[generic_base[modified_args]] if generic_base else None - ) - tmp_field.default = None - # If the field is a BaseModel, then recursively convert it's - # attributes to optionals. + # Add None to Union options and set default + modified_args = modified_args + (None,) if None not in modified_args else modified_args + tmp_field.annotation = Union[modified_args] # type: ignore + tmp_field.default = None + else: + # For other generics (like List), process their arguments + modified_args = tuple( + _process_annotation(arg) + for arg in generic_args + ) + tmp_field.annotation = Optional[generic_base[modified_args]] # type: ignore + tmp_field.default = None + # If the field is a BaseModel, then recursively convert it's attributes to optionals elif isinstance(annotation, type) and issubclass(annotation, BaseModel): - tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type] - tmp_field.default = {} + tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore + tmp_field.default = None else: - tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment] + tmp_field.annotation = Optional[annotation] # type: ignore tmp_field.default = None return tmp_field.annotation, tmp_field # type: ignore +def _process_annotation(annotation: Any) -> Any: + """Helper function to process nested annotations""" + if get_origin(annotation) is Union: + modified_args = tuple( + Partial[arg, MakeFieldsOptional] if isinstance(arg, type) and issubclass(arg, BaseModel) + else arg + for arg in get_args(annotation) + ) + # Add None to Union options + modified_args = modified_args + (None,) if None not in modified_args else modified_args + return Union[modified_args] + elif isinstance(annotation, type) and issubclass(annotation, BaseModel): + return Partial[annotation, MakeFieldsOptional] + return annotation + class PartialBase(Generic[T_Model]): @classmethod