Add more type hinting

This commit is contained in:
Gabriel Augendre 2021-12-28 23:33:50 +01:00
parent 40a4a0a308
commit 2793dc18a0
6 changed files with 61 additions and 32 deletions

View file

@ -1,7 +1,7 @@
from typing import Any from typing import Any
from django.conf import settings from django.conf import settings
from django.http import HttpRequest from django.core.handlers.wsgi import WSGIRequest
from articles.models import Article from articles.models import Article
from attachments.models import Attachment from attachments.models import Attachment
@ -11,7 +11,7 @@ IGNORED_PATHS = [
] ]
def drafts_count(request: HttpRequest) -> dict[str, Any]: def drafts_count(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS: if request.path in IGNORED_PATHS:
return {} return {}
if not request.user.is_authenticated: if not request.user.is_authenticated:
@ -19,13 +19,13 @@ def drafts_count(request: HttpRequest) -> dict[str, Any]:
return {"drafts_count": Article.objects.filter(status=Article.DRAFT).count()} return {"drafts_count": Article.objects.filter(status=Article.DRAFT).count()}
def date_format(request: HttpRequest) -> dict[str, Any]: def date_format(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS: if request.path in IGNORED_PATHS:
return {} return {}
return {"CUSTOM_ISO": r"Y-m-d\TH:i:sO", "ISO_DATE": "Y-m-d"} return {"CUSTOM_ISO": r"Y-m-d\TH:i:sO", "ISO_DATE": "Y-m-d"}
def git_version(request: HttpRequest) -> dict[str, Any]: def git_version(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS: if request.path in IGNORED_PATHS:
return {} return {}
try: try:
@ -39,13 +39,13 @@ def git_version(request: HttpRequest) -> dict[str, Any]:
return {"git_version": version, "git_version_url": url} return {"git_version": version, "git_version_url": url}
def analytics(request: HttpRequest) -> dict[str, Any]: def analytics(request: WSGIRequest) -> dict[str, Any]:
return { return {
"goatcounter_domain": settings.GOATCOUNTER_DOMAIN, "goatcounter_domain": settings.GOATCOUNTER_DOMAIN,
} }
def open_graph_image_url(request: HttpRequest) -> dict[str, Any]: def open_graph_image_url(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS: if request.path in IGNORED_PATHS:
return {} return {}
open_graph_image = Attachment.objects.get_open_graph_image() open_graph_image = Attachment.objects.get_open_graph_image()
@ -55,7 +55,7 @@ def open_graph_image_url(request: HttpRequest) -> dict[str, Any]:
return {"open_graph_image_url": url} return {"open_graph_image_url": url}
def blog_metadata(request: HttpRequest) -> dict[str, Any]: def blog_metadata(request: WSGIRequest) -> dict[str, Any]:
return { return {
"blog_title": settings.BLOG["title"], "blog_title": settings.BLOG["title"],
"blog_description": settings.BLOG["description"], "blog_description": settings.BLOG["description"],

View file

@ -1,7 +1,8 @@
from typing import Any from typing import Any
from django.contrib.auth.decorators import login_required from django.contrib.auth.decorators import login_required
from django.http import HttpRequest, HttpResponse from django.core.handlers.wsgi import WSGIRequest
from django.http import HttpResponse
from django.shortcuts import render from django.shortcuts import render
from django.views.decorators.http import require_POST from django.views.decorators.http import require_POST
@ -10,7 +11,7 @@ from articles.models import Article, Tag
@login_required @login_required
@require_POST @require_POST
def render_article(request: HttpRequest, article_pk: int) -> HttpResponse: def render_article(request: WSGIRequest, article_pk: int) -> HttpResponse:
template = "articles/article_detail.html" template = "articles/article_detail.html"
article = Article.objects.get(pk=article_pk) article = Article.objects.get(pk=article_pk)
article.content = request.POST.get("content", article.content) article.content = request.POST.get("content", article.content)

View file

@ -1,39 +1,49 @@
from datetime import datetime
from typing import Iterable
from django.contrib.syndication.views import Feed from django.contrib.syndication.views import Feed
from django.core.handlers.wsgi import WSGIRequest
from django.db.models import QuerySet
from articles.models import Article, Tag from articles.models import Article, Tag
from blog import settings from blog import settings
class CompleteFeed(Feed): class BaseFeed(Feed):
FEED_LIMIT = 15 FEED_LIMIT = 15
title = settings.BLOG["title"]
link = settings.BLOG["base_url"]
description = settings.BLOG["description"] description = settings.BLOG["description"]
def get_queryset(self, obj): def item_description(self, item: Article) -> str: # type: ignore[override]
return item.get_formatted_content
def item_pubdate(self, item: Article) -> datetime | None:
return item.published_at
def _get_queryset(self) -> QuerySet[Article]:
return Article.objects.filter(status=Article.PUBLISHED).order_by( return Article.objects.filter(status=Article.PUBLISHED).order_by(
"-published_at" "-published_at"
) )
def items(self, obj):
return self.get_queryset(obj)[: self.FEED_LIMIT]
def item_description(self, item: Article): # type: ignore[override] class CompleteFeed(BaseFeed):
return item.get_formatted_content title = settings.BLOG["title"]
link = settings.BLOG["base_url"]
def item_pubdate(self, item: Article): def items(self) -> Iterable[Article]:
return item.published_at return self._get_queryset()[: self.FEED_LIMIT]
class TagFeed(CompleteFeed): class TagFeed(BaseFeed):
def get_object(self, request, *args, **kwargs): def get_object( # type: ignore[override]
self, request: WSGIRequest, *args, **kwargs
) -> Tag:
return Tag.objects.get(slug=kwargs.get("slug")) return Tag.objects.get(slug=kwargs.get("slug"))
def get_queryset(self, tag): def title(self, tag: Tag) -> str:
return super().get_queryset(tag).filter(tags=tag)
def title(self, tag):
return tag.get_feed_title() return tag.get_feed_title()
def link(self, tag): def link(self, tag: Tag) -> str:
return tag.get_absolute_url() return tag.get_absolute_url()
def items(self, tag: Tag) -> Iterable[Article]:
return self._get_queryset().filter(tags=tag)[: self.FEED_LIMIT]

View file

@ -1,11 +1,15 @@
import operator import operator
from functools import reduce from functools import reduce
from typing import Any
from django.conf import settings from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.core.handlers.wsgi import WSGIRequest
from django.core.paginator import Page
from django.db.models import F, Q from django.db.models import F, Q
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.views import generic from django.views import generic
from django.views.generic import DetailView
from articles.models import Article, Tag from articles.models import Article, Tag
@ -16,12 +20,13 @@ class BaseArticleListView(generic.ListView):
paginate_by = 10 paginate_by = 10
main_title = "Blog posts" main_title = "Blog posts"
html_title = "" html_title = ""
request: WSGIRequest
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs) context = super().get_context_data(**kwargs)
context["blog_title"] = settings.BLOG["title"] context["blog_title"] = settings.BLOG["title"]
context["blog_description"] = settings.BLOG["description"] context["blog_description"] = settings.BLOG["description"]
page_obj = context["page_obj"] page_obj: Page = context["page_obj"]
if page_obj.has_next(): if page_obj.has_next():
querystring = self.build_querystring({"page": page_obj.next_page_number()}) querystring = self.build_querystring({"page": page_obj.next_page_number()})
context["next_page_querystring"] = querystring context["next_page_querystring"] = querystring
@ -35,7 +40,7 @@ class BaseArticleListView(generic.ListView):
def get_additional_querystring_params(self) -> dict[str, str]: def get_additional_querystring_params(self) -> dict[str, str]:
return {} return {}
def build_querystring(self, initial_queryparams: dict[str, str]) -> str: def build_querystring(self, initial_queryparams: dict[str, Any]) -> str:
querystring = { querystring = {
**initial_queryparams, **initial_queryparams,
**self.get_additional_querystring_params(), **self.get_additional_querystring_params(),
@ -50,7 +55,7 @@ class PublicArticleListView(BaseArticleListView):
class ArticlesListView(PublicArticleListView): class ArticlesListView(PublicArticleListView):
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs) context = super().get_context_data(**kwargs)
home_article: Article = Article.objects.filter( home_article = Article.objects.filter(
status=Article.PUBLISHED, is_home=True status=Article.PUBLISHED, is_home=True
).first() ).first()
context["article"] = home_article context["article"] = home_article
@ -120,10 +125,11 @@ class DraftsListView(LoginRequiredMixin, BaseArticleListView):
return context return context
class ArticleDetailView(generic.DetailView): class ArticleDetailView(DetailView[Article]):
model = Article model = Article
context_object_name = "article" context_object_name = "article"
template_name = "articles/article_detail.html" template_name = "articles/article_detail.html"
request: WSGIRequest
def get_queryset(self): def get_queryset(self):
key = self.request.GET.get("draft_key") key = self.request.GET.get("draft_key")

View file

@ -12,6 +12,7 @@ https://docs.djangoproject.com/en/3.1/ref/settings/
import os import os
from pathlib import Path from pathlib import Path
import django_stubs_ext
import environ import environ
# Build paths inside the project like this: BASE_DIR / 'subdir'. # Build paths inside the project like this: BASE_DIR / 'subdir'.
@ -40,6 +41,7 @@ if env_file:
environ.Env.read_env(env_file) environ.Env.read_env(env_file)
django_stubs_ext.monkeypatch()
# Quick-start development settings - unsuitable for production # Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/3.1/howto/deployment/checklist/ # See https://docs.djangoproject.com/en/3.1/howto/deployment/checklist/

View file

@ -22,13 +22,23 @@ def test_cov(ctx):
) )
@task(post=[test_cov]) @task
def check(ctx): def pre_commit(ctx):
with ctx.cd(BASE_DIR): with ctx.cd(BASE_DIR):
ctx.run("pre-commit run --all-files", pty=True) ctx.run("pre-commit run --all-files", pty=True)
@task
def mypy(ctx):
with ctx.cd(BASE_DIR):
ctx.run("mypy src", pty=True) ctx.run("mypy src", pty=True)
@task(pre=[pre_commit, mypy, test_cov])
def check(ctx):
pass
@task @task
def build(ctx): def build(ctx):
with ctx.cd(BASE_DIR): with ctx.cd(BASE_DIR):