Skip to content

Commit

Permalink
✨ upgrade to pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
yanyongyu authored Feb 11, 2024
1 parent 34a22fe commit 486ea8c
Show file tree
Hide file tree
Showing 8 changed files with 366 additions and 329 deletions.
22 changes: 12 additions & 10 deletions nonebot/adapters/qq/adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import sys
import json
import asyncio
from typing_extensions import override
from typing import Any, List, Tuple, Literal, Optional

from pydantic import parse_raw_as
from nonebot.utils import escape_tag
from nonebot.exception import WebSocketClosed
from nonebot.compat import PYDANTIC_V2, type_validate_python
from nonebot.drivers import (
URL,
Driver,
Expand Down Expand Up @@ -376,36 +377,37 @@ async def _loop(self, bot: Bot, ws: WebSocket):
)

def get_auth_base(self) -> URL:
return URL(self.qq_config.qq_auth_base)
return URL(str(self.qq_config.qq_auth_base))

def get_api_base(self) -> URL:
if self.qq_config.qq_is_sandbox:
return URL(self.qq_config.qq_sandbox_api_base)
return URL(str(self.qq_config.qq_sandbox_api_base))
else:
return URL(self.qq_config.qq_api_base)
return URL(str(self.qq_config.qq_api_base))

@staticmethod
async def receive_payload(bot: Bot, ws: WebSocket) -> Payload:
payload = parse_raw_as(PayloadType, await ws.receive())
payload = type_validate_python(PayloadType, json.loads(await ws.receive()))
if isinstance(payload, Dispatch):
bot.on_dispatch(payload)
return payload

@staticmethod
def payload_to_json(payload: Payload) -> str:
return payload.__config__.json_dumps(
payload.dict(), default=payload.__json_encoder__
)
if PYDANTIC_V2:
return payload.model_dump_json(by_alias=True)

return payload.json(by_alias=True)

@staticmethod
def payload_to_event(payload: Dispatch) -> Event:
EventClass = EVENT_CLASSES.get(payload.type, None)
if EventClass is None:
log("WARNING", f"Unknown payload type: {payload.type}")
event = Event.parse_obj(payload.data)
event = type_validate_python(Event, payload.data)
event.__type__ = payload.type # type: ignore
return event
return EventClass.parse_obj(payload.data)
return type_validate_python(EventClass, payload.data)

@override
async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
Expand Down
14 changes: 11 additions & 3 deletions nonebot/adapters/qq/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import List, Tuple, Optional

from pydantic import Extra, Field, HttpUrl, BaseModel
from pydantic import Field, HttpUrl, BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict


class Intents(BaseModel, extra=Extra.forbid):
class Intents(BaseModel):
guilds: bool = True
guild_members: bool = True
guild_messages: bool = False
Expand All @@ -18,6 +19,13 @@ class Intents(BaseModel, extra=Extra.forbid):
audio_action: bool = False
at_messages: bool = True

if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="forbid")
else:

class Config:
extra = "forbid"

def to_int(self):
return (
self.guilds << 0
Expand Down Expand Up @@ -54,7 +62,7 @@ def is_group_bot(self) -> bool:
return self.intent.is_group_enabled


class Config(BaseModel, extra=Extra.ignore):
class Config(BaseModel):
qq_is_sandbox: bool = False
qq_api_base: HttpUrl = Field("https://api.sgroup.qq.com/")
qq_sandbox_api_base: HttpUrl = Field("https://sandbox.api.sgroup.qq.com")
Expand Down
106 changes: 0 additions & 106 deletions nonebot/adapters/qq/models/_transformer.py

This file was deleted.

27 changes: 20 additions & 7 deletions nonebot/adapters/qq/models/guild.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import json
from enum import IntEnum
from datetime import datetime
from typing import List, Union, Generic, TypeVar, Optional

from pydantic.generics import GenericModel
from pydantic import BaseModel, validator, root_validator
from nonebot.compat import PYDANTIC_V2, model_fields, type_validate_python

from .common import MessageArk, MessageEmbed, MessageReference, MessageAttachment

if PYDANTIC_V2:
GenericModel = BaseModel
else:
from pydantic.generics import GenericModel

T = TypeVar("T")


Expand Down Expand Up @@ -329,7 +335,8 @@ class Elem(BaseModel):
video: Optional[VideoElem] = None
url: Optional[URLElem] = None

@root_validator(pre=True, allow_reuse=True)
@root_validator(pre=True)
@classmethod
def infer_type(cls, values: dict):
if values.get("type") is not None:
return values
Expand Down Expand Up @@ -373,10 +380,11 @@ class ThreadObjectInfo(BaseModel):
content: RichText
date_time: datetime

@validator("content", pre=True, allow_reuse=True)
@validator("content", pre=True)
@classmethod
def parse_content(cls, v):
if isinstance(v, str):
return RichText.parse_raw(v, content_type="json")
return type_validate_python(RichText, json.loads(v))
return v


Expand All @@ -387,10 +395,15 @@ class ThreadInfo(ThreadObjectInfo, GenericModel, Generic[_T_Title]):
# 事件推送拿到的title实际上是RichText的JSON字符串,而API调用返回的title是普通文本
title: _T_Title

@validator("title", pre=True, allow_reuse=True)
@validator("title", pre=True)
@classmethod
def parse_title(cls, v):
if isinstance(v, str) and cls.__fields__["title"].type_ is RichText:
return RichText.parse_raw(v, content_type="json")
if (
isinstance(v, str)
and next(f for f in model_fields(cls) if f.name == "title").annotation
is RichText
):
return type_validate_python(RichText, json.loads(v))
return v


Expand Down
58 changes: 43 additions & 15 deletions nonebot/adapters/qq/models/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from typing import Tuple, Union
from typing_extensions import Literal, Annotated

from pydantic import Extra, Field, BaseModel

from ._transformer import BoolToIntTransformer, AliasExportTransformer
from pydantic import Field, BaseModel
from nonebot.compat import PYDANTIC_V2, ConfigDict

PAYLOAD_FIELD_ALIASES = {"opcode": "op", "data": "d", "sequence": "s", "type": "t"}

Expand All @@ -21,14 +20,22 @@ class Opcode(IntEnum):
HTTP_CALLBACK_ACK = 12


class Payload(AliasExportTransformer, BaseModel):
class Config:
extra = Extra.allow
allow_population_by_field_name = True
class Payload(BaseModel):
if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(
extra="allow",
populate_by_name=True,
alias_generator=lambda x: PAYLOAD_FIELD_ALIASES.get(x, x),
)
else:

class Config:
extra = "allow"
allow_population_by_field_name = True

@classmethod
def alias_generator(cls, string: str) -> str:
return PAYLOAD_FIELD_ALIASES.get(string, string)
@classmethod
def alias_generator(cls, string: str) -> str:
return PAYLOAD_FIELD_ALIASES.get(string, string)


class Dispatch(Payload):
Expand All @@ -43,23 +50,37 @@ class Heartbeat(Payload):
data: int


class IdentifyData(BaseModel, extra=Extra.allow):
class IdentifyData(BaseModel):
token: str
intents: int
shard: Tuple[int, int]
properties: dict

if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="allow")
else:

class Config:
extra = "allow"


class Identify(Payload):
opcode: Literal[Opcode.IDENTIFY] = Field(Opcode.IDENTIFY)
data: IdentifyData


class ResumeData(BaseModel, extra=Extra.allow):
class ResumeData(BaseModel):
token: str
session_id: str
seq: int

if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="allow")
else:

class Config:
extra = "allow"


class Resume(Payload):
opcode: Literal[Opcode.RESUME] = Field(Opcode.RESUME)
Expand All @@ -74,9 +95,16 @@ class InvalidSession(Payload):
opcode: Literal[Opcode.INVALID_SESSION] = Field(Opcode.INVALID_SESSION)


class HelloData(BaseModel, extra=Extra.allow):
class HelloData(BaseModel):
heartbeat_interval: int

if PYDANTIC_V2:
model_config: ConfigDict = ConfigDict(extra="allow")
else:

class Config:
extra = "allow"


class Hello(Payload):
opcode: Literal[Opcode.HELLO] = Field(Opcode.HELLO)
Expand All @@ -87,9 +115,9 @@ class HeartbeatAck(Payload):
opcode: Literal[Opcode.HEARTBEAT_ACK] = Field(Opcode.HEARTBEAT_ACK)


class HTTPCallbackAck(BoolToIntTransformer, Payload):
class HTTPCallbackAck(Payload):
opcode: Literal[Opcode.HTTP_CALLBACK_ACK] = Field(Opcode.HTTP_CALLBACK_ACK)
data: bool
data: int


PayloadType = Union[
Expand Down
7 changes: 4 additions & 3 deletions packages/nonebot-adapter-qqguild/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ force_sort_within_sections = true
extra_standard_library = ["typing_extensions"]

[tool.ruff]
select = ["E", "W", "F", "UP", "C", "T", "Q"]
ignore = ["E402", "F403", "F405", "C901", "UP037"]

line-length = 88
target-version = "py38"

[tool.ruff.lint]
select = ["E", "W", "F", "UP", "C", "T", "Q"]
ignore = ["E402", "F403", "F405", "C901", "UP037"]

[tool.pyright]
pythonPlatform = "All"
pythonVersion = "3.8"
Expand Down
Loading

0 comments on commit 486ea8c

Please sign in to comment.