Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ruff format/lint rules #385

Merged
merged 4 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ ignore = [
"TD", # flake8-todos
"FIX", # flake8-fixme
"PL", # Pylint
"S101", # assert
# Compatibility with ruff formatter
"E501",
"ISC001",
Expand All @@ -116,17 +117,14 @@ ignore = [
"Q002",
"Q003",
"W191",
"B905", # TODO: Reenable this
"UP006", # TODO: Reenable this
"UP007", # TODO: Reenable this
"UP035", # TODO: Reenable this
]

[tool.ruff.lint.flake8-pytest-style]
mark-parentheses = false

[tool.ruff.lint.isort]
known-first-party = ["magentic"]
split-on-trailing-comma = false

[tool.ruff.lint.per-file-ignores]
"docs/examples/*" = [
Expand Down
24 changes: 5 additions & 19 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import base64
import json
from collections.abc import (
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
from enum import Enum
from functools import singledispatch
from itertools import groupby
Expand All @@ -15,11 +9,7 @@
import filetype

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic.chat_model.base import (
ChatModel,
aparse_stream,
parse_stream,
)
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
BaseFunctionSchema,
FunctionCallFunctionSchema,
Expand All @@ -44,11 +34,7 @@
StreamParser,
StreamState,
)
from magentic.function_call import (
FunctionCall,
ParallelFunctionCall,
_create_unique_id,
)
from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id
from magentic.vision import UserImageMessage

try:
Expand Down Expand Up @@ -273,7 +259,7 @@ def update(self, item: MessageStreamEvent) -> None:
current_snapshot=self._current_message_snapshot,
)
if item.type == "message_stop":
assert not self.usage_ref # noqa: S101
assert not self.usage_ref
self.usage_ref.append(
Usage(
input_tokens=item.message.usage.input_tokens,
Expand All @@ -283,7 +269,7 @@ def update(self, item: MessageStreamEvent) -> None:

@property
def current_message_snapshot(self) -> Message[Any]:
assert self._current_message_snapshot is not None # noqa: S101
assert self._current_message_snapshot is not None
# TODO: Possible to return AssistantMessage here?
return _RawMessage(self._current_message_snapshot.model_dump())

Expand Down
4 changes: 2 additions & 2 deletions src/magentic/chat_model/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import types
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from contextvars import ContextVar
from itertools import chain
from typing import Any, AsyncIterator, Iterator, TypeVar, cast, get_origin, overload
from typing import Any, TypeVar, cast, get_origin, overload

from pydantic import ValidationError

Expand Down
18 changes: 9 additions & 9 deletions src/magentic/chat_model/litellm_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@

class LitellmStreamParser(StreamParser[ModelResponse]):
def is_content(self, item: ModelResponse) -> bool:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
return bool(item.choices[0].delta.content)

def get_content(self, item: ModelResponse) -> str | None:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0].delta.content, str | None) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
assert isinstance(item.choices[0].delta.content, str | None)
return item.choices[0].delta.content

def is_tool_call(self, item: ModelResponse) -> bool:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
return bool(item.choices[0].delta.tool_calls)

def iter_tool_calls(self, item: ModelResponse) -> Iterable[FunctionCallChunk]:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
if item.choices and item.choices[0].delta.tool_calls:
for tool_call in item.choices[0].delta.tool_calls:
if tool_call.function:
Expand All @@ -75,13 +75,13 @@ def update(self, item: ModelResponse) -> None:
# litellm requires usage is not None for its total usage calculation
item.usage = litellm.Usage() # type: ignore[attr-defined]
if not hasattr(item, "refusal"):
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
item.choices[0].delta.refusal = None # type: ignore[attr-defined]
self._chat_completion_stream_state.handle_chunk(item) # type: ignore[arg-type]
usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined,name-defined]
# Ignore usages with 0 tokens
if usage and usage.prompt_tokens and usage.completion_tokens:
assert not self.usage_ref # noqa: S101
assert not self.usage_ref
self.usage_ref.append(
Usage(
input_tokens=usage.prompt_tokens,
Expand Down Expand Up @@ -210,7 +210,7 @@ def complete(
tool_schemas=tool_schemas, output_types=output_types
), # type: ignore[arg-type,unused-ignore]
)
assert not isinstance(response, ModelResponse) # noqa: S101
assert not isinstance(response, ModelResponse)
stream = OutputStream(
stream=response,
function_schemas=function_schemas,
Expand Down Expand Up @@ -270,7 +270,7 @@ async def acomplete(
tool_schemas=tool_schemas, output_types=output_types
), # type: ignore[arg-type,unused-ignore]
)
assert not isinstance(response, ModelResponse) # noqa: S101
assert not isinstance(response, ModelResponse)
stream = AsyncOutputStream(
stream=response,
function_schemas=function_schemas,
Expand Down
28 changes: 7 additions & 21 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import base64
from collections.abc import (
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
from enum import Enum
from functools import singledispatch
from typing import Any, Generic, Literal, TypeVar, cast, overload
Expand All @@ -24,11 +18,7 @@

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic._streamed_response import StreamedResponse
from magentic.chat_model.base import (
ChatModel,
aparse_stream,
parse_stream,
)
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
BaseFunctionSchema,
FunctionCallFunctionSchema,
Expand All @@ -53,11 +43,7 @@
StreamParser,
StreamState,
)
from magentic.function_call import (
FunctionCall,
ParallelFunctionCall,
_create_unique_id,
)
from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id
from magentic.streaming import StreamedStr
from magentic.vision import UserImageMessage

Expand All @@ -78,9 +64,9 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar

@message_to_openai_message.register(_RawMessage)
def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam:
assert isinstance(message.content, dict) # noqa: S101
assert "role" in message.content # noqa: S101
assert "content" in message.content # noqa: S101
assert isinstance(message.content, dict)
assert "role" in message.content
assert "content" in message.content
return cast(ChatCompletionMessageParam, message.content)


Expand Down Expand Up @@ -316,7 +302,7 @@ def update(self, item: ChatCompletionChunk) -> None:
tool_call_chunk.index = self._current_tool_call_index
self._chat_completion_stream_state.handle_chunk(item)
if item.usage:
assert not self.usage_ref # noqa: S101
assert not self.usage_ref
self.usage_ref.append(
Usage(
input_tokens=item.usage.prompt_tokens,
Expand Down
11 changes: 2 additions & 9 deletions src/magentic/chat_model/retry_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@
from functools import singledispatchmethod
from typing import Any, TypeVar, overload

from magentic.chat_model.base import (
ChatModel,
ToolSchemaParseError,
)
from magentic.chat_model.message import (
AssistantMessage,
Message,
ToolResultMessage,
)
from magentic.chat_model.base import ChatModel, ToolSchemaParseError
from magentic.chat_model.message import AssistantMessage, Message, ToolResultMessage
from magentic.logger import logfire

R = TypeVar("R")
Expand Down
28 changes: 14 additions & 14 deletions src/magentic/chat_model/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _streamed_str(
yield content
if self._parser.is_tool_call(item):
# TODO: Check if output types allow for early return and raise if not
assert not current_item_ref # noqa: S101
assert not current_item_ref
current_item_ref.append(item)
return
self._exhausted = True
Expand All @@ -113,7 +113,7 @@ def _tool_call(
# so that the whole stream is consumed including stop_reason/usage chunks
if item.id and item.id != current_tool_call_id:
# TODO: Check if output types allow for early return and raise if not
assert not current_tool_call_ref # noqa: S101
assert not current_tool_call_ref
current_tool_call_ref.append(item)
return
if item.args:
Expand Down Expand Up @@ -144,13 +144,13 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]:
while tool_call_ref:
current_tool_call_chunk = tool_call_ref.pop()
current_tool_call_id = current_tool_call_chunk.id
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_chunk.name is not None # noqa: S101
assert current_tool_call_id is not None
assert current_tool_call_chunk.name is not None
function_schema = select_function_schema(
self._function_schemas, current_tool_call_chunk.name
)
if function_schema is None:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise UnknownToolError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand All @@ -169,12 +169,12 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]:
if not tool_call_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
# Output must be Iterable if parse_args above did not consume
assert isinstance(output, Iterable), output # noqa: S101
assert isinstance(output, Iterable), output
# Consume stream via the output type so it can cache
consume(output)

except ValidationError as e:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise ToolSchemaParseError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand Down Expand Up @@ -221,7 +221,7 @@ async def _streamed_str(
yield content
if self._parser.is_tool_call(item):
# TODO: Check if output types allow for early return
assert not current_item_ref # noqa: S101
assert not current_item_ref
current_item_ref.append(item)
return
self._exhausted = True
Expand All @@ -235,7 +235,7 @@ async def _tool_call(
async for item in stream:
if item.id and item.id != current_tool_call_id:
# TODO: Check if output types allow for early return
assert not current_tool_call_ref # noqa: S101
assert not current_tool_call_ref
current_tool_call_ref.append(item)
return
if item.args:
Expand Down Expand Up @@ -267,13 +267,13 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]:
while tool_call_ref:
current_tool_call_chunk = tool_call_ref.pop()
current_tool_call_id = current_tool_call_chunk.id
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_chunk.name is not None # noqa: S101
assert current_tool_call_id is not None
assert current_tool_call_chunk.name is not None
function_schema = select_function_schema(
self._function_schemas, current_tool_call_chunk.name
)
if function_schema is None:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise UnknownToolError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand All @@ -292,11 +292,11 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]:
if not tool_call_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
# Output must be AsyncIterable if aparse_args above did not consume
assert isinstance(output, AsyncIterable), output # noqa: S101
assert isinstance(output, AsyncIterable), output
# Consume stream via the output type so it can cache
await aconsume(output)
except ValidationError as e:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise ToolSchemaParseError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand Down
10 changes: 1 addition & 9 deletions src/magentic/chatprompt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
import inspect
from collections.abc import Awaitable, Callable, Sequence
from functools import update_wrapper
from typing import (
Any,
Generic,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
)
from typing import Any, Generic, ParamSpec, Protocol, TypeVar, cast, overload

from magentic.backend import get_chat_model
from magentic.chat_model.base import ChatModel
Expand Down
8 changes: 1 addition & 7 deletions src/magentic/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@
Iterable,
Iterator,
)
from typing import (
Any,
Generic,
ParamSpec,
TypeVar,
cast,
)
from typing import Any, Generic, ParamSpec, TypeVar, cast
from uuid import uuid4

from magentic.logger import logfire
Expand Down
7 changes: 1 addition & 6 deletions src/magentic/prompt_chain.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import inspect
from collections.abc import Callable
from functools import wraps
from typing import (
Any,
ParamSpec,
TypeVar,
cast,
)
from typing import Any, ParamSpec, TypeVar, cast

from magentic.chat import Chat
from magentic.chat_model.base import ChatModel
Expand Down
10 changes: 1 addition & 9 deletions src/magentic/prompt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
import inspect
from collections.abc import Awaitable, Callable, Sequence
from functools import update_wrapper
from typing import (
Any,
Generic,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
)
from typing import Any, Generic, ParamSpec, Protocol, TypeVar, cast, overload

from magentic.backend import get_chat_model
from magentic.chat_model.base import ChatModel
Expand Down
Loading