Add more type hinting

This commit is contained in:
Gabriel Augendre 2021-12-28 23:33:50 +01:00
parent 40a4a0a308
commit 2793dc18a0
6 changed files with 61 additions and 32 deletions

View file

@ -1,7 +1,7 @@
from typing import Any
from django.conf import settings
from django.http import HttpRequest
from django.core.handlers.wsgi import WSGIRequest
from articles.models import Article
from attachments.models import Attachment
@ -11,7 +11,7 @@ IGNORED_PATHS = [
]
def drafts_count(request: HttpRequest) -> dict[str, Any]:
def drafts_count(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS:
return {}
if not request.user.is_authenticated:
@ -19,13 +19,13 @@ def drafts_count(request: HttpRequest) -> dict[str, Any]:
return {"drafts_count": Article.objects.filter(status=Article.DRAFT).count()}
def date_format(request: HttpRequest) -> dict[str, Any]:
def date_format(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS:
return {}
return {"CUSTOM_ISO": r"Y-m-d\TH:i:sO", "ISO_DATE": "Y-m-d"}
def git_version(request: HttpRequest) -> dict[str, Any]:
def git_version(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS:
return {}
try:
@ -39,13 +39,13 @@ def git_version(request: HttpRequest) -> dict[str, Any]:
return {"git_version": version, "git_version_url": url}
def analytics(request: HttpRequest) -> dict[str, Any]:
def analytics(request: WSGIRequest) -> dict[str, Any]:
return {
"goatcounter_domain": settings.GOATCOUNTER_DOMAIN,
}
def open_graph_image_url(request: HttpRequest) -> dict[str, Any]:
def open_graph_image_url(request: WSGIRequest) -> dict[str, Any]:
if request.path in IGNORED_PATHS:
return {}
open_graph_image = Attachment.objects.get_open_graph_image()
@ -55,7 +55,7 @@ def open_graph_image_url(request: HttpRequest) -> dict[str, Any]:
return {"open_graph_image_url": url}
def blog_metadata(request: HttpRequest) -> dict[str, Any]:
def blog_metadata(request: WSGIRequest) -> dict[str, Any]:
return {
"blog_title": settings.BLOG["title"],
"blog_description": settings.BLOG["description"],

View file

@ -1,7 +1,8 @@
from typing import Any
from django.contrib.auth.decorators import login_required
from django.http import HttpRequest, HttpResponse
from django.core.handlers.wsgi import WSGIRequest
from django.http import HttpResponse
from django.shortcuts import render
from django.views.decorators.http import require_POST
@ -10,7 +11,7 @@ from articles.models import Article, Tag
@login_required
@require_POST
def render_article(request: HttpRequest, article_pk: int) -> HttpResponse:
def render_article(request: WSGIRequest, article_pk: int) -> HttpResponse:
template = "articles/article_detail.html"
article = Article.objects.get(pk=article_pk)
article.content = request.POST.get("content", article.content)

View file

@ -1,39 +1,49 @@
from datetime import datetime
from typing import Iterable
from django.contrib.syndication.views import Feed
from django.core.handlers.wsgi import WSGIRequest
from django.db.models import QuerySet
from articles.models import Article, Tag
from blog import settings
class CompleteFeed(Feed):
class BaseFeed(Feed):
FEED_LIMIT = 15
title = settings.BLOG["title"]
link = settings.BLOG["base_url"]
description = settings.BLOG["description"]
def get_queryset(self, obj):
def item_description(self, item: Article) -> str: # type: ignore[override]
return item.get_formatted_content
def item_pubdate(self, item: Article) -> datetime | None:
return item.published_at
def _get_queryset(self) -> QuerySet[Article]:
return Article.objects.filter(status=Article.PUBLISHED).order_by(
"-published_at"
)
def items(self, obj):
return self.get_queryset(obj)[: self.FEED_LIMIT]
def item_description(self, item: Article): # type: ignore[override]
return item.get_formatted_content
class CompleteFeed(BaseFeed):
title = settings.BLOG["title"]
link = settings.BLOG["base_url"]
def item_pubdate(self, item: Article):
return item.published_at
def items(self) -> Iterable[Article]:
return self._get_queryset()[: self.FEED_LIMIT]
class TagFeed(CompleteFeed):
def get_object(self, request, *args, **kwargs):
class TagFeed(BaseFeed):
def get_object( # type: ignore[override]
self, request: WSGIRequest, *args, **kwargs
) -> Tag:
return Tag.objects.get(slug=kwargs.get("slug"))
def get_queryset(self, tag):
return super().get_queryset(tag).filter(tags=tag)
def title(self, tag):
def title(self, tag: Tag) -> str:
return tag.get_feed_title()
def link(self, tag):
def link(self, tag: Tag) -> str:
return tag.get_absolute_url()
def items(self, tag: Tag) -> Iterable[Article]:
return self._get_queryset().filter(tags=tag)[: self.FEED_LIMIT]

View file

@ -1,11 +1,15 @@
import operator
from functools import reduce
from typing import Any
from django.conf import settings
from django.contrib.auth.mixins import LoginRequiredMixin
from django.core.handlers.wsgi import WSGIRequest
from django.core.paginator import Page
from django.db.models import F, Q
from django.shortcuts import get_object_or_404
from django.views import generic
from django.views.generic import DetailView
from articles.models import Article, Tag
@ -16,12 +20,13 @@ class BaseArticleListView(generic.ListView):
paginate_by = 10
main_title = "Blog posts"
html_title = ""
request: WSGIRequest
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
context["blog_title"] = settings.BLOG["title"]
context["blog_description"] = settings.BLOG["description"]
page_obj = context["page_obj"]
page_obj: Page = context["page_obj"]
if page_obj.has_next():
querystring = self.build_querystring({"page": page_obj.next_page_number()})
context["next_page_querystring"] = querystring
@ -35,7 +40,7 @@ class BaseArticleListView(generic.ListView):
def get_additional_querystring_params(self) -> dict[str, str]:
return {}
def build_querystring(self, initial_queryparams: dict[str, str]) -> str:
def build_querystring(self, initial_queryparams: dict[str, Any]) -> str:
querystring = {
**initial_queryparams,
**self.get_additional_querystring_params(),
@ -50,7 +55,7 @@ class PublicArticleListView(BaseArticleListView):
class ArticlesListView(PublicArticleListView):
def get_context_data(self, **kwargs):
context = super().get_context_data(**kwargs)
home_article: Article = Article.objects.filter(
home_article = Article.objects.filter(
status=Article.PUBLISHED, is_home=True
).first()
context["article"] = home_article
@ -120,10 +125,11 @@ class DraftsListView(LoginRequiredMixin, BaseArticleListView):
return context
class ArticleDetailView(generic.DetailView):
class ArticleDetailView(DetailView[Article]):
model = Article
context_object_name = "article"
template_name = "articles/article_detail.html"
request: WSGIRequest
def get_queryset(self):
key = self.request.GET.get("draft_key")

View file

@ -12,6 +12,7 @@ https://docs.djangoproject.com/en/3.1/ref/settings/
import os
from pathlib import Path
import django_stubs_ext
import environ
# Build paths inside the project like this: BASE_DIR / 'subdir'.
@ -40,6 +41,7 @@ if env_file:
environ.Env.read_env(env_file)
django_stubs_ext.monkeypatch()
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/3.1/howto/deployment/checklist/

View file

@ -22,13 +22,23 @@ def test_cov(ctx):
)
@task(post=[test_cov])
def check(ctx):
@task
def pre_commit(ctx):
with ctx.cd(BASE_DIR):
ctx.run("pre-commit run --all-files", pty=True)
@task
def mypy(ctx):
with ctx.cd(BASE_DIR):
ctx.run("mypy src", pty=True)
@task(pre=[pre_commit, mypy, test_cov])
def check(ctx):
pass
@task
def build(ctx):
with ctx.cd(BASE_DIR):