Source code for django_ca.tests.base.utils

# This file is part of django-ca (https://github.com/mathiasertl/django-ca).
#
# django-ca is free software: you can redistribute it and/or modify it under the terms of the GNU General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# django-ca is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
# for more details.
#
# You should have received a copy of the GNU General Public License along with django-ca. If not, see
# <http://www.gnu.org/licenses/>.

"""Utility functions used in testing."""

import inspect
import ipaddress
import os
import shutil
import tempfile
import textwrap
import typing
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from datetime import datetime
from io import BytesIO, StringIO
from typing import Any
from unittest import mock

from pydantic import BaseModel

from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed448, ed25519, padding, rsa
from cryptography.hazmat.primitives.asymmetric.types import (
    CertificateIssuerPrivateKeyTypes,
    CertificateIssuerPublicKeyTypes,
)
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.x509.oid import AuthorityInformationAccessOID, ExtensionOID, NameOID

from django.conf import settings
from django.core.files.storage import storages
from django.core.management import ManagementUtility, call_command
from django.test import override_settings
from django.urls import reverse
from django.utils.crypto import get_random_string

from django_ca.extensions import extension_as_text
from django_ca.key_backends import KeyBackend
from django_ca.models import Certificate, CertificateAuthority, X509CertMixin
from django_ca.profiles import profiles
from django_ca.tests.acme.views.constants import SERVER_NAME
from django_ca.tests.base.constants import CERT_DATA, FIXTURES_DIR
from django_ca.typehints import ArgumentGroup, CertificateExtension, ParsableKeyType, SignatureHashAlgorithm
from django_ca.utils import get_crl_cache_key


[docs] class DummyModel(BaseModel): """Dummy model for the dummy backend."""
# pylint: disable-next=abstract-method # we don't override sign_data to test the base class function.
[docs] class DummyBackend(KeyBackend[DummyModel, DummyModel, DummyModel]): # pragma: no cover """Backend with no actions whatsoever.""" title = "dummy backend" description = "dummy description" # This backend only supports RSA and EC keys, but also the (invented) "STRANGE" key type. supported_key_types = ("RSA", "EC", "STRANGE") supported_elliptic_curves = ("sect571r1",) supported_hash_algorithms = ("SHA-256", "SHA-512") def __eq__(self, other: Any) -> bool: return isinstance(other, DummyBackend) def __hash__(self) -> int: return hash(id(self)) def get_create_private_key_options( self, key_type: ParsableKeyType, key_size: int | None, elliptic_curve: str | None, options: dict[str, Any], ) -> DummyModel: return DummyModel() def add_use_parent_private_key_arguments(self, group: ArgumentGroup) -> None: return None def get_use_parent_private_key_options( self, ca: CertificateAuthority, options: dict[str, Any] ) -> DummyModel: return DummyModel() def get_store_private_key_options(self, options: dict[str, Any]) -> DummyModel: return DummyModel() def create_private_key( self, ca: CertificateAuthority, key_type: ParsableKeyType, options: DummyModel ) -> tuple[CertificateIssuerPublicKeyTypes, DummyModel]: return None, DummyModel() # type: ignore[return-value] def get_use_private_key_options(self, ca: CertificateAuthority, options: dict[str, Any]) -> DummyModel: return DummyModel() def is_usable( self, ca: "CertificateAuthority", use_private_key_options: DummyModel | None = None ) -> bool: return True def check_usable(self, ca: "CertificateAuthority", use_private_key_options: DummyModel) -> None: return def sign_certificate_revocation_list( self, ca: "CertificateAuthority", use_private_key_options: DummyModel, builder: x509.CertificateRevocationListBuilder, algorithm: SignatureHashAlgorithm | None, ) -> x509.CertificateRevocationList: return None # type: ignore[return-value] def sign_certificate( self, ca: "CertificateAuthority", use_private_key_options: DummyModel, public_key: CertificateIssuerPublicKeyTypes, serial: int, algorithm: SignatureHashAlgorithm | None, issuer: x509.Name, subject: x509.Name, not_after: datetime, extensions: Sequence[CertificateExtension], ) -> x509.Certificate: return None # type: ignore[return-value] def store_private_key( self, ca: "CertificateAuthority", key: CertificateIssuerPrivateKeyTypes, certificate: x509.Certificate, options: DummyModel, ) -> None: return None
[docs] def root_reverse(name: str, **kwargs: Any) -> str: """Shortcut to get a django-ca URI with a root serial.""" kwargs.setdefault("serial", CERT_DATA["root"]["serial"]) return reverse(f"django_ca:{name}", kwargs=kwargs)
[docs] def root_uri(name: str, hostname: str | None = None, **kwargs: Any) -> str: """Full URI with a root serial.""" if not hostname: # pragma: no branch hostname = SERVER_NAME path = root_reverse(name, **kwargs) return f"http://{hostname}{path}"
[docs] def authority_information_access( ca_issuers: Iterable[x509.GeneralName] | None = None, ocsp: Iterable[x509.GeneralName] | None = None, critical: bool = False, ) -> x509.Extension[x509.AuthorityInformationAccess]: """Shortcut for getting a AuthorityInformationAccess extension.""" access_descriptions = [] # NOTE: OCSP is first because OID is lexicographically smaller if ocsp is not None: # pragma: no branch access_descriptions += [ x509.AccessDescription(access_method=AuthorityInformationAccessOID.OCSP, access_location=name) for name in ocsp ] if ca_issuers is not None: # pragma: no branch access_descriptions += [ x509.AccessDescription( access_method=AuthorityInformationAccessOID.CA_ISSUERS, access_location=issuer ) for issuer in ca_issuers ] value = x509.AuthorityInformationAccess(access_descriptions) return x509.Extension(oid=ExtensionOID.AUTHORITY_INFORMATION_ACCESS, critical=critical, value=value)
[docs] def basic_constraints( ca: bool = False, path_length: int | None = None, critical: bool = True ) -> x509.Extension[x509.BasicConstraints]: """Shortcut for getting a BasicConstraints extension.""" return x509.Extension( oid=ExtensionOID.BASIC_CONSTRAINTS, critical=critical, value=x509.BasicConstraints(ca=ca, path_length=path_length), )
[docs] def certificate_policies( *policies: x509.PolicyInformation, critical: bool = False ) -> x509.Extension[x509.CertificatePolicies]: """Shortcut for getting a Certificate Policy extension.""" return x509.Extension( oid=ExtensionOID.CERTIFICATE_POLICIES, critical=critical, value=x509.CertificatePolicies(policies) )
@typing.overload def cmd(*args: Any, stdout: BytesIO, stderr: BytesIO, **kwargs: Any) -> tuple[bytes, bytes]: ... @typing.overload def cmd(*args: Any, stdout: BytesIO, stderr: StringIO | None = None, **kwargs: Any) -> tuple[bytes, str]: ... @typing.overload def cmd(*args: Any, stdout: StringIO | None = None, stderr: BytesIO, **kwargs: Any) -> tuple[str, bytes]: ... @typing.overload def cmd( *args: Any, stdout: StringIO | None = None, stderr: StringIO | None = None, **kwargs: Any ) -> tuple[str, str]: ...
[docs] def cmd( *args: Any, stdout: StringIO | BytesIO | None = None, stderr: StringIO | BytesIO | None = None, **kwargs: Any, ) -> tuple[str | bytes, str | bytes]: """Call to a manage.py command using call_command.""" if stdout is None: stdout = StringIO() if stderr is None: stderr = StringIO() stdin = kwargs.pop("stdin", StringIO()) if isinstance(stdin, StringIO): with mock.patch("sys.stdin", stdin): call_command(*args, stdout=stdout, stderr=stderr, **kwargs) else: # mock https://docs.python.org/3/library/io.html#io.BufferedReader.read def _read_mock(size=None): # type: ignore # pylint: disable=unused-argument return stdin with mock.patch("sys.stdin.buffer.read", side_effect=_read_mock): call_command(*args, stdout=stdout, stderr=stderr, **kwargs) return stdout.getvalue(), stderr.getvalue()
@typing.overload def cmd_e2e( args: typing.Sequence[str], *, stdin: StringIO | bytes | None = None, stdout: StringIO | None = None, stderr: StringIO | None = None, ) -> tuple[str, str]: ... @typing.overload def cmd_e2e( args: typing.Sequence[str], *, stdin: StringIO | bytes | None = None, stdout: BytesIO, stderr: StringIO | None = None, ) -> tuple[bytes, str]: ... @typing.overload def cmd_e2e( args: typing.Sequence[str], *, stdin: StringIO | bytes | None = None, stdout: StringIO | None = None, stderr: BytesIO, ) -> tuple[str, bytes]: ... @typing.overload def cmd_e2e( args: typing.Sequence[str], *, stdin: StringIO | bytes | None = None, stdout: BytesIO, stderr: BytesIO, ) -> tuple[bytes, bytes]: ...
[docs] def cmd_e2e( args: typing.Sequence[str], stdin: StringIO | bytes | None = None, stdout: BytesIO | StringIO | None = None, stderr: BytesIO | StringIO | None = None, ) -> tuple[str | bytes, str | bytes]: """Call a management command the way manage.py does. Unlike call_command, this method also tests the argparse configuration of the called command. """ stdout = stdout or StringIO() stderr = stderr or StringIO() if stdin is None: stdin = StringIO() if isinstance(stdin, StringIO): stdin_mock = mock.patch("sys.stdin", stdin) else: def _read_mock(size=None): # type: ignore # pylint: disable=unused-argument return stdin # TYPE NOTE: mypy detects a different type, but important thing is it's a context manager stdin_mock = mock.patch( # type: ignore[assignment] "sys.stdin.buffer.read", side_effect=_read_mock ) # BinaryCommand commands (such as dump_crl) write to sys.stdout.buffer, but BytesIO does not have a # buffer attribute, so we manually add the attribute. if isinstance(stdout, BytesIO): stdout.buffer = stdout # type: ignore[attr-defined] if isinstance(stderr, BytesIO): stderr.buffer = stderr # type: ignore[attr-defined] with ( stdin_mock, mock.patch("sys.stdout", stdout), mock.patch("sys.stderr", stderr), # execute() closes all connections at the end. Somehow this is not an issue in SQLite3, # but when using PostgreSQL, tests will fail from here onwards if we don't mock this. mock.patch("django.core.management.base.connections.close_all", return_value=None), ): util = ManagementUtility(["manage.py", *args]) util.execute() return stdout.getvalue(), stderr.getvalue()
[docs] def cn(value: str) -> "x509.NameAttribute[str]": """Shortcut for creating a common name attr.""" return x509.NameAttribute(NameOID.COMMON_NAME, value)
[docs] def country(value: str) -> "x509.NameAttribute[str]": """Shortcut for creating a country attr.""" return x509.NameAttribute(NameOID.COUNTRY_NAME, value)
[docs] def crl_cache_key( serial: str, encoding: Encoding = Encoding.DER, only_contains_ca_certs: bool = False, only_contains_user_certs: bool = False, only_contains_attribute_certs: bool = False, only_some_reasons: Iterable[x509.ReasonFlags] | None = None, ) -> str: """Shortcut to get a CRL cache key.""" return get_crl_cache_key( serial, encoding, only_contains_ca_certs=only_contains_ca_certs, only_contains_user_certs=only_contains_user_certs, only_contains_attribute_certs=only_contains_attribute_certs, only_some_reasons=only_some_reasons, )
[docs] def crl_distribution_points( *distribution_points: x509.DistributionPoint, critical: bool = False ) -> x509.Extension[x509.CRLDistributionPoints]: """Shortcut for getting a CRLDistributionPoint extension.""" value = x509.CRLDistributionPoints(distribution_points) return x509.Extension(oid=ExtensionOID.CRL_DISTRIBUTION_POINTS, critical=critical, value=value)
[docs] def distribution_point( full_name: Iterable[x509.GeneralName] | None = None, relative_name: x509.RelativeDistinguishedName | None = None, reasons: frozenset[x509.ReasonFlags] | None = None, crl_issuer: Iterable[x509.GeneralName] | None = None, ) -> x509.DistributionPoint: """Shortcut for generating a single distribution point.""" return x509.DistributionPoint( full_name=full_name, relative_name=relative_name, reasons=reasons, crl_issuer=crl_issuer )
[docs] def extended_key_usage( *usages: x509.ObjectIdentifier, critical: bool = False ) -> x509.Extension[x509.ExtendedKeyUsage]: """Shortcut for getting an ExtendedKeyUsage extension.""" return x509.Extension( oid=ExtensionOID.EXTENDED_KEY_USAGE, critical=critical, value=x509.ExtendedKeyUsage(usages) )
[docs] def freshest_crl( *distribution_points: x509.DistributionPoint, critical: bool = False ) -> x509.Extension[x509.FreshestCRL]: """Shortcut for getting a CRLDistributionPoints extension.""" return x509.Extension( oid=ExtensionOID.FRESHEST_CRL, critical=critical, value=x509.FreshestCRL(distribution_points) )
[docs] def get_cert_context(name: str) -> dict[str, Any]: """Get a dictionary suitable for testing output based on the dictionary in basic.certs.""" ctx: dict[str, Any] = {} for key, value in sorted(CERT_DATA[name].items()): # Handle cryptography extensions if key == "extensions": ctx["extensions"] = {ext["type"]: ext for ext in CERT_DATA[name].get("extensions", [])} elif key == "precert_poison": ctx["precert_poison"] = "* Precert Poison (critical):\n Yes" elif isinstance(value, x509.Extension): if value.critical: ctx[f"{key}_critical"] = " (critical)" else: ctx[f"{key}_critical"] = "" ctx[f"{key}_text"] = textwrap.indent(extension_as_text(value.value), " ") elif key == "path_length": ctx[key] = value ctx[f"{key}_text"] = "unlimited" if value is None else value else: ctx[key] = value if parent := CERT_DATA[name].get("parent"): ctx["parent_name"] = CERT_DATA[parent]["name"] ctx["parent_serial"] = CERT_DATA[parent]["serial"] ctx["parent_serial_colons"] = CERT_DATA[parent]["serial_colons"] if CERT_DATA[name]["key_filename"] is not False: storage = storages["django-ca"] ctx["key_path"] = storage.path(CERT_DATA[name]["key_filename"]) return ctx
[docs] def get_idp( full_name: Iterable[x509.GeneralName] | None = None, indirect_crl: bool = False, only_contains_attribute_certs: bool = False, only_contains_ca_certs: bool = False, only_contains_user_certs: bool = False, only_some_reasons: frozenset[x509.ReasonFlags] | None = None, relative_name: x509.RelativeDistinguishedName | None = None, ) -> "x509.Extension[x509.IssuingDistributionPoint]": """Get an IssuingDistributionPoint extension.""" return x509.Extension( oid=x509.oid.ExtensionOID.ISSUING_DISTRIBUTION_POINT, value=x509.IssuingDistributionPoint( full_name=full_name, indirect_crl=indirect_crl, only_contains_attribute_certs=only_contains_attribute_certs, only_contains_ca_certs=only_contains_ca_certs, only_contains_user_certs=only_contains_user_certs, only_some_reasons=only_some_reasons, relative_name=relative_name, ), critical=True, )
[docs] def iso_format(value: datetime, timespec: str = "seconds") -> str: """Convert a timestamp to ISO, with 'Z' instead of '+00:00'.""" return value.isoformat(timespec=timespec).replace("+00:00", "Z")
[docs] def issuer_alternative_name( *names: x509.GeneralName, critical: bool = False ) -> x509.Extension[x509.IssuerAlternativeName]: """Shortcut for getting a IssuerAlternativeName extension.""" return x509.Extension( oid=ExtensionOID.ISSUER_ALTERNATIVE_NAME, critical=critical, value=x509.IssuerAlternativeName(names), )
[docs] def key_usage(**usages: bool) -> x509.Extension[x509.KeyUsage]: """Shortcut for getting a KeyUsage extension.""" critical = usages.pop("critical", True) usages.setdefault("content_commitment", False) usages.setdefault("crl_sign", False) usages.setdefault("data_encipherment", False) usages.setdefault("decipher_only", False) usages.setdefault("digital_signature", False) usages.setdefault("encipher_only", False) usages.setdefault("key_agreement", False) usages.setdefault("key_cert_sign", False) usages.setdefault("key_encipherment", False) return x509.Extension(oid=ExtensionOID.KEY_USAGE, critical=critical, value=x509.KeyUsage(**usages))
[docs] def name_constraints( permitted: Iterable[x509.GeneralName] | None = None, excluded: Iterable[x509.GeneralName] | None = None, critical: bool = True, ) -> x509.Extension[x509.NameConstraints]: """Shortcut for getting a NameConstraints extension.""" return x509.Extension( oid=ExtensionOID.NAME_CONSTRAINTS, value=x509.NameConstraints(permitted_subtrees=permitted, excluded_subtrees=excluded), critical=critical, )
[docs] def ocsp_no_check(critical: bool = False) -> x509.Extension[x509.OCSPNoCheck]: """Shortcut for getting a OCSPNoCheck extension.""" return x509.Extension(oid=ExtensionOID.OCSP_NO_CHECK, critical=critical, value=x509.OCSPNoCheck())
[docs] def precert_poison() -> x509.Extension[x509.PrecertPoison]: """Shortcut for getting a PrecertPoison extension.""" return x509.Extension(oid=ExtensionOID.PRECERT_POISON, critical=True, value=x509.PrecertPoison())
[docs] def subject_alternative_name( *names: x509.GeneralName, critical: bool = False ) -> x509.Extension[x509.SubjectAlternativeName]: """Shortcut for getting a SubjectAlternativeName extension.""" return x509.Extension( oid=ExtensionOID.SUBJECT_ALTERNATIVE_NAME, critical=critical, value=x509.SubjectAlternativeName(names), )
[docs] def subject_key_identifier( cert: X509CertMixin | x509.Certificate, ) -> x509.Extension[x509.SubjectKeyIdentifier]: """Shortcut for getting a SubjectKeyIdentifier extension.""" if isinstance(cert, X509CertMixin): # pragma: no branch - usually full certificate is passed. cert = cert.pub.loaded ski = x509.SubjectKeyIdentifier.from_public_key(cert.public_key()) return x509.Extension(oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, critical=False, value=ski)
[docs] def state(value: str) -> "x509.NameAttribute[str]": """Return a state name attr.""" return x509.NameAttribute(oid=NameOID.STATE_OR_PROVINCE_NAME, value=value)
[docs] def tls_feature(*features: x509.TLSFeatureType, critical: bool = False) -> x509.Extension[x509.TLSFeature]: """Shortcut for getting a TLSFeature extension.""" return x509.Extension(oid=ExtensionOID.TLS_FEATURE, critical=critical, value=x509.TLSFeature(features))
FuncTypeVar = typing.TypeVar("FuncTypeVar", bound=typing.Callable[..., Any])
[docs] def dns(name: str) -> x509.DNSName: # just a shortcut """Shortcut to get a :py:class:`cg:cryptography.x509.DNSName`.""" return x509.DNSName(name)
[docs] def uri(url: str) -> x509.UniformResourceIdentifier: # just a shortcut """Shortcut to get a :py:class:`cg:cryptography.x509.UniformResourceIdentifier`.""" return x509.UniformResourceIdentifier(url)
[docs] def ip( name: ipaddress.IPv4Address | ipaddress.IPv6Address | ipaddress.IPv4Network | ipaddress.IPv6Network, ) -> x509.IPAddress: """Shortcut to get a :py:class:`cg:cryptography.x509.IPAddress`.""" return x509.IPAddress(name)
[docs] def rdn( name: Iterable[tuple[x509.ObjectIdentifier, str]], ) -> x509.RelativeDistinguishedName: # just a shortcut """Shortcut to get a :py:class:`cg:cryptography.x509.RelativeDistinguishedName`.""" return x509.RelativeDistinguishedName([x509.NameAttribute(*t) for t in name])
[docs] @contextmanager def mock_slug() -> Iterator[str]: """Mock random slug generation, yields the static value.""" slug = get_random_string(length=12) with mock.patch("django_ca.models.get_random_string", return_value=slug): yield slug
[docs] class override_tmpcadir(override_settings): # pylint: disable=invalid-name; in line with parent class """Sets the CA_DIR directory to a temporary directory. .. NOTE: This also takes any additional settings. """ def __call__(self, test_func: FuncTypeVar) -> FuncTypeVar: if not inspect.isfunction(test_func): raise ValueError("Only functions can use override_tmpcadir()") return super().__call__(test_func) def enable(self) -> None: tmpdir = tempfile.mkdtemp() self.options["CA_DIR"] = tmpdir self.options["STORAGES"] = settings.STORAGES self.options["STORAGES"]["django-ca"]["OPTIONS"]["location"] = tmpdir # copy CAs for filename in [v["key_filename"] for v in CERT_DATA.values() if v["key_filename"] is not False]: shutil.copy(os.path.join(FIXTURES_DIR, filename), tmpdir) # Copy OCSP public key (required for OCSP tests) shutil.copy(os.path.join(FIXTURES_DIR, CERT_DATA["profile-ocsp"]["pub_filename"]), tmpdir) # Reset profiles, so that they are loaded again on first access profiles._reset() # pylint: disable=protected-access super().enable() def disable(self) -> None: super().disable() shutil.rmtree(self.options["CA_DIR"])
[docs] def verify_signature(cert: x509.Certificate, issuer: x509.Certificate) -> None: """Verify cert was signed by issuer. Raises InvalidSignature on failure. Requires cryptography >= 42.0.0 (for signature_algorithm_parameters, used to detect RSA-PSS vs PKCS1v15). """ pub = issuer.public_key() if isinstance(pub, rsa.RSAPublicKey): assert pub.key_size >= 2048 # 1024 bit keys may cause issues in OpenSSL>=4 # signature_algorithm_parameters returns a PSS instance for RSA-PSS, # or None for PKCS1v15 params = cert.signature_algorithm_parameters pad = params if isinstance(params, padding.PSS) else padding.PKCS1v15() pub.verify( cert.signature, cert.tbs_certificate_bytes, pad, cert.signature_hash_algorithm, # type: ignore[arg-type] ) elif isinstance(pub, ec.EllipticCurvePublicKey): pub.verify( cert.signature, cert.tbs_certificate_bytes, ec.ECDSA(cert.signature_hash_algorithm), # type: ignore[arg-type] ) elif isinstance(pub, (ed25519.Ed25519PublicKey, ed448.Ed448PublicKey)): # Ed25519/Ed448 embed the hash; no algorithm argument pub.verify(cert.signature, cert.tbs_certificate_bytes) elif isinstance(pub, dsa.DSAPublicKey): pub.verify( cert.signature, cert.tbs_certificate_bytes, cert.signature_hash_algorithm, # type: ignore[arg-type] ) else: # pragma: no cover raise ValueError(f"Unsupported key type: {type(pub)}")
[docs] def verify_chain_signatures(chain: list[Certificate | CertificateAuthority]) -> None: """Verify signatures for a chain ordered leaf -> ... -> root. Raises InvalidSignature if any signature is invalid. """ print(0, [c.cn for c in chain]) for i, cert in enumerate(chain[:-1]): print(1, cert.cn, " --> signed by:", chain[i + 1].cn) verify_signature(cert.pub.loaded, issuer=chain[i + 1].pub.loaded) # Verify the root is self-signed print(2, chain[-1].cn, "--> signed by:", chain[-1].cn) verify_signature(chain[-1].pub.loaded, issuer=chain[-1].pub.loaded)