Skip to content

Commit

Permalink
Improve ruff format/lint rules (#385)
Browse files Browse the repository at this point in the history
* Add split-on-trailing-comma for imports

* make format

* Allow S101 assert, remove noqa comments

* Reenable temporarily disabled ruff rules
  • Loading branch information
jackmpcollins authored Dec 1, 2024
1 parent f7e7e37 commit af31a87
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 121 deletions.
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ ignore = [
"TD", # flake8-todos
"FIX", # flake8-fixme
"PL", # Pylint
"S101", # assert
# Compatibility with ruff formatter
"E501",
"ISC001",
Expand All @@ -116,17 +117,14 @@ ignore = [
"Q002",
"Q003",
"W191",
"B905", # TODO: Reenable this
"UP006", # TODO: Reenable this
"UP007", # TODO: Reenable this
"UP035", # TODO: Reenable this
]

[tool.ruff.lint.flake8-pytest-style]
mark-parentheses = false

[tool.ruff.lint.isort]
known-first-party = ["magentic"]
split-on-trailing-comma = false

[tool.ruff.lint.per-file-ignores]
"docs/examples/*" = [
Expand Down
24 changes: 5 additions & 19 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import base64
import json
from collections.abc import (
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
from enum import Enum
from functools import singledispatch
from itertools import groupby
Expand All @@ -15,11 +9,7 @@
import filetype

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic.chat_model.base import (
ChatModel,
aparse_stream,
parse_stream,
)
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
BaseFunctionSchema,
FunctionCallFunctionSchema,
Expand All @@ -44,11 +34,7 @@
StreamParser,
StreamState,
)
from magentic.function_call import (
FunctionCall,
ParallelFunctionCall,
_create_unique_id,
)
from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id
from magentic.vision import UserImageMessage

try:
Expand Down Expand Up @@ -273,7 +259,7 @@ def update(self, item: MessageStreamEvent) -> None:
current_snapshot=self._current_message_snapshot,
)
if item.type == "message_stop":
assert not self.usage_ref # noqa: S101
assert not self.usage_ref
self.usage_ref.append(
Usage(
input_tokens=item.message.usage.input_tokens,
Expand All @@ -283,7 +269,7 @@ def update(self, item: MessageStreamEvent) -> None:

@property
def current_message_snapshot(self) -> Message[Any]:
assert self._current_message_snapshot is not None # noqa: S101
assert self._current_message_snapshot is not None
# TODO: Possible to return AssistantMessage here?
return _RawMessage(self._current_message_snapshot.model_dump())

Expand Down
4 changes: 2 additions & 2 deletions src/magentic/chat_model/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import types
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from contextvars import ContextVar
from itertools import chain
from typing import Any, AsyncIterator, Iterator, TypeVar, cast, get_origin, overload
from typing import Any, TypeVar, cast, get_origin, overload

from pydantic import ValidationError

Expand Down
18 changes: 9 additions & 9 deletions src/magentic/chat_model/litellm_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,20 @@

class LitellmStreamParser(StreamParser[ModelResponse]):
def is_content(self, item: ModelResponse) -> bool:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
return bool(item.choices[0].delta.content)

def get_content(self, item: ModelResponse) -> str | None:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0].delta.content, str | None) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
assert isinstance(item.choices[0].delta.content, str | None)
return item.choices[0].delta.content

def is_tool_call(self, item: ModelResponse) -> bool:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
return bool(item.choices[0].delta.tool_calls)

def iter_tool_calls(self, item: ModelResponse) -> Iterable[FunctionCallChunk]:
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
if item.choices and item.choices[0].delta.tool_calls:
for tool_call in item.choices[0].delta.tool_calls:
if tool_call.function:
Expand All @@ -75,13 +75,13 @@ def update(self, item: ModelResponse) -> None:
# litellm requires usage is not None for its total usage calculation
item.usage = litellm.Usage() # type: ignore[attr-defined]
if not hasattr(item, "refusal"):
assert isinstance(item.choices[0], StreamingChoices) # noqa: S101
assert isinstance(item.choices[0], StreamingChoices)
item.choices[0].delta.refusal = None # type: ignore[attr-defined]
self._chat_completion_stream_state.handle_chunk(item) # type: ignore[arg-type]
usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined,name-defined]
# Ignore usages with 0 tokens
if usage and usage.prompt_tokens and usage.completion_tokens:
assert not self.usage_ref # noqa: S101
assert not self.usage_ref
self.usage_ref.append(
Usage(
input_tokens=usage.prompt_tokens,
Expand Down Expand Up @@ -210,7 +210,7 @@ def complete(
tool_schemas=tool_schemas, output_types=output_types
), # type: ignore[arg-type,unused-ignore]
)
assert not isinstance(response, ModelResponse) # noqa: S101
assert not isinstance(response, ModelResponse)
stream = OutputStream(
stream=response,
function_schemas=function_schemas,
Expand Down Expand Up @@ -270,7 +270,7 @@ async def acomplete(
tool_schemas=tool_schemas, output_types=output_types
), # type: ignore[arg-type,unused-ignore]
)
assert not isinstance(response, ModelResponse) # noqa: S101
assert not isinstance(response, ModelResponse)
stream = AsyncOutputStream(
stream=response,
function_schemas=function_schemas,
Expand Down
28 changes: 7 additions & 21 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import base64
from collections.abc import (
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
from enum import Enum
from functools import singledispatch
from typing import Any, Generic, Literal, TypeVar, cast, overload
Expand All @@ -24,11 +18,7 @@

from magentic._parsing import contains_parallel_function_call_type, contains_string_type
from magentic._streamed_response import StreamedResponse
from magentic.chat_model.base import (
ChatModel,
aparse_stream,
parse_stream,
)
from magentic.chat_model.base import ChatModel, aparse_stream, parse_stream
from magentic.chat_model.function_schema import (
BaseFunctionSchema,
FunctionCallFunctionSchema,
Expand All @@ -53,11 +43,7 @@
StreamParser,
StreamState,
)
from magentic.function_call import (
FunctionCall,
ParallelFunctionCall,
_create_unique_id,
)
from magentic.function_call import FunctionCall, ParallelFunctionCall, _create_unique_id
from magentic.streaming import StreamedStr
from magentic.vision import UserImageMessage

Expand All @@ -78,9 +64,9 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar

@message_to_openai_message.register(_RawMessage)
def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam:
assert isinstance(message.content, dict) # noqa: S101
assert "role" in message.content # noqa: S101
assert "content" in message.content # noqa: S101
assert isinstance(message.content, dict)
assert "role" in message.content
assert "content" in message.content
return cast(ChatCompletionMessageParam, message.content)


Expand Down Expand Up @@ -316,7 +302,7 @@ def update(self, item: ChatCompletionChunk) -> None:
tool_call_chunk.index = self._current_tool_call_index
self._chat_completion_stream_state.handle_chunk(item)
if item.usage:
assert not self.usage_ref # noqa: S101
assert not self.usage_ref
self.usage_ref.append(
Usage(
input_tokens=item.usage.prompt_tokens,
Expand Down
11 changes: 2 additions & 9 deletions src/magentic/chat_model/retry_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,8 @@
from functools import singledispatchmethod
from typing import Any, TypeVar, overload

from magentic.chat_model.base import (
ChatModel,
ToolSchemaParseError,
)
from magentic.chat_model.message import (
AssistantMessage,
Message,
ToolResultMessage,
)
from magentic.chat_model.base import ChatModel, ToolSchemaParseError
from magentic.chat_model.message import AssistantMessage, Message, ToolResultMessage
from magentic.logger import logfire

R = TypeVar("R")
Expand Down
28 changes: 14 additions & 14 deletions src/magentic/chat_model/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _streamed_str(
yield content
if self._parser.is_tool_call(item):
# TODO: Check if output types allow for early return and raise if not
assert not current_item_ref # noqa: S101
assert not current_item_ref
current_item_ref.append(item)
return
self._exhausted = True
Expand All @@ -113,7 +113,7 @@ def _tool_call(
# so that the whole stream is consumed including stop_reason/usage chunks
if item.id and item.id != current_tool_call_id:
# TODO: Check if output types allow for early return and raise if not
assert not current_tool_call_ref # noqa: S101
assert not current_tool_call_ref
current_tool_call_ref.append(item)
return
if item.args:
Expand Down Expand Up @@ -144,13 +144,13 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]:
while tool_call_ref:
current_tool_call_chunk = tool_call_ref.pop()
current_tool_call_id = current_tool_call_chunk.id
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_chunk.name is not None # noqa: S101
assert current_tool_call_id is not None
assert current_tool_call_chunk.name is not None
function_schema = select_function_schema(
self._function_schemas, current_tool_call_chunk.name
)
if function_schema is None:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise UnknownToolError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand All @@ -169,12 +169,12 @@ def __stream__(self) -> Iterator[StreamedStr | OutputT]:
if not tool_call_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
# Output must be Iterable if parse_args above did not consume
assert isinstance(output, Iterable), output # noqa: S101
assert isinstance(output, Iterable), output
# Consume stream via the output type so it can cache
consume(output)

except ValidationError as e:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise ToolSchemaParseError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand Down Expand Up @@ -221,7 +221,7 @@ async def _streamed_str(
yield content
if self._parser.is_tool_call(item):
# TODO: Check if output types allow for early return
assert not current_item_ref # noqa: S101
assert not current_item_ref
current_item_ref.append(item)
return
self._exhausted = True
Expand All @@ -235,7 +235,7 @@ async def _tool_call(
async for item in stream:
if item.id and item.id != current_tool_call_id:
# TODO: Check if output types allow for early return
assert not current_tool_call_ref # noqa: S101
assert not current_tool_call_ref
current_tool_call_ref.append(item)
return
if item.args:
Expand Down Expand Up @@ -267,13 +267,13 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]:
while tool_call_ref:
current_tool_call_chunk = tool_call_ref.pop()
current_tool_call_id = current_tool_call_chunk.id
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_chunk.name is not None # noqa: S101
assert current_tool_call_id is not None
assert current_tool_call_chunk.name is not None
function_schema = select_function_schema(
self._function_schemas, current_tool_call_chunk.name
)
if function_schema is None:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise UnknownToolError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand All @@ -292,11 +292,11 @@ async def __stream__(self) -> AsyncIterator[AsyncStreamedStr | OutputT]:
if not tool_call_ref and not self._exhausted:
# Finish the group to allow advancing to the next one
# Output must be AsyncIterable if aparse_args above did not consume
assert isinstance(output, AsyncIterable), output # noqa: S101
assert isinstance(output, AsyncIterable), output
# Consume stream via the output type so it can cache
await aconsume(output)
except ValidationError as e:
assert current_tool_call_id is not None # noqa: S101
assert current_tool_call_id is not None
raise ToolSchemaParseError(
output_message=self._state.current_message_snapshot,
tool_call_id=current_tool_call_id,
Expand Down
10 changes: 1 addition & 9 deletions src/magentic/chatprompt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
import inspect
from collections.abc import Awaitable, Callable, Sequence
from functools import update_wrapper
from typing import (
Any,
Generic,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
)
from typing import Any, Generic, ParamSpec, Protocol, TypeVar, cast, overload

from magentic.backend import get_chat_model
from magentic.chat_model.base import ChatModel
Expand Down
8 changes: 1 addition & 7 deletions src/magentic/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@
Iterable,
Iterator,
)
from typing import (
Any,
Generic,
ParamSpec,
TypeVar,
cast,
)
from typing import Any, Generic, ParamSpec, TypeVar, cast
from uuid import uuid4

from magentic.logger import logfire
Expand Down
7 changes: 1 addition & 6 deletions src/magentic/prompt_chain.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import inspect
from collections.abc import Callable
from functools import wraps
from typing import (
Any,
ParamSpec,
TypeVar,
cast,
)
from typing import Any, ParamSpec, TypeVar, cast

from magentic.chat import Chat
from magentic.chat_model.base import ChatModel
Expand Down
10 changes: 1 addition & 9 deletions src/magentic/prompt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
import inspect
from collections.abc import Awaitable, Callable, Sequence
from functools import update_wrapper
from typing import (
Any,
Generic,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
)
from typing import Any, Generic, ParamSpec, Protocol, TypeVar, cast, overload

from magentic.backend import get_chat_model
from magentic.chat_model.base import ChatModel
Expand Down
Loading

0 comments on commit af31a87

Please sign in to comment.