diff --git a/modes/dns/__init__.py b/modes/dns/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modes/dns/abstract.py b/modes/dns/abstract.py new file mode 100644 index 0000000..835a84f --- /dev/null +++ b/modes/dns/abstract.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# dns.nsupdate - rfc2136 based challenge handler +# Copyright (c) Rudolf Mayerhofer, 2018-2019 +# available under the ISC license, see LICENSE + +import dns +import dns.query +import dns.resolver +import dns.tsigkeyring +import dns.update +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes + +import tools +from modes.abstract import AbstractChallengeHandler + + +class DNSChallengeHandler(AbstractChallengeHandler): + @staticmethod + def get_challenge_type(): + return "dns-01" + + def __init__(self, config): + AbstractChallengeHandler.__init__(self, config) + self.dns_updatedomain = config.get("dns_updatedomain") + + def _determine_challenge_domain(self, domain): + if self.dns_updatedomain: + domain = self.dns_updatedomain + else: + domain = "_acme-challenge.{0}".format(domain) + + domain = dns.name.from_text(domain) + if not domain.is_absolute(): + domain = domain.concatenate(dns.name.root) + + return domain.to_text() + + @staticmethod + def _determine_txtvalue(thumbprint, token): + keyauthorization = "{0}.{1}".format(token, thumbprint) + digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) + digest.update(keyauthorization.encode('utf8')) + return tools.to_json_base64(digest.finalize()) + + def create_challenge(self, domain, thumbprint, token): + domain = self._determine_challenge_domain(domain) + txtvalue = self._determine_txtvalue(thumbprint, token) + return self.add_dns_record(domain, txtvalue) + + def add_dns_record(self, domain, txtvalue): + raise NotImplementedError + + def destroy_challenge(self, domain, thumbprint, token): + domain = self._determine_challenge_domain(domain) + txtvalue = self._determine_txtvalue(thumbprint, token) + return self.remove_dns_record(domain, txtvalue) + + def remove_dns_record(self, domain, txtvalue): + raise NotImplementedError diff --git a/modes/dns/nsupdate.py b/modes/dns/nsupdate.py new file mode 100644 index 0000000..5ef8821 --- /dev/null +++ b/modes/dns/nsupdate.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# dns.nsupdate - rfc2136 based challenge handler +# Copyright (c) Rudolf Mayerhofer, 2019 +# available under the ISC license, see LICENSE +import datetime +import ipaddress +import re +import socket +import io + +import dns +import dns.query +import dns.resolver +import dns.tsigkeyring +import dns.update + +from modes.dns.abstract import DNSChallengeHandler + + +DEFAULT_KEY_ALGORITHM = "HMAC-MD5.SIG-ALG.REG.INT" + + +class ChallengeHandler(DNSChallengeHandler): + @staticmethod + def _find_first_key_name(tsig_key_file): + try: + with io.open(tsig_key_file) as key_file: + key_struct = key_file.read() + return re.search(r"key \"?([^\"{ ]+?)\"? {.*};", key_struct, re.DOTALL).group(1) + except IOError as exc: + print(exc) + raise Exception( + "A problem was encountered opening your keyfile, %s." % tsig_key_file) + except AttributeError as exc: + print(exc) + raise Exception("Failed to find first key Name") + + @staticmethod + def _read_tsigkey(tsig_key_file, key_name): + try: + with io.open(tsig_key_file) as key_file: + key_struct = key_file.read() + key_file.close() + key_data = re.search(r"key \"?%s\"? {(.*?)};" % key_name, key_struct, re.DOTALL).group(1) + algorithm = re.search(r"algorithm ([a-zA-Z0-9_-]+?);", key_data, re.DOTALL).group(1) + tsig_secret = re.search(r"secret \"(.*?)\"", key_data, re.DOTALL).group(1) + except IOError as exc: + print(exc) + raise Exception( + "A problem was encountered opening your keyfile, %s." % tsig_key_file) + except AttributeError as exc: + print(exc) + raise Exception("Unable to decipher the keyname and secret from your key file.") + + keyring = dns.tsigkeyring.from_text({ + key_name: tsig_secret + }) + + if not algorithm: + algorithm = DEFAULT_KEY_ALGORITHM + + return keyring, algorithm + + @staticmethod + def _lookup_dns_server(domain_or_ip): + try: + return str(ipaddress.ip_address(domain_or_ip)) + except ValueError: + result = socket.getaddrinfo(domain_or_ip, 53) + if len(result) > 0: + return result[0][4][0] + else: + raise ValueError("Could not lookup dns ip for {}".format(domain_or_ip)) + + @staticmethod + def _get_soa(domain, nameserver=None): + if nameserver: + nameservers = [nameserver] + else: + nameservers = dns.resolver.get_default_resolver().nameservers + + domain = dns.name.from_text(domain) + if not domain.is_absolute(): + domain = domain.concatenate(dns.name.root) + + while domain.parent() != dns.name.root: + request = dns.message.make_query(domain, dns.rdatatype.SOA) + for nameserver in nameservers: + try: + response = dns.query.udp(request, nameserver) + if response.rcode() == dns.rcode.NOERROR: + for answer in response.answer: + for item in answer: + if item.rdtype == dns.rdatatype.SOA: + zone = domain.to_text() + authoritative_ns = item.mname.to_text().split(' ')[0] + return zone, authoritative_ns + else: + break + except dns.exception.Timeout: + # Go to next nameserver on timeout + continue + except dns.exception.DNSException: + # Break loop on any other error + break + domain = domain.parent() + raise Exception('Could not find Zone SOA for "{0}"'.format(domain)) + + @staticmethod + def get_challenge_type(): + return "dns-01" + + def __init__(self, config): + DNSChallengeHandler.__init__(self, config) + if 'nsupdate_keyfile' in config: + if 'nsupdate_keyname' in config: + nsupdate_keyname = config.get("nsupdate_keyname") + else: + nsupdate_keyname = self._find_first_key_name(config.get("nsupdate_keyfile")) + + self.keyring, self.keyalgorithm = self._read_tsigkey(config.get("nsupdate_keyfile"), nsupdate_keyname) + else: + self.keyring = dns.tsigkeyring.from_text({ + config.get("nsupdate_keyname"): config.get("nsupdate_keyvalue") + }) + self.keyalgorithm = config.get("nsupdate_keyalgorithm", DEFAULT_KEY_ALGORITHM) + self.dns_server = config.get("nsupdate_server") + self.dns_ttl = int(config.get("nsupdate_ttl", "60")) + + def _determine_zone_and_nameserverip(self, domain): + nameserver = self.dns_server + if nameserver: + nameserverip = self._lookup_dns_server(nameserver) + zone, _ = self._get_soa(domain, nameserverip) + else: + zone, nameserver = self._get_soa(domain) + nameserverip = self._lookup_dns_server(nameserver) + + return zone, nameserverip + + def add_dns_record(self, domain, txtvalue): + zone, nameserverip = self._determine_zone_and_nameserverip(domain) + update = dns.update.Update(zone, keyring=self.keyring, keyalgorithm=self.keyalgorithm) + update.add(domain, self.dns_ttl, 'TXT', txtvalue) + dns.query.tcp(update, nameserverip) + print('Added \'{} 60 IN TXT "{}"\' to {}'.format(domain, txtvalue, nameserverip)) + return datetime.datetime.now() + datetime.timedelta(seconds=2 * self.dns_ttl) + + def remove_dns_record(self, domain, txtvalue): + zone, nameserverip = self._determine_zone_and_nameserverip(domain) + update = dns.update.Update(zone, keyring=self.keyring, keyalgorithm=self.keyalgorithm) + update.delete(domain, 'TXT', txtvalue) + dns.query.tcp(update, nameserverip) + print('Deleted \'{} 60 IN TXT "{}"\' from {}'.format(domain, txtvalue, nameserverip))