acertmgr/acertmgr/tools.py

453 lines
18 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# acertmgr - various support functions
# Copyright (c) Markus Hauschild & David Klaftenegger, 2016.
# Copyright (c) Rudolf Mayerhofer, 2019.
# available under the ISC license, see LICENSE
import base64
import datetime
import io
import os
import re
import stat
import sys
import traceback
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, ec, padding
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
from cryptography.utils import int_to_bytes
from cryptography.x509.oid import NameOID, ExtensionOID
try:
from cryptography.x509 import ocsp
except ImportError:
pass
try:
from cryptography.hazmat.primitives.asymmetric import ed25519, ed448
except ImportError:
pass
try:
from urllib.request import urlopen, Request # Python 3
except ImportError:
from urllib2 import urlopen, Request # Python 2
LOG_REPLACEMENTS = {}
class InvalidCertificateError(Exception):
pass
# @brief a simple, portable indent function
def indent(text, spaces=0):
ind = ' ' * spaces
return os.linesep.join(ind + line for line in text.splitlines())
# @brief wrapper for log output
def log(msg, exc=None, error=False, warning=False):
if error:
prefix = "Error: "
elif warning:
prefix = "Warning: "
else:
prefix = ""
output = prefix + msg
for k, v in LOG_REPLACEMENTS.items():
output = output.replace(k, v)
if exc:
_, exc_value, _ = sys.exc_info()
if not getattr(exc, '__traceback__', None) and exc == exc_value:
# Traceback handling on Python 2 is ugly, so we only output it if the exception is the current sys one
formatted_exc = traceback.format_exc()
else:
formatted_exc = traceback.format_exception(type(exc), exc, getattr(exc, '__traceback__', None))
exc_string = ''.join(formatted_exc) if isinstance(formatted_exc, list) else str(formatted_exc)
output += os.linesep + indent(exc_string, len(prefix))
if error or warning:
sys.stderr.write(output + os.linesep)
sys.stderr.flush() # force flush buffers after message was written for immediate display
else:
sys.stdout.write(output + os.linesep)
sys.stdout.flush() # force flush buffers after message was written for immediate display
# @brief wrapper for downloading an url
def get_url(url, data=None, headers=None):
return urlopen(Request(url, data=data, headers={} if headers is None else headers))
# @brief check whether existing certificate is still valid or expiring soon
# @param crt_file string containing the path to the certificate file
# @param ttl_days the minimum amount of days for which the certificate must be valid
# @return True if certificate is still valid for at least ttl_days, False otherwise
def is_cert_valid(cert, ttl_days):
now = datetime.datetime.now()
if cert.not_valid_before > now:
raise InvalidCertificateError("Certificate seems to be from the future")
expiry_limit = now + datetime.timedelta(days=ttl_days)
if cert.not_valid_after < expiry_limit:
return False
return True
# @brief create a certificate signing request
# @param names list of domain names the certificate should be valid for
# @param key the key to use with the certificate in pyopenssl format
# @param must_staple whether or not the certificate should include the OCSP must-staple flag
# @return the CSR in pyopenssl format
def new_cert_request(names, key, must_staple=False):
primary_name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME,
names[0].decode('utf-8') if getattr(names[0], 'decode', None) else
names[0])])
all_names = x509.SubjectAlternativeName(
[x509.DNSName(name.decode('utf-8') if getattr(name, 'decode', None) else name) for name in names])
req = x509.CertificateSigningRequestBuilder()
req = req.subject_name(primary_name)
req = req.add_extension(all_names, critical=False)
if must_staple:
if getattr(x509, 'TLSFeature', None):
req = req.add_extension(x509.TLSFeature(features=[x509.TLSFeatureType.status_request]), critical=False)
else:
log('OCSP must-staple ignored as current version of cryptography does not support the flag.', warning=True)
req = req.sign(key, hashes.SHA256(), default_backend())
return req
# @brief generate a new account key
# @param path path where the new key file should be written in PEM format (optional)
def new_account_key(path=None, key_algo=None, key_size=None):
return new_ssl_key(path, key_algo, key_size)
# @brief generate a new ssl key
# @param path path where the new key file should be written in PEM format (optional)
def new_ssl_key(path=None, key_algo=None, key_size=None):
if not key_algo or key_algo.lower() == 'rsa':
if not key_size:
key_size = 4096
key_format = serialization.PrivateFormat.TraditionalOpenSSL
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
backend=default_backend()
)
elif key_algo.lower() == 'ec' or key_algo.lower() == 'ecc':
if not key_size or key_size == 256:
key_curve = ec.SECP256R1
elif key_size == 384:
key_curve = ec.SECP384R1
elif key_size == 521:
key_curve = ec.SECP521R1
else:
raise ValueError("Unsupported EC curve size parameter: {}".format(key_size))
key_format = serialization.PrivateFormat.PKCS8
private_key = ec.generate_private_key(curve=key_curve, backend=default_backend())
elif key_algo.lower() == 'ed25519' and "cryptography.hazmat.primitives.asymmetric.ed25519":
key_format = serialization.PrivateFormat.PKCS8
private_key = ed25519.Ed25519PrivateKey.generate()
elif key_algo.lower() == 'ed448' and "cryptography.hazmat.primitives.asymmetric.ed448":
key_format = serialization.PrivateFormat.PKCS8
private_key = ed448.Ed448PrivateKey.generate()
else:
raise ValueError("Unsupported key algorithm: {}".format(key_algo))
if path is not None:
pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=key_format,
encryption_algorithm=serialization.NoEncryption(),
)
with io.open(path, 'wb') as pem_out:
pem_out.write(pem)
if hasattr(os, 'chmod'):
try:
os.chmod(path, int("0400", 8))
except OSError:
log('Could not set file permissions on {0}!'.format(path), warning=True)
else:
log('Keyfile permission handling unavailable on this platform', warning=True)
return private_key
# @brief read a key from file
# @param path path to file
# @param key indicate whether we are loading a key
# @param csr indicate whether we are loading a csr
# @return the key in pyopenssl format
def read_pem_file(path, key=False, csr=False):
with io.open(path, 'r') as f:
if key:
return serialization.load_pem_private_key(f.read().encode('utf-8'), None, default_backend())
elif csr:
return x509.load_pem_x509_csr(f.read().encode('utf8'), default_backend())
else:
return convert_pem_str_to_cert(f.read())
# @brief write cert data to PEM formatted file
def write_pem_file(crt, path, perms=None):
if hasattr(os, 'chmod') and os.path.exists(path):
try:
os.chmod(path, os.stat(path).st_mode | stat.S_IWRITE)
except OSError:
log('Could not make file ({0}) writable'.format(path), warning=True)
with io.open(path, "w") as f:
f.write(convert_cert_to_pem_str(crt))
if perms:
if hasattr(os, 'chmod'):
try:
os.chmod(path, perms)
except OSError:
log('Could not set file permissions ({0}) on {1}!'.format(perms, path), warning=True)
else:
log('PEM-File permission handling unavailable on this platform', warning=True)
# @brief download the issuer ca for a given certificate
# @param cert certificate data
# @returns ca certificate data
def download_issuer_ca(cert):
aia = cert.extensions.get_extension_for_oid(ExtensionOID.AUTHORITY_INFORMATION_ACCESS)
ca_issuers = None
for data in aia.value:
if data.access_method == x509.OID_CA_ISSUERS:
ca_issuers = data.access_location.value
break
if not ca_issuers:
log("Could not determine issuer CA for given certificate: {}".format(cert), error=True)
return None
log("Downloading CA certificate from {}".format(ca_issuers))
resp = get_url(ca_issuers)
code = resp.getcode()
if code >= 400:
log("Could not download issuer CA (error {}) for given certificate: {}".format(code, cert), error=True)
return None
return x509.load_der_x509_certificate(resp.read(), default_backend())
# @brief determine all san domains on a given certificate
def get_cert_domains(cert):
san_cert = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
domains = set()
domains.add(cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value)
if san_cert:
for d in san_cert.value:
domains.add(d.value)
return domains
# @brief determine certificate cn
def get_cert_cn(cert):
return "CN={}".format(cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value)
# @brief determine certificate end of validity
def get_cert_valid_until(cert):
return cert.not_valid_after
# @brief convert certificate to PEM format
# @param cert certificate object or a list thereof
# @return the certificate in PEM format
def convert_cert_to_pem_str(cert):
if not isinstance(cert, list):
cert = [cert]
result = list()
for data in cert:
result.append(data.public_bytes(serialization.Encoding.PEM).decode('utf8'))
return '\n'.join(result)
# @brief load a PEM certificate from str
# @return a certificate object or a list of objects if multiple are in the string
def convert_pem_str_to_cert(certdata):
certs = re.findall(r'(-----BEGIN CERTIFICATE-----\n[^\-]+\n-----END CERTIFICATE-----)',
certdata, re.DOTALL)
result = list()
for data in certs:
result.append(x509.load_pem_x509_certificate(data.encode('utf8'), default_backend()))
return result[0] if len(result) == 1 else result
# @brief serialize cert/csr to DER bytes
def convert_cert_to_der_bytes(data):
return data.public_bytes(serialization.Encoding.DER)
# @brief load a DER certificate from str
def convert_der_bytes_to_cert(data):
return x509.load_der_x509_certificate(data, default_backend())
# @brief determine key signing algorithm and jwk data
# @return key algorithm, signature algorithm, key numbers as a dict
def get_key_alg_and_jwk(key):
if isinstance(key, rsa.RSAPrivateKey):
# See https://tools.ietf.org/html/rfc7518#section-6.3
numbers = key.public_key().public_numbers()
return "RS256", {"kty": "RSA",
"e": bytes_to_base64url(int_to_bytes(numbers.e)),
"n": bytes_to_base64url(int_to_bytes(numbers.n))}
elif isinstance(key, ec.EllipticCurvePrivateKey):
# See https://tools.ietf.org/html/rfc7518#section-6.2
numbers = key.public_key().public_numbers()
if isinstance(numbers.curve, ec.SECP256R1):
alg = 'ES256'
crv = 'P-256'
elif isinstance(numbers.curve, ec.SECP384R1):
alg = 'ES384'
crv = 'P-384'
elif isinstance(numbers.curve, ec.SECP521R1):
alg = 'ES512'
crv = 'P-521'
else:
raise ValueError("Unsupported EC curve in key: {}".format(key))
full_octets = (int(crv[2:]) + 7) // 8
return alg, {"kty": "EC", "crv": crv,
"x": bytes_to_base64url(int_to_bytes(numbers.x, full_octets)),
"y": bytes_to_base64url(int_to_bytes(numbers.y, full_octets))}
elif "cryptography.hazmat.primitives.asymmetric.ed25519" in sys.modules and isinstance(key,
ed25519.Ed25519PrivateKey):
# See https://tools.ietf.org/html/rfc8037#appendix-A.2
return "EdDSA", {"kty": "OKP", "crv": "Ed25519",
"x": bytes_to_base64url(key.public_key().public_bytes(encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
)}
elif "cryptography.hazmat.primitives.asymmetric.ed448" in sys.modules and isinstance(key,
ed448.Ed448PrivateKey):
return "EdDSA", {"kty": "OKP", "crv": "Ed448",
"x": bytes_to_base64url(key.public_key().public_bytes(encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
)}
else:
raise ValueError("Unsupported key: {}".format(key))
# @brief sign string with key
def signature_of_str(key, string):
alg, _ = get_key_alg_and_jwk(key)
data = string.encode('utf8')
if alg == 'RS256':
return key.sign(data, padding.PKCS1v15(), hashes.SHA256())
elif alg.startswith('ES'):
full_octets = (int(alg[2:]) + 7) // 8
if alg == 'ES256':
der_sig = key.sign(data, ec.ECDSA(hashes.SHA256()))
elif alg == 'ES384':
der_sig = key.sign(data, ec.ECDSA(hashes.SHA384()))
elif alg == 'ES512':
der_sig = key.sign(data, ec.ECDSA(hashes.SHA512()))
else:
raise ValueError("Unsupported EC signature algorithm: {}".format(alg))
# convert DER signature to RAW format (https://tools.ietf.org/html/rfc7518#section-3.4)
r, s = decode_dss_signature(der_sig)
return int_to_bytes(r, full_octets) + int_to_bytes(s, full_octets)
elif alg == 'EdDSA':
return key.sign(data)
else:
raise ValueError("Unsupported signature algorithm: {}".format(alg))
# @brief hash a string
def hash_of_str(string):
account_hash = hashes.Hash(hashes.SHA256(), backend=default_backend())
account_hash.update(string.encode('utf8'))
return account_hash.finalize()
# @brief helper function to base64 encode for JSON objects
# @param b the byte-string to encode
# @return the encoded string
def bytes_to_base64url(b):
return base64.urlsafe_b64encode(b).decode('utf8').replace("=", "")
# @brief check whether existing target file is still valid or source crt has been updated
# @param target string containing the path to the target file
# @param file string containing the path to the certificate file
# @return True if target file is at least as new as the certificate, False otherwise
def target_is_current(target, file):
if not os.path.isfile(target):
return False
target_date = os.path.getmtime(target)
crt_date = os.path.getmtime(file)
return target_date >= crt_date
# @brief convert domain to idna representation (if applicable
def idna_convert(domain):
try:
if any(ord(c) >= 128 for c in domain):
# Translate IDNA domain name from a unicode domain (handle wildcards separately)
if domain.startswith('*.'):
idna_domain = "*.{}".format(domain[2:].encode('idna').decode('ascii'))
else:
idna_domain = domain.encode('idna').decode('ascii')
return idna_domain
except Exception as e:
log("Unicode domain(s) found but IDNA names could not be translated due to error: {}".format(e), error=True)
return domain
# @brief validate the OCSP status for a given certificate by the given issuer
def is_ocsp_valid(cert, issuer, hash_algo):
if hash_algo == 'sha1':
algorithm = hashes.SHA1
elif hash_algo == 'sha224':
algorithm = hashes.SHA224
elif hash_algo == 'sha256':
algorithm = hashes.SHA256
elif hash_algo == 'sha385':
algorithm = hashes.SHA384
elif hash_algo == 'sha512':
algorithm = hashes.SHA512
else:
log("Invalid hash algorithm '{}' used for OCSP validation. Validation ignored.".format(hash_algo), warning=True)
return True
if isinstance(issuer, list):
issuer = issuer[0] # First certificate in the CA chain is the immediate issuer
try:
ocsp_urls = []
aia = cert.extensions.get_extension_for_oid(ExtensionOID.AUTHORITY_INFORMATION_ACCESS)
for data in aia.value:
if data.access_method == x509.OID_OCSP:
ocsp_urls.append(data.access_location.value)
# This is a bit of a hack due to validation problems within cryptography (TODO: Check if this is still true)
# Correct replacement: ocsprequest = ocsp.OCSPRequestBuilder().add_certificate(cert, issuer, algorithm).build()
ocsprequest = ocsp.OCSPRequestBuilder((cert, issuer, (algorithm)())).build()
ocsprequestdata = ocsprequest.public_bytes(serialization.Encoding.DER)
for ocsp_url in ocsp_urls:
response = get_url(ocsp_url,
ocsprequestdata,
{
'Accept': 'application/ocsp-response',
'Content-Type': 'application/ocsp-request',
})
ocspresponsedata = response.read()
ocspresponse = ocsp.load_der_ocsp_response(ocspresponsedata)
if ocspresponse.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL \
and ocspresponse.certificate_status == ocsp.OCSPCertStatus.REVOKED:
return False
except Exception as e:
log("An exception occurred during OCSP validation (Validation will be ignored): {}".format(e), error=True)
return True