Skip to content

Commit

Permalink
Refactor API and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Karlinde committed Mar 16, 2024
1 parent ba37045 commit 386b00c
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 92 deletions.
56 changes: 3 additions & 53 deletions src/zonefilegen/cli.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,10 @@
import argparse
import logging
import pathlib
import re

import zonefilegen
import zonefilegen.parsing
import zonefilegen.core

# First line of each generated zone file should be a comment with the
# SHA-1 hex digest of the input toml file:
# ; Generated by zonefilegen, INPUT_SHA1: ea03443f2d9f8c580e73d2f8cd1016dc7174bddc
FIRST_LINE_PATTERN = re.compile(r'^\s*;.*INPUT_DIGEST:\s+(?P<digest>[0-9a-f]+)')
SOA_PATTERN = re.compile(r'^.+SOA[ \t]+(?P<mname>[\.\w]+)[ \t]+(?P<rname>[\.\w]+)[ \t]+\(\s*(?P<serial>[0-9]+)', re.MULTILINE)


def gen_zone(zone: zonefilegen.core.Zone, output_dir: pathlib.Path, soa_dict: dict, input_digest: str):
out_filepath: pathlib.Path = output_dir / f"{zone.name}zone"
logging.info(f"Generating zone file {out_filepath}")
serial_number = None
if out_filepath.exists():
with open(out_filepath, 'r') as f:
first_line_matches = FIRST_LINE_PATTERN.match(f.readline())
soa_matches = SOA_PATTERN.search(f.read(), )
old_digest = None
old_serial = None
if first_line_matches:
old_digest = first_line_matches.group('digest')
else:
logging.error(f"Existing zone file {out_filepath} was not generated by this tool. Aborting.")
exit(1)

if soa_matches:
old_serial = soa_matches.group('serial')
else:
logging.warning(f"Didn't find or recognize SOA record in existing zone file {out_filepath}. Serial number will"
" be reset.")

if old_serial and old_digest:
if old_digest != input_digest:
serial_number = (int(old_serial) + 1) % pow(2, 32)
logging.info(f"Changes detected, updating serial to {serial_number}")
else:
serial_number = int(old_serial)
logging.info(f"No changes detected, serial remains at {serial_number}")

if serial_number is None:
serial_number = 1
soa_rec = zonefilegen.generation.build_soa_record(soa_dict, serial_number)

with open(out_filepath, 'w') as f:
f.write(zonefilegen.generation.generate_header(input_digest) + '\n')
f.write(zone.generate_origin() + '\n')
f.write(zone.generate_ttl() + '\n')
f.write(soa_rec.to_line() + '\n')
for rec in zone.records:
f.write(rec.to_line() + '\n')
import zonefilegen.generation


def generate():
Expand All @@ -75,6 +25,6 @@ def generate():
(fwd_zone, reverse_zones, soa_dict, input_digest) = zonefilegen.parsing.parse_toml_file(args.input_file)

for zone in reverse_zones:
gen_zone(zone, args.output_dir, soa_dict, input_digest)
zonefilegen.generation.gen_zone(zone, args.output_dir, soa_dict, input_digest)

gen_zone(fwd_zone, args.output_dir, soa_dict, input_digest)
zonefilegen.generation.gen_zone(fwd_zone, args.output_dir, soa_dict, input_digest)
38 changes: 38 additions & 0 deletions src/zonefilegen/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ipaddress

RECORD_CLASSES = [
'IN',
'CH',
Expand Down Expand Up @@ -119,3 +121,39 @@ def generate_origin(self):

def generate_ttl(self):
return f"$TTL {self.default_ttl}"


def get_rev_zone_name(network) -> str:
"""
Cuts off the first few blocks of a reverse pointer for a network address
to create a suitable reverse zone name for a certain prefix length.
"""
if type(network) is ipaddress.IPv4Network:
divisor = 8
address_len = 32
elif type(network) is ipaddress.IPv6Network:
divisor = 4
address_len = 128
else:
raise Exception(f"Invalid network type: {network}")

blocks_to_cut = int((address_len - network.prefixlen) / divisor)
return '.'.join(network.network_address.reverse_pointer.split('.')[blocks_to_cut:None]) + '.'


def get_rev_ptr_name(address, prefix_len) -> str:
"""
Cuts off the last few blocks of a reverse pointer for an address
to create a suitable reverse pointer name for a certain prefix length.
"""
if type(address) is ipaddress.IPv4Address:
divisor = 8
address_len = 32
elif type(address) is ipaddress.IPv6Address:
divisor = 4
address_len = 128
else:
raise Exception(f"Invalid address type: {address}")

blocks_to_cut = int(((address_len - prefix_len) / divisor))
return '.'.join(address.reverse_pointer.split('.')[None:blocks_to_cut])
56 changes: 53 additions & 3 deletions src/zonefilegen/generation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import logging
import pathlib
import re
from typing import List, Tuple

import zonefilegen.core
import zonefilegen.parsing

# First line of each generated zone file should be a comment with the
# SHA-1 hex digest of the input toml file:
# ; Generated by zonefilegen, INPUT_SHA1: ea03443f2d9f8c580e73d2f8cd1016dc7174bddc
FIRST_LINE_PATTERN = re.compile(r'^\s*;.*INPUT_DIGEST:\s+(?P<digest>[0-9a-f]+)')
SOA_PATTERN = re.compile(r'^.+SOA[ \t]+(?P<mname>[\.\w]+)[ \t]+(?P<rname>[\.\w]+)[ \t]+\(\s*(?P<serial>[0-9]+)', re.MULTILINE)


def build_soa_record(soa_dict: dict, serial_number: int) -> zonefilegen.core.ResourceRecord:
Expand Down Expand Up @@ -31,7 +38,7 @@ def build_reverse_zone(network,
Checks a set of addresses if they are part of a network and
include them as PTR records in a reverse zone for that network in such case.
"""
rev_zone = zonefilegen.core.Zone(zonefilegen.parsing.get_rev_zone_name(network), default_ttl)
rev_zone = zonefilegen.core.Zone(zonefilegen.core.get_rev_zone_name(network), default_ttl)
included_ptr_names = set()

# Use the same NS records for reverse zones as for the forward zones
Expand All @@ -47,7 +54,7 @@ def build_reverse_zone(network,
for (name, ttl, addr) in ptr_candidates:
if addr in network:
rec = zonefilegen.core.ResourceRecord()
rec.name = zonefilegen.parsing.get_rev_ptr_name(addr, network.prefixlen)
rec.name = zonefilegen.core.get_rev_ptr_name(addr, network.prefixlen)
rec.record_type = 'PTR'
rec.record_class = 'IN'
rec.ttl = ttl
Expand Down Expand Up @@ -95,3 +102,46 @@ def build_fwd_zone(zone_name: str, rrset_dict: dict, default_ttl: int) -> zonefi

def generate_header(digest: str):
return f"; Generated by zonefilegen, INPUT_DIGEST: {digest}"


def gen_zone(zone: zonefilegen.core.Zone, output_dir: pathlib.Path, soa_dict: dict, input_digest: str):
out_filepath: pathlib.Path = output_dir / f"{zone.name}zone"
logging.info(f"Generating zone file {out_filepath}")
serial_number = None
if out_filepath.exists():
with open(out_filepath, 'r') as f:
first_line_matches = FIRST_LINE_PATTERN.match(f.readline())
soa_matches = SOA_PATTERN.search(f.read(), )
old_digest = None
old_serial = None
if first_line_matches:
old_digest = first_line_matches.group('digest')
else:
logging.error(f"Existing zone file {out_filepath} was not generated by this tool. Aborting.")
exit(1)

if soa_matches:
old_serial = soa_matches.group('serial')
else:
logging.warning(f"Didn't find or recognize SOA record in existing zone file {out_filepath}. Serial number will"
" be reset.")

if old_serial and old_digest:
if old_digest != input_digest:
serial_number = (int(old_serial) + 1) % pow(2, 32)
logging.info(f"Changes detected, updating serial to {serial_number}")
else:
serial_number = int(old_serial)
logging.info(f"No changes detected, serial remains at {serial_number}")

if serial_number is None:
serial_number = 1
soa_rec = zonefilegen.generation.build_soa_record(soa_dict, serial_number)

with open(out_filepath, 'w') as f:
f.write(zonefilegen.generation.generate_header(input_digest) + '\n')
f.write(zone.generate_origin() + '\n')
f.write(zone.generate_ttl() + '\n')
f.write(soa_rec.to_line() + '\n')
for rec in zone.records:
f.write(rec.to_line() + '\n')
36 changes: 0 additions & 36 deletions src/zonefilegen/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,6 @@
import zonefilegen.generation


def get_rev_zone_name(network) -> str:
"""
Cuts off the first few blocks of a reverse pointer for a network address
to create a suitable reverse zone name for a certain prefix length.
"""
if type(network) is ipaddress.IPv4Network:
divisor = 8
address_len = 32
elif type(network) is ipaddress.IPv6Network:
divisor = 4
address_len = 128
else:
raise Exception(f"Invalid network type: {network}")

blocks_to_cut = int((address_len - network.prefixlen) / divisor)
return '.'.join(network.network_address.reverse_pointer.split('.')[blocks_to_cut:None]) + '.'


def get_rev_ptr_name(address, prefix_len) -> str:
"""
Cuts off the last few blocks of a reverse pointer for an address
to create a suitable reverse pointer name for a certain prefix length.
"""
if type(address) is ipaddress.IPv4Address:
divisor = 8
address_len = 32
elif type(address) is ipaddress.IPv6Address:
divisor = 4
address_len = 128
else:
raise Exception(f"Invalid address type: {address}")

blocks_to_cut = int(((address_len - prefix_len) / divisor))
return '.'.join(address.reverse_pointer.split('.')[None:blocks_to_cut])


def parse_toml_file(input_file_path: pathlib.Path) -> Tuple[zonefilegen.core.Zone, List[zonefilegen.core.Zone], dict, str]:
"""
Parses a toml file with DNS records and generates one forward zone and one
Expand Down
20 changes: 20 additions & 0 deletions tests/test_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import zonefilegen.parsing


def test_simple_parsing():
input_file = "docs/sample.toml"

(fwd_zone, reverse_zones, soa_dict, input_digest) = zonefilegen.parsing.parse_toml_file(input_file)
assert fwd_zone.name == "example.com."
assert fwd_zone.default_ttl == 3600

found_mail1 = False
for rec in fwd_zone.records:
if rec.name == 'mail1.example.com.':
found_mail1 = True
assert rec.record_class == 'IN'
assert rec.record_type == 'A'
assert rec.data == '198.51.100.3'
assert rec.ttl == 300

assert found_mail1

0 comments on commit 386b00c

Please sign in to comment.