Source code for django_ca.tasks

# This file is part of django-ca (https://github.com/mathiasertl/django-ca).
#
# django-ca is free software: you can redistribute it and/or modify it under the terms of the GNU General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# django-ca is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
# for more details.
#
# You should have received a copy of the GNU General Public License along with django-ca. If not, see
# <http://www.gnu.org/licenses/>.

"""Asynchronous Celery tasks for django-ca.

.. seealso:: https://docs.celeryproject.org/en/stable/index.html
"""

import logging
import typing
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone as tz
from http import HTTPStatus
from typing import Any, Optional

import requests

from cryptography import x509
from cryptography.x509.oid import ExtensionOID

from django.db import transaction
from django.utils import timezone

from django_ca import ca_settings
from django_ca.acme.validation import validate_dns_01
from django_ca.constants import EXTENSION_DEFAULT_CRITICAL
from django_ca.models import (
    AcmeAuthorization,
    AcmeCertificate,
    AcmeChallenge,
    AcmeOrder,
    Certificate,
    CertificateAuthority,
    CertificateOrder,
)
from django_ca.profiles import profiles
from django_ca.pydantic.messages import GenerateOCSPKeyMessage, SignCertificateMessage
from django_ca.typehints import (
    JSON,
    EllipticCurves,
    HashAlgorithms,
    ParsableKeyType,
    SerializedPydanticExtension,
    SerializedPydanticName,
)
from django_ca.utils import parse_general_name

log = logging.getLogger(__name__)

FuncTypeVar = typing.TypeVar("FuncTypeVar", bound=typing.Callable[..., Any])

try:
    from celery import shared_task
    from celery.local import Proxy
except ImportError:

    def shared_task(func: FuncTypeVar) -> "Proxy[FuncTypeVar]":
        """Dummy decorator so that we can use the decorator whether celery is installed or not."""
        # We do not yet need this, but might come in handy in the future:
        # func.delay = lambda *a, **kw: func(*a, **kw)
        # func.apply_async = lambda *a, **kw: func(*a, **kw)
        func.delay = func  # type: ignore[attr-defined]
        return typing.cast("Proxy[FuncTypeVar]", func)


# pragma: only py<3.10: Use typing.ParamSpec for better type hinting
def run_task(task: "Proxy[FuncTypeVar]", *args: Any, **kwargs: Any) -> Any:
    """Function that passes `task` to celery or invokes it directly, depending on if Celery is installed."""
    eager = kwargs.pop("eager", False)

    if ca_settings.CA_USE_CELERY is True and eager is False:
        return task.delay(*args, **kwargs)

    return task(*args, **kwargs)


[docs] @shared_task def cache_crl(serial: str, key_backend_options: Optional[dict[str, JSON]] = None) -> None: """Task to cache the CRL for a given CA.""" if key_backend_options is None: key_backend_options = {} ca = CertificateAuthority.objects.get(serial=serial) key_backend_options_model = ca.key_backend.use_model.model_validate( key_backend_options, context={"ca": ca}, strict=True ) ca.cache_crls(key_backend_options_model)
[docs] @shared_task def cache_crls( serials: Optional[Iterable[str]] = None, key_backend_options: Optional[dict[str, dict[str, JSON]]] = None ) -> None: """Task to cache the CRLs for all CAs.""" if serials is None: serials = [] if key_backend_options is None: key_backend_options = {} if not serials: serials = typing.cast( Iterable[str], CertificateAuthority.objects.usable().values_list("serial", flat=True) ) for serial in serials: try: run_task(cache_crl, serial, key_backend_options=key_backend_options.get(serial, {})) except Exception: # pylint: disable=broad-exception-caught # NOTE: When using Celery, an exception will only be raised here if task.delay() itself raises an # exception, e.g. if the connection to the broker fails. Without celery, exceptions in cache_crl() # are raised here directly. log.exception("Error caching CRL for %s", serial)
[docs] @shared_task def generate_ocsp_key( serial: str, key_backend_options: Optional[dict[str, JSON]] = None, profile: str = "ocsp", expires: Optional[int] = None, algorithm: Optional[HashAlgorithms] = None, key_size: Optional[int] = None, key_type: Optional[ParsableKeyType] = None, elliptic_curve: Optional[EllipticCurves] = None, autogenerated: bool = True, force: bool = False, ) -> Optional[tuple[str, str, int]]: """Task to generate an OCSP key for the CA named by `serial`. The `serial` names the certificate authority for which to regenerate the OCSP responder certificate. All other arguments are passed on to :py:func:`~django_ca.models.CertificateAuthority.generate_ocsp_key`. The task returns the private and public key paths and the *primary key* of the generated certificate if a new certificate was generated, otherwise it returns ``None``. """ if key_backend_options is None: key_backend_options = {} parameters = GenerateOCSPKeyMessage( serial=serial, profile=profile, expires=expires, key_type=key_type, key_size=key_size, elliptic_curve=elliptic_curve, algorithm=algorithm, autogenerated=autogenerated, force=force, ) ca: CertificateAuthority = CertificateAuthority.objects.get(serial=parameters.serial) key_backend_options_model = ca.key_backend.use_model.model_validate( key_backend_options, context={"ca": ca}, strict=True ) value = ca.generate_ocsp_key( key_backend_options=key_backend_options_model, profile=parameters.profile, expires=parameters.expires, algorithm=parameters.algorithm, key_size=parameters.key_size, key_type=parameters.key_type, elliptic_curve=parameters.elliptic_curve, autogenerated=parameters.autogenerated, force=parameters.force, ) if value is not None: private_path, cert_path, cert = value return private_path, cert_path, cert.pk return None
[docs] @shared_task def generate_ocsp_keys( serials: Optional[Iterable[str]] = None, key_backend_options: Optional[dict[str, dict[str, JSON]]] = None ) -> None: """Task to generate an OCSP keys for all usable CAs.""" if serials is None: serials = [] if key_backend_options is None: key_backend_options = {} if not serials: serials = typing.cast( Iterable[str], CertificateAuthority.objects.usable().values_list("serial", flat=True) ) for serial in serials: try: run_task(generate_ocsp_key, serial, key_backend_options=key_backend_options.get(serial, {})) except Exception: # pylint: disable=broad-exception-caught # NOTE: When using Celery, an exception will only be raised here if task.delay() itself raises an # exception, e.g. if the connection to the broker fails. Without celery, exceptions in # generate_ocsp_key() are raised here directly. log.exception("Error creating OCSP responder key for %s", serial)
[docs] @shared_task @transaction.atomic def sign_certificate( order_pk: int, csr: str, subject: SerializedPydanticName, algorithm: Optional[HashAlgorithms] = None, expires: Optional[str] = None, extensions: Optional[list[SerializedPydanticExtension]] = None, profile: str = ca_settings.CA_DEFAULT_PROFILE, autogenerated: bool = False, ) -> int: """Sign a certificate from the given order with the given parameters.""" order = CertificateOrder.objects.select_related("certificate_authority").get(pk=order_pk) ca: CertificateAuthority = order.certificate_authority message = SignCertificateMessage( algorithm=algorithm, autogenerated=autogenerated, csr=csr, expires=expires, extensions=extensions, profile=profile, subject=subject, ) key_backend_options = ca.key_backend.get_use_private_key_options(ca, {}) parsed_extensions = message.get_extensions() extension_oids = [ext.oid for ext in parsed_extensions] for oid, extension in ca.extensions_for_certificate.items(): if oid not in extension_oids: parsed_extensions.append(extension) # Create a signed certificate certificate = ca.sign( key_backend_options, message.get_csr(), subject=message.subject.cryptography, # pylint: disable=no-member # false positive algorithm=message.get_algorithm(), expires=message.expires, extensions=parsed_extensions, ) # Store certificate in database certificate_obj = Certificate(ca=ca, profile=message.profile, autogenerated=message.autogenerated) certificate_obj.update_certificate(certificate) certificate_obj.save() # Update certificate order order.status = CertificateOrder.STATUS_ISSUED order.certificate = certificate_obj order.save() return certificate_obj.pk
[docs] @shared_task @transaction.atomic def acme_validate_challenge(challenge_pk: int) -> None: """Validate an ACME challenge.""" if not ca_settings.CA_ENABLE_ACME: log.error("ACME is not enabled.") return try: challenge = AcmeChallenge.objects.url().get(pk=challenge_pk) except AcmeChallenge.DoesNotExist: log.error("Challenge with id=%s not found", challenge_pk) return # Whoever is invoking this task is responsible for setting the status to "processing" first. if challenge.status != AcmeChallenge.STATUS_PROCESSING: log.error( "%s: %s: Invalid state (must be %s)", challenge, challenge.status, AcmeChallenge.STATUS_PROCESSING ) return # If the auth cannot be used for validation, neither can this challenge. We check auth.usable instead of # challenge.usable b/c a challenge in the "processing" state is not "usable" (= it is already being used). if challenge.auth.usable is False: log.error("%s: Authentication is not usable", challenge) return # General data for challenge validation value = challenge.auth.value # Challenge is marked as invalid by default challenge_valid = False # Validate HTTP challenge (only thing supported so far) if challenge.type == AcmeChallenge.TYPE_HTTP_01: decoded_token = challenge.encoded_token.decode("utf-8") expected = challenge.expected if requests is None: # pragma: no cover log.error("requests is not installed, cannot do http-01 challenge validation.") return url = f"http://{value}/.well-known/acme-challenge/{decoded_token}" try: with requests.get(url, timeout=1, stream=True) as response: # Only fetch the response body if the status code is HTTP 200 (OK) if response.status_code == HTTPStatus.OK: # Only fetch the expected number of bytes to prevent a large file ending up in memory # But fetch one extra byte (if available) to make sure that response has no extra bytes received = response.raw.read(len(expected) + 1, decode_content=True) challenge_valid = received == expected except Exception as ex: # pylint: disable=broad-except log.exception(ex) elif challenge.type == AcmeChallenge.TYPE_DNS_01: challenge_valid = validate_dns_01(challenge) # TODO: support ALPN_01 challenges # elif challenge.type == AcmeChallenge.TYPE_TLS_ALPN_01: # host = socket.gethostbyname(value) # sni_cert = crypto_util.probe_sni( # host=host, port=443, name=value, alpn_protocols=[TlsAlpnProtocol.V1] # ) else: log.error("%s: Challenge type is not supported.", challenge) # Transition state of the challenge depending on if the challenge is valid or not. RFC8555, Section 7.1.6: # # "If validation is successful, the challenge moves to the "valid" state; if there is an error, the # challenge moves to the "invalid" state." # # We also transition the matching authorization object: # # "If one of the challenges listed in the authorization transitions to the "valid" state, then the # authorization also changes to the "valid" state. If the client attempts to fulfill a challenge and # fails, or if there is an error while the authorization is still pending, then the authorization # transitions to the "invalid" state. # # We also transition the matching order object (section 7.4): # # "* ready: The server agrees that the requirements have been fulfilled, and is awaiting finalization. # Submit a finalization request." if challenge_valid: challenge.status = AcmeChallenge.STATUS_VALID challenge.validated = timezone.now() challenge.auth.status = AcmeAuthorization.STATUS_VALID # Set the order status to READY if all challenges are valid auths = AcmeAuthorization.objects.filter(order=challenge.auth.order) auths = auths.exclude(status=AcmeAuthorization.STATUS_VALID) if not auths.exclude(pk=challenge.auth.pk).exists(): log.info("Order is now valid") challenge.auth.order.status = AcmeOrder.STATUS_READY else: challenge.status = AcmeChallenge.STATUS_INVALID # RFC 8555, section 7.1.6: # # If the client attempts to fulfill a challenge and fails, or if there is an error while the # authorization is still pending, then the authorization transitions to the "invalid" state. challenge.auth.status = AcmeAuthorization.STATUS_INVALID # RFC 8555, section 7.1.6: # # If an error occurs at any of these stages, the order moves to the "invalid" state. challenge.auth.order.status = AcmeOrder.STATUS_INVALID log.info("%s is %s", challenge, challenge.status) challenge.save() challenge.auth.save() challenge.auth.order.save()
[docs] @shared_task @transaction.atomic def acme_issue_certificate(acme_certificate_pk: int) -> None: """Actually issue an ACME certificate.""" if not ca_settings.CA_ENABLE_ACME: log.error("ACME is not enabled.") return try: acme_cert = AcmeCertificate.objects.select_related("order__account__ca").get(pk=acme_certificate_pk) except AcmeCertificate.DoesNotExist: log.error("Certificate with id=%s not found", acme_certificate_pk) return if acme_cert.usable is False: log.error("%s: Cannot issue certificate for this order", acme_cert.order) return names = [a.subject_alternative_name for a in acme_cert.order.authorizations.all()] log.info("%s: Issuing certificate for %s", acme_cert.order, ",".join(names)) subject_alternative_names = x509.SubjectAlternativeName([parse_general_name(name) for name in names]) extensions = [ x509.Extension( oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, critical=EXTENSION_DEFAULT_CRITICAL[ExtensionOID.SUBJECT_ALTERNATIVE_NAME], value=subject_alternative_names, ) ] ca = acme_cert.order.account.ca profile = profiles[ca.acme_profile] # Honor not_after from the order if set if acme_cert.order.not_after: expires = acme_cert.order.not_after # Make sure expires_datetime is tz-aware, even if USE_TZ=False. if timezone.is_naive(expires): expires = timezone.make_aware(expires) else: expires = datetime.now(tz=tz.utc) + ca_settings.ACME_DEFAULT_CERT_VALIDITY csr = acme_cert.parse_csr() # Initialize key backend options key_backend_options = ca.key_backend.get_use_private_key_options(ca, {}) # Finally, actually create a certificate cert = Certificate.objects.create_cert( ca, key_backend_options, csr=csr, profile=profile, expires=expires, extensions=extensions ) acme_cert.cert = cert acme_cert.order.status = AcmeOrder.STATUS_VALID acme_cert.order.save() acme_cert.save()
[docs] @shared_task @transaction.atomic def acme_cleanup() -> None: """Cleanup expired ACME orders.""" if not ca_settings.CA_ENABLE_ACME: # NOTE: Since this task does only cleanup, log message is only info. log.info("ACME is not enabled, not doing anything.") return # Delete orders that expired more than a day ago. threshold = timezone.now() - timedelta(days=1) AcmeOrder.objects.filter(expires__lt=threshold).delete()