Skip to content

Commit

Permalink
✨ pydanticv2
Browse files Browse the repository at this point in the history
  • Loading branch information
j1g5awi committed Feb 9, 2024
1 parent 12f57fc commit 66425b1
Show file tree
Hide file tree
Showing 9 changed files with 506 additions and 422 deletions.
16 changes: 10 additions & 6 deletions nonebot/adapters/telegram/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Adapter(BaseAdapter):
@overrides(BaseAdapter)
def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
self.adapter_config = AdapterConfig(**self.config.dict())
self.adapter_config = AdapterConfig(**self.config.model_dump())
self.tasks: List[asyncio.Task] = []
self.setup()

Expand All @@ -47,7 +47,9 @@ async def __handle_update(self, bot: Bot, update: Dict[str, Any]):

log(
"DEBUG",
escape_tag(str(event.dict(exclude_none=True, exclude={"telegram_model"}))),
escape_tag(
str(event.model_dump(exclude_none=True, exclude={"telegram_model"}))
),
)
await bot.handle_event(event)

Expand Down Expand Up @@ -89,7 +91,7 @@ async def poll(self, bot: Bot):
update_offset = update.update_id + 1
asyncio.create_task(
self.__handle_update(
bot, update.dict(by_alias=True, exclude_none=True)
bot, update.model_dump(by_alias=True, exclude_none=True)
)
)
elif updates:
Expand Down Expand Up @@ -212,9 +214,11 @@ async def process_input_file(file: Union[InputFile, str]) -> Optional[str]:
data[key] = json.dumps(
data[key],
default=(
lambda o: o.dict(exclude_none=True)
if isinstance(o, BaseModel)
else pydantic_encoder(o)
lambda o: (
o.model_dump(exclude_none=True)
if isinstance(o, BaseModel)
else pydantic_encoder(o)
)
),
)

Expand Down
9 changes: 4 additions & 5 deletions nonebot/adapters/telegram/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial
from typing import Any, List, Union, Optional, Sequence, cast

from pydantic import parse_obj_as
from pydantic import TypeAdapter
from nonebot.typing import overrides
from nonebot.message import handle_event

Expand Down Expand Up @@ -99,8 +99,8 @@ async def call_api(self, api: str, *args: Any, **kargs: Any) -> Any:
kargs[param.name] = args_.pop(0)
except IndexError:
kargs[param.name] = None
return parse_obj_as(
sign.return_annotation, await super().call_api(api, **kargs)
return TypeAdapter(sign.return_annotation).validate_python(
await super().call_api(api, **kargs)
)
return await super().call_api(api, **kargs)

Expand Down Expand Up @@ -237,8 +237,7 @@ async def send_to(
if len(files) > 1:
# 多个文件
medias = [
parse_obj_as(
InputMedia,
TypeAdapter(InputMedia).validate_python(
{
"type": file.type,
"media": file.data.pop("file"),
Expand Down
58 changes: 31 additions & 27 deletions nonebot/adapters/telegram/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __parse_event(cls, obj: dict) -> "Event":
}

event = event_map[post_type].parse_event(obj[post_type])
setattr(event, "telegram_model", Update.parse_obj(obj))
setattr(event, "telegram_model", Update.model_validate(obj))
return event

@classmethod
Expand All @@ -72,7 +72,7 @@ def _parse_event(cls, obj: dict, failed: set = set()) -> "Event":
return subclass.parse_event(obj)
except:
pass
return cls.parse_obj(obj)
return cls.model_validate(obj)

@classmethod
def parse_event(cls, obj: dict) -> "Event":
Expand All @@ -92,7 +92,11 @@ def get_event_name(self) -> str:
@overrides(BaseEvent)
def get_event_description(self) -> str:
return escape_tag(
str(self.dict(by_alias=True, exclude_none=True, exclude={"telegram_model"}))
str(
self.model_dump(
by_alias=True, exclude_none=True, exclude={"telegram_model"}
)
)
)

@overrides(BaseEvent)
Expand Down Expand Up @@ -126,24 +130,24 @@ class MessageEvent(Event):
message_id: int
date: int
chat: Chat
forward_from: Optional[User]
forward_from_chat: Optional[Chat]
forward_from_message_id: Optional[int]
forward_signature: Optional[str]
forward_sender_name: Optional[str]
forward_date: Optional[int]
via_bot: Optional[User]
has_protected_content: Optional[Literal[True]]
media_group_id: Optional[str]
author_signature: Optional[str]
forward_from: Optional[User] = None
forward_from_chat: Optional[Chat] = None
forward_from_message_id: Optional[int] = None
forward_signature: Optional[str] = None
forward_sender_name: Optional[str] = None
forward_date: Optional[int] = None
via_bot: Optional[User] = None
has_protected_content: Optional[Literal[True]] = None
media_group_id: Optional[str] = None
author_signature: Optional[str] = None
reply_to_message: Optional["MessageEvent"] = None
message: Message = Message()
original_message: Message = Message()
_tome: bool = False

@classmethod
def __parse_event(cls, obj: dict) -> "Event":
message = Message.parse_obj(obj)
message = Message.model_validate(obj)
if not message:
return NoticeEvent.parse_event(obj)
else:
Expand Down Expand Up @@ -214,11 +218,11 @@ def __parse_event(cls, obj: dict) -> "Event":
event = ForumTopicMessageEvent.parse_event(obj)
else:
obj.pop("message_thread_id", None)
event = cls.parse_obj(obj)
event = cls.model_validate(obj)
return event

from_: User = Field(alias="from")
sender_chat: Optional[Chat]
sender_chat: Optional[Chat] = None

@overrides(MessageEvent)
def get_event_name(self) -> str:
Expand Down Expand Up @@ -254,7 +258,7 @@ def get_event_description(self) -> str:


class ChannelPostEvent(MessageEvent):
sender_chat: Optional[Chat]
sender_chat: Optional[Chat] = None

@overrides(MessageEvent)
def get_event_name(self) -> str:
Expand All @@ -269,10 +273,10 @@ class EditedMessageEvent(Event):
message_id: int
date: int
chat: Chat
via_bot: Optional[User]
via_bot: Optional[User] = None
edit_date: int
media_group_id: Optional[str]
author_signature: Optional[str]
media_group_id: Optional[str] = None
author_signature: Optional[str] = None
reply_to_message: Optional["MessageEvent"] = None
message: Message = Message()

Expand All @@ -287,7 +291,7 @@ def __parse_event(cls, obj: dict):
"channel": EditedChannelPostEvent,
}
event = event_map[message_type].parse_event(obj)
setattr(event, "message", Message.parse_obj(obj))
setattr(event, "message", Message.model_validate(obj))
if reply_to_message:
setattr(
event, "reply_to_message", MessageEvent.parse_event(reply_to_message)
Expand Down Expand Up @@ -317,7 +321,7 @@ def get_event_description(self) -> str:

class PrivateEditedMessageEvent(EditedMessageEvent):
from_: User = Field(alias="from")
sender_chat: Optional[Chat]
sender_chat: Optional[Chat] = None

@overrides(EditedMessageEvent)
def get_event_name(self) -> str:
Expand All @@ -342,15 +346,15 @@ def get_event_description(self) -> str:

class GroupEditedMessageEvent(EditedMessageEvent):
from_: User = Field(default=None, alias="from")
sender_chat: Optional[Chat]
sender_chat: Optional[Chat] = None

@classmethod
def __parse_event(cls, obj: dict) -> "Event":
if obj.pop("is_topic_message", None):
event = ForumTopicEditedMessageEvent.parse_event(obj)
else:
obj.pop("message_thread_id", None)
event = cls.parse_obj(obj)
event = cls.model_validate(obj)
return event

@overrides(EditedMessageEvent)
Expand Down Expand Up @@ -387,7 +391,7 @@ def get_event_description(self) -> str:


class EditedChannelPostEvent(EditedMessageEvent):
sender_chat: Optional[Chat]
sender_chat: Optional[Chat] = None

@overrides(EditedMessageEvent)
def get_event_name(self) -> str:
Expand Down Expand Up @@ -440,15 +444,15 @@ def get_event_name(self) -> str:
class PinnedMessageEvent(NoticeEvent):
message_id: int
from_: Optional[User] = Field(alias="from")
sender_chat: Optional[Chat]
sender_chat: Optional[Chat] = None
chat: Chat
date: int
pinned_message: MessageEvent = Field(default=None)

@classmethod
def __parse_event(cls, obj: dict):
pinned_message = obj.pop("pinned_message")
event = cls.parse_obj(obj)
event = cls.model_validate(obj)
setattr(event, "pinned_message", MessageEvent.parse_event(pinned_message))
return event

Expand Down
2 changes: 1 addition & 1 deletion nonebot/adapters/telegram/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def _construct(msg: str) -> Iterable[MessageSegment]:
yield Entity.text(msg)

@classmethod
def parse_obj(cls, obj: Dict[str, Any]) -> "Message":
def model_validate(cls, obj: Dict[str, Any]) -> "Message":
msg = []
if "text" in obj or "caption" in obj:
key, entities_key = (
Expand Down
Loading

0 comments on commit 66425b1

Please sign in to comment.