From 190975420840845bfbe54bef9e128c8abe8b6762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Mon, 15 Jun 2026 18:35:39 +0200 Subject: [PATCH 1/8] refactor: enhance type hinting for field constructors --- tortoise/fields/base.py | 72 ++++++++++++++++++++++++++++- tortoise/fields/data.py | 85 ++++++++++++++++++++++++----------- tortoise/fields/relational.py | 26 ++++++++--- 3 files changed, 151 insertions(+), 32 deletions(-) diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 81db786c0..aa82dd9ba 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -7,7 +7,7 @@ from collections.abc import Callable from enum import Enum from functools import reduce -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar, overload from pypika_tortoise.terms import Term @@ -90,6 +90,76 @@ class OnDelete(StrEnum): NO_ACTION = OnDelete.NO_ACTION +class _FieldKwargsCommon(TypedDict, total=False): + """:class:`Field` constructor arguments that are never declared as explicit parameters. + + Used with :data:`typing.Unpack` to give ``**kwargs`` explicit type hints. This is the + smallest set; fields that declare ``unique``/``db_index``/``primary_key`` explicitly + (e.g. ``TextField``) unpack this directly to avoid PEP 692 parameter-name collisions. + """ + + source_field: str | None + generated: bool + default: Any + db_default: Any + description: str | None + model: Model | None + validators: list[Validator | Callable] + pk: bool # deprecated alias for primary_key + index: bool # deprecated alias for db_index + + +class _FieldKwargsNoPk(_FieldKwargsCommon, total=False): + """Common arguments excluding ``primary_key`` and ``null``. + + For constructors that declare ``primary_key`` and ``null`` as explicit parameters + (e.g. ``IntField``). + """ + + unique: bool + db_index: bool | None + + +class FieldKwargs(_FieldKwargsNoPk, total=False): + """Common arguments excluding ``null``. + + For constructors that declare only ``null`` as an explicit parameter (the majority). + """ + + primary_key: bool | None + + +class JSONFieldKwargs(FieldKwargs, total=False): + """Constructor arguments for :class:`JSONField`. + + ``JSONField`` declares neither ``null`` nor ``primary_key`` explicitly, and also accepts + a custom ``field_type`` (e.g. a Pydantic model class). + """ + + null: bool + field_type: Any + +class RelationalFieldKwargs(FieldKwargs, total=False): + """Constructor arguments for :func:`ForeignKeyField` and :func:`OneToOneField`. + + Extends the common :class:`~tortoise.fields.base.FieldKwargs` with ``to_field``. + ``null`` is declared as an explicit parameter on those constructors, so it is omitted. + """ + + to_field: str | None + + +class ManyToManyFieldKwargs(_FieldKwargsCommon, total=False): + """Constructor arguments for :func:`ManyToManyField`. + + ``unique`` is declared as an explicit parameter, so it is omitted here; the deprecated + ``create_unique_index`` alias is still accepted. + """ + + create_unique_index: bool # deprecated alias for unique + + + class _FieldMeta(type): # TODO: Require functions to return field instances instead of this hack def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict) -> type: diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 6109d0022..0ba5fa94d 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -4,6 +4,7 @@ import datetime import functools import json +import sys import warnings from collections.abc import Callable from decimal import Decimal @@ -17,7 +18,18 @@ from tortoise import timezone from tortoise.exceptions import ConfigurationError, FieldError -from tortoise.fields.base import Field +from tortoise.fields.base import ( + Field, + FieldKwargs, + JSONFieldKwargs, + _FieldKwargsCommon, + _FieldKwargsNoPk, +) + +if sys.version_info >= (3, 11): + from typing import Unpack +else: # pragma: no cover + from typing_extensions import Unpack from tortoise.timezone import get_default_timezone, get_timezone, get_use_tz, localtime from tortoise.validators import MaxLengthValidator @@ -107,7 +119,7 @@ def __init__( primary_key: bool | None = None, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[_FieldKwargsNoPk], ) -> None: ... @overload @@ -116,7 +128,7 @@ def __init__( primary_key: bool | None = None, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[_FieldKwargsNoPk], ) -> None: ... def __init__(self, primary_key: bool | None = None, **kwargs: Any) -> None: @@ -222,12 +234,20 @@ class CharField(Field[T_STR]): @overload def __init__( - self: CharField[str], max_length: int, *, null: Literal[False] = False, **kwargs: Any + self: CharField[str], + max_length: int, + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: CharField[str | None], max_length: int, *, null: Literal[True], **kwargs: Any + self: CharField[str | None], + max_length: int, + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, max_length: int, **kwargs: Any) -> None: @@ -269,7 +289,7 @@ def __init__( primary_key: bool | None = None, unique: bool = False, db_index: bool = False, - **kwargs: Any, + **kwargs: Unpack[_FieldKwargsCommon], ) -> None: if primary_key or kwargs.get("pk"): warnings.warn( @@ -315,12 +335,12 @@ class BooleanField(Field[T_BOOL]): @overload def __init__( - self: BooleanField[bool], *, null: Literal[False] = False, **kwargs: Any + self: BooleanField[bool], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] ) -> None: ... @overload def __init__( - self: BooleanField[bool | None], *, null: Literal[True], **kwargs: Any + self: BooleanField[bool | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -357,7 +377,7 @@ def __init__( decimal_places: int, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload @@ -367,7 +387,7 @@ def __init__( decimal_places: int, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, max_digits: int, decimal_places: int, **kwargs: Any) -> None: @@ -440,7 +460,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload @@ -450,7 +470,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: @@ -530,12 +550,15 @@ class DateField(Field[T_DATE], datetime.date): @overload def __init__( - self: DateField[datetime.date], *, null: Literal[False] = False, **kwargs: Any + self: DateField[datetime.date], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: DateField[datetime.date | None], *, null: Literal[True], **kwargs: Any + self: DateField[datetime.date | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -574,7 +597,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload @@ -584,7 +607,7 @@ def __init__( auto_now_add: bool = False, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: @@ -654,12 +677,18 @@ class TimeDeltaField(Field[T_TIMEDELTA]): @overload def __init__( - self: TimeDeltaField[datetime.timedelta], *, null: Literal[False] = False, **kwargs: Any + self: TimeDeltaField[datetime.timedelta], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: TimeDeltaField[datetime.timedelta | None], *, null: Literal[True], **kwargs: Any + self: TimeDeltaField[datetime.timedelta | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -692,11 +721,13 @@ class FloatField(Field[T_FLOAT], float): @overload def __init__( - self: FloatField[float], *, null: Literal[False] = False, **kwargs: Any + self: FloatField[float], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] ) -> None: ... @overload - def __init__(self: FloatField[float | None], *, null: Literal[True], **kwargs: Any) -> None: ... + def __init__( + self: FloatField[float | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] + ) -> None: ... def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -748,7 +779,7 @@ def __init__( self, encoder: JsonDumpsFunc = JSON_DUMPS, decoder: JsonLoadsFunc = JSON_LOADS, - **kwargs: Any, + **kwargs: Unpack[JSONFieldKwargs], ) -> None: super().__init__(**kwargs) self.encoder = encoder @@ -820,10 +851,14 @@ class _db_postgres: SQL_TYPE = "UUID" @overload - def __init__(self: UUIDField[UUID], *, null: Literal[False] = False, **kwargs: Any) -> None: ... + def __init__( + self: UUIDField[UUID], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] + ) -> None: ... @overload - def __init__(self: UUIDField[UUID | None], *, null: Literal[True], **kwargs: Any) -> None: ... + def __init__( + self: UUIDField[UUID | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] + ) -> None: ... def __init__(self, **kwargs: Any) -> None: if (kwargs.get("primary_key") or kwargs.get("pk", False)) and "default" not in kwargs: @@ -852,12 +887,12 @@ class BinaryField(Field[T_BINARY], bytes): # type: ignore @overload def __init__( - self: BinaryField[bytes], *, null: Literal[False] = False, **kwargs: Any + self: BinaryField[bytes], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] ) -> None: ... @overload def __init__( - self: BinaryField[bytes | None], *, null: Literal[True], **kwargs: Any + self: BinaryField[bytes | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] ) -> None: ... def __init__(self, **kwargs: Any) -> None: diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index 442707848..d69e6c24d 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import warnings from collections.abc import AsyncGenerator, Generator, Iterator from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload @@ -7,7 +8,19 @@ from pypika_tortoise.queries import Table from tortoise.exceptions import ConfigurationError, NoValuesFetched, OperationalError -from tortoise.fields.base import CASCADE, SET_NULL, Field, OnDelete +from tortoise.fields.base import ( + CASCADE, + SET_NULL, + Field, + ManyToManyFieldKwargs, + OnDelete, + RelationalFieldKwargs, +) + +if sys.version_info >= (3, 11): + from typing import Unpack +else: # pragma: no cover + from typing_extensions import Unpack if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient @@ -17,6 +30,7 @@ MODEL = TypeVar("MODEL", bound="Model") + class _NoneAwaitable: __slots__ = () @@ -441,7 +455,7 @@ def OneToOneField( db_constraint: bool = True, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> OneToOneNullableRelation[MODEL]: ... @@ -452,7 +466,7 @@ def OneToOneField( on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> OneToOneRelation[MODEL]: ... @@ -516,7 +530,7 @@ def ForeignKeyField( db_constraint: bool = True, *, null: Literal[True], - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> ForeignKeyNullableRelation[MODEL]: ... @@ -527,7 +541,7 @@ def ForeignKeyField( on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, - **kwargs: Any, + **kwargs: Unpack[RelationalFieldKwargs], ) -> ForeignKeyRelation[MODEL]: ... @@ -592,7 +606,7 @@ def ManyToManyField( on_delete: OnDelete = CASCADE, db_constraint: bool = True, unique: bool = True, - **kwargs: Any, + **kwargs: Unpack[ManyToManyFieldKwargs], ) -> ManyToManyRelation[MODEL]: """ ManyToMany relation field. From 390b8df4e2918a89dc6cc6300f046aa3830952ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Mon, 15 Jun 2026 18:36:05 +0200 Subject: [PATCH 2/8] updated changelog --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 78e832ee1..1a4a22ab2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Added - ``QuerySet.contains()`` method to check if an object exists in a queryset. - Added comprehensive EXPLAIN support for MySQL and PostgreSQL. - Built-in ``DomainNameValidator``, ``URLValidator``, and ``EmailValidator`` classes for common validation patterns. (#2162) +- Typed ``**kwargs`` on field constructors via PEP 692 (``Unpack[TypedDict]``), so IDEs and type checkers can autocomplete and validate common field arguments (``default``, ``null``, ``unique``, ``db_index``, ``description``, etc.). (#2168) Fixed ^^^^^ From dbd4819c9726a49f8822954e01a41919b460718a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Tue, 16 Jun 2026 07:58:31 +0200 Subject: [PATCH 3/8] run `make style` --- tortoise/fields/base.py | 2 +- tortoise/fields/relational.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index aa82dd9ba..bd6562c0f 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -139,6 +139,7 @@ class JSONFieldKwargs(FieldKwargs, total=False): null: bool field_type: Any + class RelationalFieldKwargs(FieldKwargs, total=False): """Constructor arguments for :func:`ForeignKeyField` and :func:`OneToOneField`. @@ -159,7 +160,6 @@ class ManyToManyFieldKwargs(_FieldKwargsCommon, total=False): create_unique_index: bool # deprecated alias for unique - class _FieldMeta(type): # TODO: Require functions to return field instances instead of this hack def __new__(mcs, name: str, bases: tuple[type, ...], attrs: dict) -> type: diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index d69e6c24d..b5e7ab568 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -30,7 +30,6 @@ MODEL = TypeVar("MODEL", bound="Model") - class _NoneAwaitable: __slots__ = () From fab2ca66a88aea011db2752c75016803285a00cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Wed, 17 Jun 2026 08:47:53 +0200 Subject: [PATCH 4/8] Added two @overload signatures with explicit null: Literal[False] / null: Literal[True] --- tortoise/fields/data.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 0ba5fa94d..4e3453429 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -284,12 +284,34 @@ class TextField(Field[str], str): # type: ignore indexable = False SQL_TYPE = "TEXT" + @overload + def __init__( + self, + *, + primary_key: bool | None = None, + unique: bool = False, + db_index: bool = False, + null: Literal[False] = False, + **kwargs: Unpack[_FieldKwargsCommon], + ) -> None: ... + + @overload def __init__( self, + *, primary_key: bool | None = None, unique: bool = False, db_index: bool = False, + null: Literal[True], **kwargs: Unpack[_FieldKwargsCommon], + ) -> None: ... + + def __init__( + self, + primary_key: bool | None = None, + unique: bool = False, + db_index: bool = False, + **kwargs: Any, ) -> None: if primary_key or kwargs.get("pk"): warnings.warn( From ae4953c608c71203646817165f5d3625d1c040fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Wed, 17 Jun 2026 08:57:58 +0200 Subject: [PATCH 5/8] fix: changed Textfield type to Field[T_STR] instead of Field[str] --- tortoise/fields/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 4e3453429..ac70562e4 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -276,7 +276,7 @@ def SQL_TYPE(self) -> str: return f"NVARCHAR2({self.field.max_length})" -class TextField(Field[str], str): # type: ignore +class TextField(Field[T_STR], str): # type: ignore """ Large Text field. """ @@ -286,7 +286,7 @@ class TextField(Field[str], str): # type: ignore @overload def __init__( - self, + self: TextField[str], *, primary_key: bool | None = None, unique: bool = False, @@ -297,7 +297,7 @@ def __init__( @overload def __init__( - self, + self: TextField[str | None], *, primary_key: bool | None = None, unique: bool = False, From 538d1f9c41c83ef9a59c98ee948c4a0c5be84958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Thu, 18 Jun 2026 16:44:54 +0200 Subject: [PATCH 6/8] refactor: changed import order for consistency --- tortoise/fields/data.py | 60 ++++++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index ac70562e4..f36332432 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -25,11 +25,6 @@ _FieldKwargsCommon, _FieldKwargsNoPk, ) - -if sys.version_info >= (3, 11): - from typing import Unpack -else: # pragma: no cover - from typing_extensions import Unpack from tortoise.timezone import get_default_timezone, get_timezone, get_use_tz, localtime from tortoise.validators import MaxLengthValidator @@ -42,7 +37,9 @@ try: from pydantic import BaseModel as _PydanticBaseModel - from pydantic._internal._model_construction import ModelMetaclass as _PydanticModelMetaclass + from pydantic._internal._model_construction import ( + ModelMetaclass as _PydanticModelMetaclass, + ) except ImportError: _PydanticBaseModel = None # type: ignore[assignment,misc] _PydanticModelMetaclass = None # type: ignore[assignment,misc] @@ -50,6 +47,12 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.models import Model + +if sys.version_info >= (3, 11): + from typing import Unpack +else: # pragma: no cover + from typing_extensions import Unpack + __all__ = ( "BigIntField", "BinaryField", @@ -357,12 +360,18 @@ class BooleanField(Field[T_BOOL]): @overload def __init__( - self: BooleanField[bool], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] + self: BooleanField[bool], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: BooleanField[bool | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] + self: BooleanField[bool | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -580,7 +589,10 @@ def __init__( @overload def __init__( - self: DateField[datetime.date | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] + self: DateField[datetime.date | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -743,12 +755,18 @@ class FloatField(Field[T_FLOAT], float): @overload def __init__( - self: FloatField[float], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] + self: FloatField[float], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: FloatField[float | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] + self: FloatField[float | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -874,12 +892,18 @@ class _db_postgres: @overload def __init__( - self: UUIDField[UUID], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] + self: UUIDField[UUID], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: UUIDField[UUID | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] + self: UUIDField[UUID | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: @@ -909,12 +933,18 @@ class BinaryField(Field[T_BINARY], bytes): # type: ignore @overload def __init__( - self: BinaryField[bytes], *, null: Literal[False] = False, **kwargs: Unpack[FieldKwargs] + self: BinaryField[bytes], + *, + null: Literal[False] = False, + **kwargs: Unpack[FieldKwargs], ) -> None: ... @overload def __init__( - self: BinaryField[bytes | None], *, null: Literal[True], **kwargs: Unpack[FieldKwargs] + self: BinaryField[bytes | None], + *, + null: Literal[True], + **kwargs: Unpack[FieldKwargs], ) -> None: ... def __init__(self, **kwargs: Any) -> None: From 8cb6ba78a7d3ca84577859898c361dab7f1b4823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Mon, 22 Jun 2026 08:14:32 +0200 Subject: [PATCH 7/8] refactor: moved imports under TYPE_CHECKING --- tortoise/fields/data.py | 56 ++++++++++++++++++-------- tortoise/fields/relational.py | 76 ++++++++++++++++++++++++----------- 2 files changed, 92 insertions(+), 40 deletions(-) diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index f36332432..7834ba12b 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -4,7 +4,6 @@ import datetime import functools import json -import sys import warnings from collections.abc import Callable from decimal import Decimal @@ -45,13 +44,14 @@ _PydanticModelMetaclass = None # type: ignore[assignment,misc] if TYPE_CHECKING: # pragma: nocoverage - from tortoise.models import Model + import sys + from tortoise.models import Model -if sys.version_info >= (3, 11): - from typing import Unpack -else: # pragma: no cover - from typing_extensions import Unpack + if sys.version_info >= (3, 11): + from typing import Unpack + else: # pragma: no cover + from typing_extensions import Unpack __all__ = ( "BigIntField", @@ -333,7 +333,9 @@ def __init__( stacklevel=2, ) if index or db_index: - raise ConfigurationError("TextField can't be indexed, consider CharField") + raise ConfigurationError( + "TextField can't be indexed, consider CharField" + ) elif db_index: raise ConfigurationError("TextField can't be indexed, consider CharField") @@ -429,7 +431,9 @@ def __init__(self, max_digits: int, decimal_places: int, **kwargs: Any) -> None: super().__init__(**kwargs) self.max_digits = max_digits self.decimal_places = decimal_places - self.quant = Decimal("1" if decimal_places == 0 else f"1.{('0' * decimal_places)}") + self.quant = Decimal( + "1" if decimal_places == 0 else f"1.{('0' * decimal_places)}" + ) def to_python_value(self, value: Any) -> Decimal | None: if value is not None: @@ -454,7 +458,9 @@ def function_cast(self, term: Term) -> Term: DatetimeFieldQueryValueType = TypeVar( "DatetimeFieldQueryValueType", datetime.datetime, int, float, str ) -DateFieldQueryValueType = TypeVar("DateFieldQueryValueType", datetime.date, int, float, str) +DateFieldQueryValueType = TypeVar( + "DateFieldQueryValueType", datetime.date, int, float, str +) class DatetimeField(Field[T_DATETIME], datetime.datetime): @@ -504,7 +510,9 @@ def __init__( **kwargs: Unpack[FieldKwargs], ) -> None: ... - def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: + def __init__( + self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any + ) -> None: if auto_now_add and auto_now: raise ConfigurationError("You can choose only 'auto_now' or 'auto_now_add'") super().__init__(**kwargs) @@ -644,7 +652,9 @@ def __init__( **kwargs: Unpack[FieldKwargs], ) -> None: ... - def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: + def __init__( + self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any + ) -> None: if auto_now_add and auto_now: raise ConfigurationError("You can choose only 'auto_now' or 'auto_now_add'") super().__init__(**kwargs) @@ -743,7 +753,9 @@ def to_db_value( if value is None: return None - return (value.days * 86400000000) + (value.seconds * 1000000) + value.microseconds + return ( + (value.days * 86400000000) + (value.seconds * 1000000) + value.microseconds + ) class FloatField(Field[T_FLOAT], float): @@ -907,7 +919,9 @@ def __init__( ) -> None: ... def __init__(self, **kwargs: Any) -> None: - if (kwargs.get("primary_key") or kwargs.get("pk", False)) and "default" not in kwargs: + if ( + kwargs.get("primary_key") or kwargs.get("pk", False) + ) and "default" not in kwargs: kwargs["default"] = uuid4 super().__init__(**kwargs) @@ -982,7 +996,9 @@ def __init__( # Automatic description for the field if not specified by the user if description is None: - description = "\n".join([f"{e.name}: {int(e.value)}" for e in enum_type])[:2048] + description = "\n".join([f"{e.name}: {int(e.value)}" for e in enum_type])[ + :2048 + ] super().__init__(description=description, **kwargs) self.enum_type = enum_type @@ -991,7 +1007,9 @@ def to_python_value(self, value: int | None) -> IntEnum | None: value = self.enum_type(value) if value is not None else None return value - def to_db_value(self, value: IntEnum | None | int, instance: type[Model] | Model) -> int | None: + def to_db_value( + self, value: IntEnum | None | int, instance: type[Model] | Model + ) -> int | None: if isinstance(value, IntEnum): value = int(value.value) if isinstance(value, int): @@ -1038,7 +1056,9 @@ def __init__( ) -> None: # Automatic description for the field if not specified by the user if description is None: - description = "\n".join([f"{e.name}: {str(e.value)}" for e in enum_type])[:2048] + description = "\n".join([f"{e.name}: {str(e.value)}" for e in enum_type])[ + :2048 + ] # Automatic CharField max_length if max_length == 0: @@ -1053,7 +1073,9 @@ def __init__( def to_python_value(self, value: str | None) -> Enum | None: return self.enum_type(value) if value is not None else None - def to_db_value(self, value: Enum | None | str, instance: type[Model] | Model) -> str | None: + def to_db_value( + self, value: Enum | None | str, instance: type[Model] | Model + ) -> str | None: self.validate(value) if isinstance(value, Enum): return str(value.value) diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index b5e7ab568..617e3e869 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys import warnings from collections.abc import AsyncGenerator, Generator, Iterator from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, overload @@ -17,16 +16,18 @@ RelationalFieldKwargs, ) -if sys.version_info >= (3, 11): - from typing import Unpack -else: # pragma: no cover - from typing_extensions import Unpack - if TYPE_CHECKING: # pragma: nocoverage + import sys + from tortoise.backends.base.client import BaseDBAsyncClient from tortoise.models import Model from tortoise.queryset import Q, QuerySet + if sys.version_info >= (3, 11): + from typing import Unpack + else: # pragma: no cover + from typing_extensions import Unpack + MODEL = TypeVar("MODEL", bound="Model") @@ -132,7 +133,9 @@ def offset(self, offset: int) -> QuerySet[MODEL]: """ return self._query.offset(offset) - async def create(self, using_db: BaseDBAsyncClient | None = None, **kwargs: Any) -> MODEL: + async def create( + self, using_db: BaseDBAsyncClient | None = None, **kwargs: Any + ) -> MODEL: """ Create a related record in the DB and returns the object, automatically setting the foreign key relationship to the parent instance. @@ -164,7 +167,9 @@ async def create(self, using_db: BaseDBAsyncClient | None = None, **kwargs: Any) # Call remote model's create method return await self.remote_model.create(using_db=using_db, **kwargs) - def _set_result_for_query(self, sequence: list[MODEL], attr: str | None = None) -> None: + def _set_result_for_query( + self, sequence: list[MODEL], attr: str | None = None + ) -> None: self._fetched = True self.related_objects = sequence if attr: @@ -182,12 +187,18 @@ class ManyToManyRelation(ReverseRelation[MODEL]): Many-to-many relation container for :func:`.ManyToManyField`. """ - def __init__(self, instance: Model, m2m_field: ManyToManyFieldInstance[MODEL]) -> None: - super().__init__(m2m_field.related_model, m2m_field.related_name, instance, "pk") + def __init__( + self, instance: Model, m2m_field: ManyToManyFieldInstance[MODEL] + ) -> None: + super().__init__( + m2m_field.related_model, m2m_field.related_name, instance, "pk" + ) self.field = m2m_field self.instance = instance - async def add(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None) -> None: + async def add( + self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None + ) -> None: """ Adds one or more of ``instances`` to the relation. @@ -206,16 +217,25 @@ async def add(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None pks_f: list = [] for instance_to_add in instances: if not instance_to_add._saved_in_db: - raise OperationalError(f"You should first call .save() on {instance_to_add}") + raise OperationalError( + f"You should first call .save() on {instance_to_add}" + ) pk_f = related_pk_formatting_func(instance_to_add.pk, instance_to_add) pks_f.append(pk_f) through_table = Table(self.field.through, schema=self.field.through_schema) backward_key, forward_key = self.field.backward_key, self.field.forward_key - backward_field, forward_field = through_table[backward_key], through_table[forward_key] + backward_field, forward_field = ( + through_table[backward_key], + through_table[forward_key], + ) select_query = ( - db.query_class.from_(through_table).where(backward_field == pk_b).select(forward_key) + db.query_class.from_(through_table) + .where(backward_field == pk_b) + .select(forward_key) + ) + criterion = ( + forward_field == pks_f[0] if len(pks_f) == 1 else forward_field.isin(pks_f) ) - criterion = forward_field == pks_f[0] if len(pks_f) == 1 else forward_field.isin(pks_f) select_query = select_query.where(criterion) _, already_existing_relations_raw = await db.execute_query( @@ -227,7 +247,9 @@ async def add(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None } if pks_f_to_insert := set(pks_f) - already_existing_forward_pks: - query = db.query_class.into(through_table).columns(forward_field, backward_field) + query = db.query_class.into(through_table).columns( + forward_field, backward_field + ) for pk_f in pks_f_to_insert: query = query.insert(pk_f, pk_b) await db.execute_query(*query.get_parameterized_sql()) @@ -238,7 +260,9 @@ async def clear(self, using_db: BaseDBAsyncClient | None = None) -> None: """ await self._remove_or_clear(using_db=using_db) - async def remove(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None) -> None: + async def remove( + self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None + ) -> None: """ Removes one or more of ``instances`` from the relation. @@ -263,9 +287,9 @@ async def _remove_or_clear( if instances: related_pk_formatting_func = type(instances[0])._meta.pk.to_db_value if len(instances) == 1: - condition &= through_table[self.field.forward_key] == related_pk_formatting_func( - instances[0].pk, instances[0] - ) + condition &= through_table[ + self.field.forward_key + ] == related_pk_formatting_func(instances[0].pk, instances[0]) else: condition &= through_table[self.field.forward_key].isin( [related_pk_formatting_func(i.pk, i) for i in instances] @@ -293,7 +317,9 @@ def __init__( if TYPE_CHECKING: @overload - def __get__(self, instance: None, owner: type[Model]) -> RelationalField[MODEL]: ... + def __get__( + self, instance: None, owner: type[Model] + ) -> RelationalField[MODEL]: ... @overload def __get__(self, instance: Model, owner: type[Model]) -> MODEL: ... @@ -322,7 +348,9 @@ def validate_model_name(cls, model_name: str | type[Model]) -> None: ) from None elif len(model_name.split(".")) != 2: field_type = cls.__name__.replace("Instance", "") - raise ConfigurationError(f'{field_type} accepts model name in format "app.Model"') + raise ConfigurationError( + f'{field_type} accepts model name in format "app.Model"' + ) class ForeignKeyFieldInstance(RelationalField[MODEL]): @@ -342,7 +370,9 @@ def __init__( "on_delete can only be CASCADE, RESTRICT, SET_NULL, SET_DEFAULT or NO_ACTION" ) if on_delete == SET_NULL and not bool(kwargs.get("null")): - raise ConfigurationError("If on_delete is SET_NULL, then field must have null=True set") + raise ConfigurationError( + "If on_delete is SET_NULL, then field must have null=True set" + ) self.on_delete = on_delete def describe(self, serializable: bool) -> dict: From 3a8b565e3cf217e1137fee21bc3426d5498c34b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrea=20Magist=C3=A0?= Date: Mon, 22 Jun 2026 08:16:05 +0200 Subject: [PATCH 8/8] run `make style` --- tortoise/fields/data.py | 44 +++++++------------------- tortoise/fields/relational.py | 58 ++++++++++------------------------- 2 files changed, 27 insertions(+), 75 deletions(-) diff --git a/tortoise/fields/data.py b/tortoise/fields/data.py index 7834ba12b..ce297e002 100644 --- a/tortoise/fields/data.py +++ b/tortoise/fields/data.py @@ -333,9 +333,7 @@ def __init__( stacklevel=2, ) if index or db_index: - raise ConfigurationError( - "TextField can't be indexed, consider CharField" - ) + raise ConfigurationError("TextField can't be indexed, consider CharField") elif db_index: raise ConfigurationError("TextField can't be indexed, consider CharField") @@ -431,9 +429,7 @@ def __init__(self, max_digits: int, decimal_places: int, **kwargs: Any) -> None: super().__init__(**kwargs) self.max_digits = max_digits self.decimal_places = decimal_places - self.quant = Decimal( - "1" if decimal_places == 0 else f"1.{('0' * decimal_places)}" - ) + self.quant = Decimal("1" if decimal_places == 0 else f"1.{('0' * decimal_places)}") def to_python_value(self, value: Any) -> Decimal | None: if value is not None: @@ -458,9 +454,7 @@ def function_cast(self, term: Term) -> Term: DatetimeFieldQueryValueType = TypeVar( "DatetimeFieldQueryValueType", datetime.datetime, int, float, str ) -DateFieldQueryValueType = TypeVar( - "DateFieldQueryValueType", datetime.date, int, float, str -) +DateFieldQueryValueType = TypeVar("DateFieldQueryValueType", datetime.date, int, float, str) class DatetimeField(Field[T_DATETIME], datetime.datetime): @@ -510,9 +504,7 @@ def __init__( **kwargs: Unpack[FieldKwargs], ) -> None: ... - def __init__( - self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any - ) -> None: + def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: if auto_now_add and auto_now: raise ConfigurationError("You can choose only 'auto_now' or 'auto_now_add'") super().__init__(**kwargs) @@ -652,9 +644,7 @@ def __init__( **kwargs: Unpack[FieldKwargs], ) -> None: ... - def __init__( - self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any - ) -> None: + def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs: Any) -> None: if auto_now_add and auto_now: raise ConfigurationError("You can choose only 'auto_now' or 'auto_now_add'") super().__init__(**kwargs) @@ -753,9 +743,7 @@ def to_db_value( if value is None: return None - return ( - (value.days * 86400000000) + (value.seconds * 1000000) + value.microseconds - ) + return (value.days * 86400000000) + (value.seconds * 1000000) + value.microseconds class FloatField(Field[T_FLOAT], float): @@ -919,9 +907,7 @@ def __init__( ) -> None: ... def __init__(self, **kwargs: Any) -> None: - if ( - kwargs.get("primary_key") or kwargs.get("pk", False) - ) and "default" not in kwargs: + if (kwargs.get("primary_key") or kwargs.get("pk", False)) and "default" not in kwargs: kwargs["default"] = uuid4 super().__init__(**kwargs) @@ -996,9 +982,7 @@ def __init__( # Automatic description for the field if not specified by the user if description is None: - description = "\n".join([f"{e.name}: {int(e.value)}" for e in enum_type])[ - :2048 - ] + description = "\n".join([f"{e.name}: {int(e.value)}" for e in enum_type])[:2048] super().__init__(description=description, **kwargs) self.enum_type = enum_type @@ -1007,9 +991,7 @@ def to_python_value(self, value: int | None) -> IntEnum | None: value = self.enum_type(value) if value is not None else None return value - def to_db_value( - self, value: IntEnum | None | int, instance: type[Model] | Model - ) -> int | None: + def to_db_value(self, value: IntEnum | None | int, instance: type[Model] | Model) -> int | None: if isinstance(value, IntEnum): value = int(value.value) if isinstance(value, int): @@ -1056,9 +1038,7 @@ def __init__( ) -> None: # Automatic description for the field if not specified by the user if description is None: - description = "\n".join([f"{e.name}: {str(e.value)}" for e in enum_type])[ - :2048 - ] + description = "\n".join([f"{e.name}: {str(e.value)}" for e in enum_type])[:2048] # Automatic CharField max_length if max_length == 0: @@ -1073,9 +1053,7 @@ def __init__( def to_python_value(self, value: str | None) -> Enum | None: return self.enum_type(value) if value is not None else None - def to_db_value( - self, value: Enum | None | str, instance: type[Model] | Model - ) -> str | None: + def to_db_value(self, value: Enum | None | str, instance: type[Model] | Model) -> str | None: self.validate(value) if isinstance(value, Enum): return str(value.value) diff --git a/tortoise/fields/relational.py b/tortoise/fields/relational.py index 617e3e869..8523efa18 100644 --- a/tortoise/fields/relational.py +++ b/tortoise/fields/relational.py @@ -133,9 +133,7 @@ def offset(self, offset: int) -> QuerySet[MODEL]: """ return self._query.offset(offset) - async def create( - self, using_db: BaseDBAsyncClient | None = None, **kwargs: Any - ) -> MODEL: + async def create(self, using_db: BaseDBAsyncClient | None = None, **kwargs: Any) -> MODEL: """ Create a related record in the DB and returns the object, automatically setting the foreign key relationship to the parent instance. @@ -167,9 +165,7 @@ async def create( # Call remote model's create method return await self.remote_model.create(using_db=using_db, **kwargs) - def _set_result_for_query( - self, sequence: list[MODEL], attr: str | None = None - ) -> None: + def _set_result_for_query(self, sequence: list[MODEL], attr: str | None = None) -> None: self._fetched = True self.related_objects = sequence if attr: @@ -187,18 +183,12 @@ class ManyToManyRelation(ReverseRelation[MODEL]): Many-to-many relation container for :func:`.ManyToManyField`. """ - def __init__( - self, instance: Model, m2m_field: ManyToManyFieldInstance[MODEL] - ) -> None: - super().__init__( - m2m_field.related_model, m2m_field.related_name, instance, "pk" - ) + def __init__(self, instance: Model, m2m_field: ManyToManyFieldInstance[MODEL]) -> None: + super().__init__(m2m_field.related_model, m2m_field.related_name, instance, "pk") self.field = m2m_field self.instance = instance - async def add( - self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None - ) -> None: + async def add(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None) -> None: """ Adds one or more of ``instances`` to the relation. @@ -217,9 +207,7 @@ async def add( pks_f: list = [] for instance_to_add in instances: if not instance_to_add._saved_in_db: - raise OperationalError( - f"You should first call .save() on {instance_to_add}" - ) + raise OperationalError(f"You should first call .save() on {instance_to_add}") pk_f = related_pk_formatting_func(instance_to_add.pk, instance_to_add) pks_f.append(pk_f) through_table = Table(self.field.through, schema=self.field.through_schema) @@ -229,13 +217,9 @@ async def add( through_table[forward_key], ) select_query = ( - db.query_class.from_(through_table) - .where(backward_field == pk_b) - .select(forward_key) - ) - criterion = ( - forward_field == pks_f[0] if len(pks_f) == 1 else forward_field.isin(pks_f) + db.query_class.from_(through_table).where(backward_field == pk_b).select(forward_key) ) + criterion = forward_field == pks_f[0] if len(pks_f) == 1 else forward_field.isin(pks_f) select_query = select_query.where(criterion) _, already_existing_relations_raw = await db.execute_query( @@ -247,9 +231,7 @@ async def add( } if pks_f_to_insert := set(pks_f) - already_existing_forward_pks: - query = db.query_class.into(through_table).columns( - forward_field, backward_field - ) + query = db.query_class.into(through_table).columns(forward_field, backward_field) for pk_f in pks_f_to_insert: query = query.insert(pk_f, pk_b) await db.execute_query(*query.get_parameterized_sql()) @@ -260,9 +242,7 @@ async def clear(self, using_db: BaseDBAsyncClient | None = None) -> None: """ await self._remove_or_clear(using_db=using_db) - async def remove( - self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None - ) -> None: + async def remove(self, *instances: MODEL, using_db: BaseDBAsyncClient | None = None) -> None: """ Removes one or more of ``instances`` from the relation. @@ -287,9 +267,9 @@ async def _remove_or_clear( if instances: related_pk_formatting_func = type(instances[0])._meta.pk.to_db_value if len(instances) == 1: - condition &= through_table[ - self.field.forward_key - ] == related_pk_formatting_func(instances[0].pk, instances[0]) + condition &= through_table[self.field.forward_key] == related_pk_formatting_func( + instances[0].pk, instances[0] + ) else: condition &= through_table[self.field.forward_key].isin( [related_pk_formatting_func(i.pk, i) for i in instances] @@ -317,9 +297,7 @@ def __init__( if TYPE_CHECKING: @overload - def __get__( - self, instance: None, owner: type[Model] - ) -> RelationalField[MODEL]: ... + def __get__(self, instance: None, owner: type[Model]) -> RelationalField[MODEL]: ... @overload def __get__(self, instance: Model, owner: type[Model]) -> MODEL: ... @@ -348,9 +326,7 @@ def validate_model_name(cls, model_name: str | type[Model]) -> None: ) from None elif len(model_name.split(".")) != 2: field_type = cls.__name__.replace("Instance", "") - raise ConfigurationError( - f'{field_type} accepts model name in format "app.Model"' - ) + raise ConfigurationError(f'{field_type} accepts model name in format "app.Model"') class ForeignKeyFieldInstance(RelationalField[MODEL]): @@ -370,9 +346,7 @@ def __init__( "on_delete can only be CASCADE, RESTRICT, SET_NULL, SET_DEFAULT or NO_ACTION" ) if on_delete == SET_NULL and not bool(kwargs.get("null")): - raise ConfigurationError( - "If on_delete is SET_NULL, then field must have null=True set" - ) + raise ConfigurationError("If on_delete is SET_NULL, then field must have null=True set") self.on_delete = on_delete def describe(self, serializable: bool) -> dict: