# This file is part of django-ca (https://github.com/mathiasertl/django-ca).
#
# django-ca is free software: you can redistribute it and/or modify it under the terms of the GNU General
# Public License as published by the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# django-ca is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
# for more details.
#
# You should have received a copy of the GNU General Public License along with django-ca. If not, see
# <http://www.gnu.org/licenses/>.
""":py:mod:`django_ca.tests.base.assertions` collects assertions used throughout the entire test suite."""
import collections
import io
import re
import typing
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from datetime import UTC, datetime, timedelta
from typing import Any
from unittest.mock import Mock
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, x448, x25519
from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.x509.oid import ExtensionOID
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.core.management import CommandError
import pytest
from django_ca.conf import model_settings
from django_ca.constants import ReasonFlags
from django_ca.deprecation import RemovedInDjangoCA320Warning
from django_ca.key_backends.storages.models import StoragesUsePrivateKeyOptions
from django_ca.models import Certificate, CertificateAuthority, CertificateRevocationList, X509CertMixin
from django_ca.signals import post_create_ca, post_issue_cert, post_sign_cert, pre_create_ca, pre_sign_cert
from django_ca.tests.base.mocks import mock_signal
from django_ca.tests.base.utils import (
authority_information_access,
basic_constraints,
cmd_e2e,
crl_cache_key,
crl_distribution_points,
distribution_point,
get_idp,
uri,
verify_chain_signatures,
)
[docs]
@contextmanager
def assert_removed_in_320(match: str | re.Pattern[str] | None = None) -> Iterator[None]: # pragma: no cover
"""Assert that a ``RemovedInDjangoCA320Warning`` is emitted."""
with pytest.warns(RemovedInDjangoCA320Warning, match=match):
yield
[docs]
def assert_authority_key_identifier(issuer: CertificateAuthority, cert: X509CertMixin) -> None:
"""Assert the AuthorityKeyIdentifier extension of `issuer`.
This assertion tests that :py:class:`~cg:cryptography.x509.AuthorityKeyIdentifier` extension of `cert`
matches the :py:class:`~cg:cryptography.x509.SubjectKeyIdentifier` extension of `issuer`.
"""
actual = cert.extensions[ExtensionOID.AUTHORITY_KEY_IDENTIFIER].value
expected = issuer.extensions[ExtensionOID.SUBJECT_KEY_IDENTIFIER].value
assert actual.key_identifier == expected.key_identifier # type: ignore[union-attr]
[docs]
def assert_ca_properties(
ca: CertificateAuthority,
name: str,
parent: CertificateAuthority | None = None,
private_key_type: type[CertificateIssuerPrivateKeyTypes] = rsa.RSAPrivateKey,
acme_enabled: bool = False,
acme_profile: str | None = None,
acme_requires_contact: bool = True,
password: bytes | None = None,
) -> None:
"""Assert some basic properties of a CA."""
parent_ca = parent or ca
parent_serial = parent_ca.serial
issuer = parent_ca.subject
base_url = f"http://{model_settings.CA_DEFAULT_HOSTNAME}/django_ca/"
assert ca.name == name
assert ca.enabled is True
assert ca.parent == parent
# Test ACME properties
assert ca.acme_enabled is acme_enabled
assert ca.acme_profile == acme_profile or model_settings.CA_DEFAULT_PROFILE
assert ca.acme_requires_contact is acme_requires_contact
# Test certificate properties
assert ca.issuer == issuer
# TYPEHINT NOTE: We assume a StoragesBackend here
assert isinstance(
ca.key_backend.get_key(ca, StoragesUsePrivateKeyOptions(password=password)), # type: ignore[attr-defined]
private_key_type,
)
# Test AuthorityKeyIdentifier extension
assert_authority_key_identifier(parent_ca, ca)
# Test the BasicConstraints extension
basic_constraints_ext = typing.cast(
x509.Extension[x509.BasicConstraints], ca.extensions[ExtensionOID.BASIC_CONSTRAINTS]
)
assert basic_constraints_ext.critical is True
assert basic_constraints_ext.value.ca is True
# Test default signing extensions
assert ca.sign_authority_information_access == authority_information_access(
ocsp=[uri(f"{base_url}ocsp/{ca.serial}/cert/")],
ca_issuers=[uri(f"{base_url}issuer/{parent_serial}.der")],
)
assert ca.sign_certificate_policies is None
assert ca.sign_crl_distribution_points == crl_distribution_points(
distribution_point([uri(f"{base_url}crl/{ca.serial}/")])
)
assert ca.sign_issuer_alternative_name is None
[docs]
def assert_certificate(
cert: x509.Certificate | CertificateAuthority | Certificate,
subject: x509.Name,
algorithm: type[hashes.HashAlgorithm] = hashes.SHA512,
signer: CertificateAuthority | x509.Certificate | None = None,
) -> None:
"""Assert certificate properties."""
if isinstance(cert, CertificateAuthority):
if cert.parent is not None:
signer = cert.parent
cert = cert.pub.loaded
elif isinstance(cert, Certificate): # pragma: no cover # not used, currently
signer = cert.ca
cert = cert.pub.loaded
if signer is None:
signer = cert
assert cert.version == x509.Version.v3
assert cert.issuer == signer.subject
assert cert.subject == subject
assert isinstance(cert.signature_hash_algorithm, algorithm)
[docs]
@contextmanager
def assert_command_error(msg: str, returncode: int = 1) -> Iterator[None]:
"""Context manager asserting that CommandError is raised.
Parameters
----------
msg : str
The regex matching the exception message.
"""
with pytest.raises(CommandError, match=msg) as exc_info:
yield
assert exc_info.value.returncode == returncode
[docs]
def assert_count_equal(first: Iterable[Any], second: Iterable[Any]) -> None:
"""Roughly equivalent version of ``unittests assertCountEqual()``."""
first, second = list(first), list(second)
assert collections.Counter(first) == collections.Counter(second)
[docs]
@contextmanager
def assert_create_ca_signals(pre: bool = True, post: bool = True) -> Iterator[tuple[Mock, Mock]]:
"""Context manager asserting that the `pre_create_ca`/`post_create_ca` signals are (not) called."""
with mock_signal(pre_create_ca) as pre_sig, mock_signal(post_create_ca) as post_sig:
try:
yield pre_sig, post_sig
finally:
assert pre_sig.called is pre
assert post_sig.called is post
[docs]
@contextmanager
def assert_create_cert_signals(pre: bool = True, post: bool = True) -> Iterator[tuple[Mock, Mock]]:
"""Context manager asserting that the `pre_create_cert`/`post_create_cert` signals are (not) called."""
with mock_signal(pre_sign_cert) as pre_sig, mock_signal(post_issue_cert) as post_sig:
try:
yield pre_sig, post_sig
finally:
assert pre_sig.called is pre
assert post_sig.called is post
[docs]
def assert_crl(
crl: bytes,
expected: typing.Sequence[X509CertMixin] | None = None,
signer: CertificateAuthority | None = None,
expires: int = 86400,
algorithm: hashes.HashAlgorithm | None = None,
encoding: Encoding = Encoding.PEM,
idp: x509.Extension[x509.IssuingDistributionPoint] | None = None,
extensions: list[x509.Extension[x509.ExtensionType]] | None = None,
crl_number: int = 0,
entry_extensions: tuple[list[x509.Extension[x509.ExtensionType]]] | None = None,
) -> None:
"""Test the given CRL.
Parameters
----------
crl : bytes
The raw CRL
expected : list
signer
expires
algorithm
encoding
idp
extensions
crl_number
entry_extensions
"""
expected = expected or []
signer = signer or CertificateAuthority.objects.get(name="child")
extensions = extensions or []
now = datetime.now(tz=UTC)
expires_timestamp = (now + timedelta(seconds=expires)).replace(microsecond=0)
if idp is not None: # pragma: no branch
extensions.append(idp)
last_update = now.replace(microsecond=0)
extensions.append(signer.get_authority_key_identifier_extension())
extensions.append(
x509.Extension(
value=x509.CRLNumber(crl_number=crl_number), critical=False, oid=ExtensionOID.CRL_NUMBER
)
)
if encoding == Encoding.PEM:
parsed_crl = x509.load_pem_x509_crl(crl)
else:
parsed_crl = x509.load_der_x509_crl(crl)
if algorithm is None:
algorithm = signer.algorithm
public_key = signer.pub.loaded.public_key()
if isinstance(public_key, x448.X448PublicKey | x25519.X25519PublicKey): # pragma: no cover
raise TypeError() # just to make mypy happy
assert isinstance(parsed_crl.signature_hash_algorithm, type(algorithm))
assert parsed_crl.is_signature_valid(public_key) is True
assert parsed_crl.issuer == signer.pub.loaded.subject
assert parsed_crl.last_update_utc == last_update
assert parsed_crl.next_update_utc == expires_timestamp
def ext_sorter(ext: x509.Extension[x509.ExtensionType]) -> str:
return ext.oid.dotted_string
assert sorted(parsed_crl.extensions, key=ext_sorter) == sorted(extensions, key=ext_sorter)
entries = {e.serial_number: e for e in parsed_crl}
assert sorted(entries) == sorted(c.pub.loaded.serial_number for c in expected)
for i, entry in enumerate(entries.values()):
if entry_extensions:
assert list(entry.extensions) == entry_extensions[i]
else:
assert not list(entry.extensions)
[docs]
def assert_crls(
ca: CertificateAuthority,
number: int = 0,
expected_user: list[Certificate] | None = None,
expected_ca: list[Certificate] | None = None,
) -> None:
"""Test the CRLs for the given certificate authority."""
for kwargs in [{"only_contains_ca_certs": True}, {"only_contains_user_certs": True}]:
der_key = crl_cache_key(ca.serial, **kwargs) # type: ignore[arg-type]
pem_key = crl_cache_key(ca.serial, Encoding.PEM, **kwargs) # type: ignore[arg-type]
idp = get_idp(full_name=None, **kwargs) # type: ignore[arg-type]
# Determine expected certificates based on scope
if kwargs.get("only_contains_user_certs"):
expected = expected_user
else:
expected = expected_ca
# Fetch and test CRLs from the cache
der_crl = cache.get(der_key)
pem_crl = cache.get(pem_key)
assert_crl(der_crl, expected, ca, encoding=Encoding.DER, idp=idp, crl_number=number)
assert_crl(pem_crl, expected, ca, encoding=Encoding.PEM, idp=idp, crl_number=number)
# Fetch from the database and verify
db_crl = CertificateRevocationList.objects.scope(serial=ca.serial, **kwargs).newest() # type: ignore[arg-type]
assert db_crl is not None
assert db_crl.data == der_crl
assert db_crl.number == number
[docs]
def assert_e2e_error(
cmd: typing.Sequence[str],
stdout: str | re.Pattern[str] = "",
stderr: str | re.Pattern[str] = "",
code: int = 2,
) -> None:
"""Assert an error was through in an e2e command."""
actual_stdout = io.StringIO()
actual_stderr = io.StringIO()
with assert_system_exit(code):
cmd_e2e(cmd, stdout=actual_stdout, stderr=actual_stderr)
if isinstance(stdout, str | bytes):
assert stdout == actual_stdout.getvalue()
else: # pragma: no cover
assert stdout.search(actual_stdout.getvalue())
if isinstance(stderr, str): # pragma: no cover
assert stderr == actual_stderr.getvalue()
else:
assert stderr.search(actual_stderr.getvalue())
[docs]
def assert_extension_equal(
first: x509.Extension[x509.ExtensionType] | None, second: x509.Extension[x509.ExtensionType] | None
) -> None:
"""Compare two extensions for equality (or if both are None).
This assertion overrides comparison for iterable extension and should be used only when order of these
extension values cannot be guaranteed. For example, two ExtendedKeyUsage extension will pass as equal
regardless of order of the extended key usages in the extensions.
"""
# If both are None that's still okay.
if first is None and second is None:
return
if first is None or second is None: # pragma: no cover
raise AssertionError("One of the values is None.")
if second.oid in (
ExtensionOID.EXTENDED_KEY_USAGE,
ExtensionOID.TLS_FEATURE,
ExtensionOID.SUBJECT_ALTERNATIVE_NAME,
ExtensionOID.ISSUER_ALTERNATIVE_NAME,
):
assert first.oid == second.oid
assert first.critical == second.critical
assert_count_equal(first.value, second.value) # type: ignore[arg-type]
else:
assert first == second
[docs]
def assert_extensions(
cert: X509CertMixin | x509.Certificate,
extensions: Iterable[x509.Extension[x509.ExtensionType]],
signer: CertificateAuthority | None = None,
expect_defaults: bool = True,
) -> None:
"""Assert that `cert` has the given extensions."""
# temporary fast check
for ext in extensions:
assert isinstance(ext, x509.Extension)
expected = {e.oid: e for e in extensions}
if isinstance(cert, Certificate):
pubkey = cert.pub.loaded.public_key()
actual = cert.extensions
signer = cert.ca
elif isinstance(cert, CertificateAuthority):
pubkey = cert.pub.loaded.public_key()
actual = cert.extensions
if cert.parent is None: # root CA
signer = cert
else: # intermediate CA
signer = cert.parent
elif isinstance(cert, x509.Certificate): # cg cert
pubkey = cert.public_key()
actual = {e.oid: e for e in cert.extensions}
else: # pragma: no cover
raise ValueError("cert must be Certificate(Authority) or x509.Certificate)")
if expect_defaults is True:
if isinstance(cert, Certificate):
expected.setdefault(ExtensionOID.BASIC_CONSTRAINTS, basic_constraints(ca=False))
if signer is not None: # pragma: no branch
expected.setdefault(
ExtensionOID.AUTHORITY_KEY_IDENTIFIER,
signer.get_authority_key_identifier_extension(),
)
if isinstance(cert, Certificate) and signer.sign_crl_distribution_points:
expected.setdefault(ExtensionOID.CRL_DISTRIBUTION_POINTS, signer.sign_crl_distribution_points)
if isinstance(cert, Certificate) and signer.sign_authority_information_access:
expected.setdefault(
ExtensionOID.AUTHORITY_INFORMATION_ACCESS, signer.sign_authority_information_access
)
ski = x509.SubjectKeyIdentifier.from_public_key(pubkey)
expected.setdefault(
ExtensionOID.SUBJECT_KEY_IDENTIFIER,
x509.Extension(oid=ExtensionOID.SUBJECT_KEY_IDENTIFIER, critical=False, value=ski),
)
# Diff output is bad for dicts, so we sort this based on dotted string to get better output
actual_tuple = sorted(actual.items(), key=lambda t: t[0].dotted_string)
expected_tuple = sorted(expected.items(), key=lambda t: t[0].dotted_string)
assert actual_tuple == expected_tuple
[docs]
def assert_issuing_distribution_point(
extension: x509.Extension[x509.IssuingDistributionPoint],
full_name: Iterable[x509.GeneralName] | None = None,
relative_name: x509.RelativeDistinguishedName | None = None,
only_contains_user_certs: bool = False,
only_contains_ca_certs: bool = False,
only_some_reasons: frozenset[x509.ReasonFlags] | None = None,
indirect_crl: bool = False,
only_contains_attribute_certs: bool = False,
critical: bool = True,
) -> None:
"""Shortcut for asserting an Issuing Point Distribution extension."""
assert extension == x509.Extension(
oid=ExtensionOID.ISSUING_DISTRIBUTION_POINT,
critical=critical,
value=x509.IssuingDistributionPoint(
full_name=full_name,
relative_name=relative_name,
only_contains_user_certs=only_contains_user_certs,
only_contains_ca_certs=only_contains_ca_certs,
only_contains_attribute_certs=only_contains_attribute_certs,
indirect_crl=indirect_crl,
only_some_reasons=only_some_reasons,
),
)
[docs]
def assert_post_issue_cert(post: Mock, cert: Certificate) -> None:
"""Assert that the post_issue_cert signal was called with the expected certificate."""
post.assert_called_once_with(cert=cert, signal=post_issue_cert, sender=Certificate)
[docs]
def assert_revoked(
cert: X509CertMixin, reason: str | None = None, compromised: datetime | None = None
) -> None:
"""Assert that the certificate is now revoked."""
if isinstance(cert, CertificateAuthority):
cert = CertificateAuthority.objects.get(serial=cert.serial)
else:
cert = Certificate.objects.get(serial=cert.serial)
assert cert.revoked
assert cert.compromised == compromised
if reason is None:
assert cert.revoked_reason == ReasonFlags.unspecified.name
else:
assert cert.revoked_reason == reason
[docs]
@contextmanager
def assert_sign_cert_signals(pre: bool = True, post: bool = True) -> Iterator[tuple[Mock, Mock]]:
"""Context manager mocking both ``pre_create_ca`` and ``post_create_ca`` signals."""
with mock_signal(pre_sign_cert) as pre_sig, mock_signal(post_sign_cert) as post_sig:
try:
yield pre_sig, post_sig
finally:
assert pre_sig.called is pre
assert post_sig.called is post
[docs]
def assert_signature(chain: Iterable[CertificateAuthority], cert: Certificate | CertificateAuthority) -> None:
"""Assert that `cert` is properly signed by `chain`.
.. seealso:: http://stackoverflow.com/questions/30700348
"""
verify_chain_signatures(list(reversed([*chain, cert])))
[docs]
@contextmanager
def assert_system_exit(code: int) -> Iterator[None]:
"""Assert that SystemExit is raised."""
with pytest.raises(SystemExit, match=rf"^{code}$") as excm:
yield
assert excm.value.args == (code,)
[docs]
@contextmanager
def assert_validation_error(errors: dict[str, list[str]]) -> Iterator[None]:
"""Context manager to assert that a ValidationError is thrown."""
with pytest.raises(ValidationError) as excinfo:
yield
assert excinfo.value.message_dict == errors