Skip to content

Commit

Permalink
Resolve configured AUTH_USER_MODEL with a get_type_analyze_hook (#…
Browse files Browse the repository at this point in the history
…2335)

It's nothing special really. We declare a `TypeAlias` in the pyi, thus getting a fullname, somewhere under `django.contrib.auth`. We then build a type analyze hook for that fullname. And the hook simulates what `django.contrib.auth.get_user_model` does

The alias is set up to point to `AbstractBaseUser`, so for a type checker other than mypy nothing should've changed
  • Loading branch information
flaeppe committed Aug 12, 2024
1 parent 3e6f1a6 commit 81ecdcf
Show file tree
Hide file tree
Showing 23 changed files with 220 additions and 126 deletions.
22 changes: 10 additions & 12 deletions django-stubs/contrib/auth/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from django.contrib.auth.backends import BaseBackend
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import _UserModel
from django.contrib.auth.models import AnonymousUser
from django.db.models.options import Options
from django.http.request import HttpRequest
Expand All @@ -18,19 +18,17 @@ REDIRECT_FIELD_NAME: str

def load_backend(path: str) -> BaseBackend: ...
def get_backends() -> list[BaseBackend]: ...
def authenticate(request: HttpRequest | None = ..., **credentials: Any) -> AbstractBaseUser | None: ...
async def aauthenticate(request: HttpRequest | None = ..., **credentials: Any) -> AbstractBaseUser | None: ...
def login(
request: HttpRequest, user: AbstractBaseUser | None, backend: type[BaseBackend] | str | None = ...
) -> None: ...
def authenticate(request: HttpRequest | None = ..., **credentials: Any) -> _UserModel | None: ...
async def aauthenticate(request: HttpRequest | None = ..., **credentials: Any) -> _UserModel | None: ...
def login(request: HttpRequest, user: _UserModel | None, backend: type[BaseBackend] | str | None = ...) -> None: ...
async def alogin(
request: HttpRequest, user: AbstractBaseUser | None, backend: type[BaseBackend] | str | None = ...
request: HttpRequest, user: _UserModel | None, backend: type[BaseBackend] | str | None = ...
) -> None: ...
def logout(request: HttpRequest) -> None: ...
async def alogout(request: HttpRequest) -> None: ...
def get_user_model() -> type[AbstractBaseUser]: ...
def get_user(request: HttpRequest | Client) -> AbstractBaseUser | AnonymousUser: ...
async def aget_user(request: HttpRequest | Client) -> AbstractBaseUser | AnonymousUser: ...
def get_user_model() -> type[_UserModel]: ...
def get_user(request: HttpRequest | Client) -> _UserModel | AnonymousUser: ...
async def aget_user(request: HttpRequest | Client) -> _UserModel | AnonymousUser: ...
def get_permission_codename(action: str, opts: Options) -> str: ...
def update_session_auth_hash(request: HttpRequest, user: AbstractBaseUser) -> None: ...
async def aupdate_session_auth_hash(request: HttpRequest, user: AbstractBaseUser) -> None: ...
def update_session_auth_hash(request: HttpRequest, user: _UserModel) -> None: ...
async def aupdate_session_auth_hash(request: HttpRequest, user: _UserModel) -> None: ...
15 changes: 7 additions & 8 deletions django-stubs/contrib/auth/backends.pyi
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from typing import Any, TypeVar

from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import AbstractBaseUser, _UserModel
from django.contrib.auth.models import AnonymousUser, Permission
from django.db.models import QuerySet
from django.db.models.base import Model
from django.http.request import HttpRequest
from typing_extensions import TypeAlias

_AnyUser: TypeAlias = AbstractBaseUser | AnonymousUser

UserModel: Any
UserModel: TypeAlias = type[_UserModel]
_AnyUser: TypeAlias = _UserModel | AnonymousUser

class BaseBackend:
def authenticate(self, request: HttpRequest | None, **kwargs: Any) -> AbstractBaseUser | None: ...
def get_user(self, user_id: Any) -> AbstractBaseUser | None: ...
def authenticate(self, request: HttpRequest | None, **kwargs: Any) -> _UserModel | None: ...
def get_user(self, user_id: Any) -> _UserModel | None: ...
def get_user_permissions(self, user_obj: _AnyUser, obj: Model | None = ...) -> set[str]: ...
def get_group_permissions(self, user_obj: _AnyUser, obj: Model | None = ...) -> set[str]: ...
def get_all_permissions(self, user_obj: _AnyUser, obj: Model | None = ...) -> set[str]: ...
Expand All @@ -22,7 +21,7 @@ class BaseBackend:
class ModelBackend(BaseBackend):
def authenticate(
self, request: HttpRequest | None, username: str | None = ..., password: str | None = ..., **kwargs: Any
) -> AbstractBaseUser | None: ...
) -> _UserModel | None: ...
def has_module_perms(self, user_obj: _AnyUser, app_label: str) -> bool: ...
def user_can_authenticate(self, user: _AnyUser | None) -> bool: ...
def with_perm(
Expand All @@ -31,7 +30,7 @@ class ModelBackend(BaseBackend):
is_active: bool = ...,
include_superusers: bool = ...,
obj: Model | None = ...,
) -> QuerySet[AbstractBaseUser]: ...
) -> QuerySet[_UserModel]: ...

class AllowAllUsersModelBackend(ModelBackend): ...

Expand Down
6 changes: 6 additions & 0 deletions django-stubs/contrib/auth/base_user.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ from django.db import models
from django.db.models.base import Model
from django.db.models.expressions import Combinable
from django.db.models.fields import BooleanField
from typing_extensions import TypeAlias

_T = TypeVar("_T", bound=Model)

Expand Down Expand Up @@ -41,3 +42,8 @@ class AbstractBaseUser(models.Model):
@classmethod
@overload
def normalize_username(cls, username: Any) -> Any: ...

# This is our "placeholder" type the mypy plugin refines to configured 'AUTH_USER_MODEL'
# wherever it is used as a type. The most recognised example of this is (probably)
# `HttpRequest.user`
_UserModel: TypeAlias = AbstractBaseUser # noqa: PYI047
5 changes: 3 additions & 2 deletions django-stubs/contrib/auth/decorators.pyi
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from collections.abc import Callable, Iterable
from typing import TypeVar, overload

from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
from django.contrib.auth.base_user import _UserModel
from django.contrib.auth.models import AnonymousUser
from django.http.response import HttpResponseBase

_VIEW = TypeVar("_VIEW", bound=Callable[..., HttpResponseBase])

def user_passes_test(
test_func: Callable[[AbstractBaseUser | AnonymousUser], bool],
test_func: Callable[[_UserModel | AnonymousUser], bool],
login_url: str | None = ...,
redirect_field_name: str | None = ...,
) -> Callable[[_VIEW], _VIEW]: ...
Expand Down
11 changes: 6 additions & 5 deletions django-stubs/contrib/auth/forms.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ from collections.abc import Iterable
from typing import Any, TypeVar

from django import forms
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import AbstractBaseUser, _UserModel
from django.contrib.auth.tokens import PasswordResetTokenGenerator
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.fields import _ErrorMessagesDict
from django.forms.fields import _ClassLevelWidgetT
from django.forms.widgets import Widget
from django.http.request import HttpRequest
from typing_extensions import TypeAlias

UserModel: type[AbstractBaseUser]
UserModel: TypeAlias = type[_UserModel]
_User = TypeVar("_User", bound=AbstractBaseUser)

class ReadOnlyPasswordHashWidget(forms.Widget):
Expand Down Expand Up @@ -47,11 +48,11 @@ class AuthenticationForm(forms.Form):
password: forms.Field
error_messages: _ErrorMessagesDict
request: HttpRequest | None
user_cache: Any
user_cache: _UserModel | None
username_field: models.Field
def __init__(self, request: HttpRequest | None = ..., *args: Any, **kwargs: Any) -> None: ...
def confirm_login_allowed(self, user: AbstractBaseUser) -> None: ...
def get_user(self) -> AbstractBaseUser: ...
def get_user(self) -> _UserModel: ...
def get_invalid_login_error(self) -> ValidationError: ...
def clean(self) -> dict[str, Any]: ...

Expand All @@ -66,7 +67,7 @@ class PasswordResetForm(forms.Form):
to_email: str,
html_email_template_name: str | None = ...,
) -> None: ...
def get_users(self, email: str) -> Iterable[AbstractBaseUser]: ...
def get_users(self, email: str) -> Iterable[_UserModel]: ...
def save(
self,
domain_override: str | None = ...,
Expand Down
6 changes: 3 additions & 3 deletions django-stubs/contrib/auth/middleware.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import _UserModel
from django.contrib.auth.models import AnonymousUser
from django.http.request import HttpRequest
from django.utils.deprecation import MiddlewareMixin

def get_user(request: HttpRequest) -> AnonymousUser | AbstractBaseUser: ...
async def auser(request: HttpRequest) -> AnonymousUser | AbstractBaseUser: ...
def get_user(request: HttpRequest) -> AnonymousUser | _UserModel: ...
async def auser(request: HttpRequest) -> AnonymousUser | _UserModel: ...

class AuthenticationMiddleware(MiddlewareMixin):
def process_request(self, request: HttpRequest) -> None: ...
Expand Down
5 changes: 1 addition & 4 deletions django-stubs/contrib/auth/password_validation.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ from collections.abc import Mapping, Sequence
from pathlib import Path, PosixPath
from typing import Any, Protocol, type_check_only

from django.db.models.base import Model
from typing_extensions import TypeAlias

_UserModel: TypeAlias = Model
from django.contrib.auth.base_user import _UserModel

@type_check_only
class PasswordValidator(Protocol):
Expand Down
10 changes: 5 additions & 5 deletions django-stubs/contrib/auth/tokens.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from datetime import date, datetime
from typing import Any

from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import _UserModel

class PasswordResetTokenGenerator:
key_salt: str
secret: str | bytes
secret_fallbacks: list[str | bytes]
algorithm: str
def make_token(self, user: AbstractBaseUser) -> str: ...
def check_token(self, user: AbstractBaseUser | None, token: str | None) -> bool: ...
def _make_token_with_timestamp(self, user: AbstractBaseUser, timestamp: int, secret: str | bytes = ...) -> str: ...
def _make_hash_value(self, user: AbstractBaseUser, timestamp: int) -> str: ...
def make_token(self, user: _UserModel) -> str: ...
def check_token(self, user: _UserModel | None, token: str | None) -> bool: ...
def _make_token_with_timestamp(self, user: _UserModel, timestamp: int, secret: str | bytes = ...) -> str: ...
def _make_hash_value(self, user: _UserModel, timestamp: int) -> str: ...
def _num_seconds(self, dt: datetime | date) -> int: ...
def _now(self) -> datetime: ...

Expand Down
7 changes: 4 additions & 3 deletions django-stubs/contrib/auth/views.pyi
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any

from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import _UserModel
from django.contrib.auth.forms import AuthenticationForm
from django.http.request import HttpRequest
from django.http.response import HttpResponse, HttpResponseRedirect
from django.views.generic.base import TemplateView
from django.views.generic.edit import FormView
from typing_extensions import TypeAlias

UserModel: Any
UserModel: TypeAlias = type[_UserModel]

class RedirectURLMixin:
next_page: str | None
Expand Down Expand Up @@ -65,7 +66,7 @@ class PasswordResetConfirmView(PasswordContextMixin, FormView):
token_generator: Any
validlink: bool
user: Any
def get_user(self, uidb64: str) -> AbstractBaseUser | None: ...
def get_user(self, uidb64: str) -> _UserModel | None: ...

class PasswordResetCompleteView(PasswordContextMixin, TemplateView):
title: Any
Expand Down
6 changes: 3 additions & 3 deletions django-stubs/http/request.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from io import BytesIO
from re import Pattern
from typing import Any, Awaitable, BinaryIO, Callable, Literal, NoReturn, TypeVar, overload, type_check_only

from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import _UserModel
from django.contrib.auth.models import AnonymousUser
from django.contrib.sessions.backends.base import SessionBase
from django.contrib.sites.models import Site
Expand Down Expand Up @@ -55,9 +55,9 @@ class HttpRequest(BytesIO):
# django.contrib.admin views:
current_app: str
# django.contrib.auth.middleware.AuthenticationMiddleware:
user: AbstractBaseUser | AnonymousUser
user: _UserModel | AnonymousUser
# django.contrib.auth.middleware.AuthenticationMiddleware:
auser: Callable[[], Awaitable[AbstractBaseUser | AnonymousUser]]
auser: Callable[[], Awaitable[_UserModel | AnonymousUser]]
# django.middleware.locale.LocaleMiddleware:
LANGUAGE_CODE: str
# django.contrib.sites.middleware.CurrentSiteMiddleware
Expand Down
6 changes: 3 additions & 3 deletions django-stubs/test/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from types import TracebackType
from typing import Any, Generic, Literal, NoReturn, TypedDict, TypeVar, type_check_only

from asgiref.typing import ASGIVersions
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import _UserModel
from django.contrib.sessions.backends.base import SessionBase
from django.core.handlers.asgi import ASGIRequest
from django.core.handlers.base import BaseHandler
Expand Down Expand Up @@ -213,8 +213,8 @@ class ClientMixin:
async def asession(self) -> SessionBase: ...
def login(self, **credentials: Any) -> bool: ...
async def alogin(self, **credentials: Any) -> bool: ...
def force_login(self, user: AbstractBaseUser, backend: str | None = ...) -> None: ...
async def aforce_login(self, user: AbstractBaseUser, backend: str | None = ...) -> None: ...
def force_login(self, user: _UserModel, backend: str | None = ...) -> None: ...
async def aforce_login(self, user: _UserModel, backend: str | None = ...) -> None: ...
def logout(self) -> None: ...
async def alogout(self) -> None: ...

Expand Down
3 changes: 2 additions & 1 deletion mypy_django_plugin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ def parse_ini_file(self, filepath: Path) -> None:
except ValueError:
exit_with_error(INVALID_BOOL_SETTING.format(key="strict_settings"))

def to_json(self) -> Dict[str, Any]:
def to_json(self, extra_data: Dict[str, Any]) -> Dict[str, Any]:
"""We use this method to reset mypy cache via `report_config_data` hook."""
return {
"django_settings_module": self.django_settings_module,
"strict_settings": self.strict_settings,
**dict(sorted(extra_data.items())),
}
4 changes: 4 additions & 0 deletions mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,7 @@ def resolve_lookup_expected_type(

def resolve_f_expression_type(self, f_expression_type: Instance) -> ProperType:
return AnyType(TypeOfAny.explicit)

@cached_property
def is_contrib_auth_installed(self) -> bool:
return "django.contrib.auth" in self.settings.INSTALLED_APPS
3 changes: 2 additions & 1 deletion mypy_django_plugin/lib/fullnames.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ABSTRACT_BASE_USER_MODEL_FULLNAME = "django.contrib.auth.base_user.AbstractBaseUser"
ABSTRACT_USER_MODEL_FULLNAME = "django.contrib.auth.models.AbstractUser"
PERMISSION_MIXIN_CLASS_FULLNAME = "django.contrib.auth.models.PermissionsMixin"
MODEL_METACLASS_FULLNAME = "django.db.models.base.ModelBase"
Expand Down Expand Up @@ -62,7 +63,7 @@

DJANGO_ABSTRACT_MODELS = frozenset(
(
"django.contrib.auth.base_user.AbstractBaseUser",
ABSTRACT_BASE_USER_MODEL_FULLNAME,
ABSTRACT_USER_MODEL_FULLNAME,
PERMISSION_MIXIN_CLASS_FULLNAME,
"django.contrib.sessions.base_session.AbstractBaseSession",
Expand Down
38 changes: 18 additions & 20 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
meta,
orm_lookups,
querysets,
request,
settings,
)
from mypy_django_plugin.transformers.auth import get_user_model
from mypy_django_plugin.transformers.functional import resolve_str_promise_attribute
from mypy_django_plugin.transformers.managers import (
add_as_manager_to_queryset_class,
Expand Down Expand Up @@ -110,15 +110,14 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
return [self._new_dependency("typing"), self._new_dependency("django_stubs_ext")]

# for `get_user_model()`
if self.django_context.settings:
if file.fullname == "django.contrib.auth" or file.fullname in {"django.http", "django.http.request"}:
auth_user_model_name = self.django_context.settings.AUTH_USER_MODEL
try:
auth_user_module = self.django_context.apps_registry.get_model(auth_user_model_name).__module__
except LookupError:
# get_user_model() model app is not installed
return []
return [self._new_dependency(auth_user_module), self._new_dependency("django_stubs_ext")]
if file.fullname == "django.contrib.auth" or file.fullname in {"django.http", "django.http.request"}:
auth_user_model_name = self.django_context.settings.AUTH_USER_MODEL
try:
auth_user_module = self.django_context.apps_registry.get_model(auth_user_model_name).__module__
except LookupError:
# get_user_model() model app is not installed
return []
return [self._new_dependency(auth_user_module), self._new_dependency("django_stubs_ext")]

# ensure that all mentioned to='someapp.SomeModel' are loaded with corresponding related Fields
defined_model_classes = self.django_context.model_modules.get(file.fullname)
Expand Down Expand Up @@ -149,9 +148,6 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
]

def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], MypyType]]:
if fullname == "django.contrib.auth.get_user_model":
return partial(settings.get_user_model_hook, django_context=self.django_context)

info = self._get_typeinfo_or_none(fullname)
if info:
if info.has_base(fullnames.FIELD_FULLNAME):
Expand Down Expand Up @@ -270,10 +266,6 @@ def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeConte
if info and info.has_base(fullnames.PERMISSION_MIXIN_CLASS_FULLNAME) and attr_name == "is_superuser":
return partial(set_auth_user_model_boolean_fields, django_context=self.django_context)

# Lookup of the 'request.user' attribute
if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == "user":
return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context)

# Lookup of the 'user.is_staff' or 'user.is_active' attribute
if info and info.has_base(fullnames.ABSTRACT_USER_MODEL_FULLNAME) and attr_name in ("is_staff", "is_active"):
return partial(set_auth_user_model_boolean_fields, django_context=self.django_context)
Expand All @@ -299,8 +291,9 @@ def get_type_analyze_hook(self, fullname: str) -> Optional[Callable[[AnalyzeType
"django_stubs_ext.annotations.WithAnnotations",
):
return partial(handle_annotated_type, fullname=fullname)
else:
return None
elif fullname == "django.contrib.auth.base_user._UserModel":
return partial(get_user_model, django_context=self.django_context)
return None

def get_dynamic_class_hook(self, fullname: str) -> Optional[Callable[[DynamicClassDefContext], None]]:
# Create a new manager class definition when a manager's '.from_queryset' classmethod is called
Expand All @@ -313,7 +306,12 @@ def get_dynamic_class_hook(self, fullname: str) -> Optional[Callable[[DynamicCla

def report_config_data(self, ctx: ReportConfigContext) -> Dict[str, Any]:
# Cache would be cleared if any settings do change.
return self.plugin_config.to_json()
extra_data = {}
# In all places we use '_UserModel' alias as a type we want to clear cache if
# AUTH_USER_MODEL setting changes
if ctx.id.startswith("django.contrib.auth") or ctx.id in {"django.http.request", "django.test.client"}:
extra_data["AUTH_USER_MODEL"] = self.django_context.settings.AUTH_USER_MODEL
return self.plugin_config.to_json(extra_data)


def plugin(version: str) -> Type[NewSemanalDjangoPlugin]:
Expand Down
Loading

0 comments on commit 81ecdcf

Please sign in to comment.