Skip to content

Commit

Permalink
refactoring: define an enum for DNS resource record types (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
keepsimple1 authored Nov 24, 2024
1 parent d117f4f commit e185d6f
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 198 deletions.
43 changes: 21 additions & 22 deletions src/dns_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
use crate::log::debug;
use crate::{
dns_parser::{
current_time_millis, split_sub_domain, DnsAddress, DnsPointer, DnsRecordBox, DnsSrv,
RR_TYPE_A, RR_TYPE_AAAA, RR_TYPE_NSEC, RR_TYPE_PTR, RR_TYPE_SRV, RR_TYPE_TXT,
current_time_millis, split_sub_domain, DnsAddress, DnsPointer, DnsRecordBox, DnsSrv, RRType,
},
service_info::valid_two_addrs_on_intf,
};
Expand Down Expand Up @@ -116,12 +115,12 @@ impl DnsCache {
&mut self,
instance: &str,
expire_at: Option<u64>,
) -> Vec<(String, u16)> {
) -> Vec<(String, RRType)> {
let Some(srv_vec) = self.srv.get_mut(instance) else {
return Vec::new();
};

let mut query_vec = vec![(instance.to_string(), RR_TYPE_SRV)];
let mut query_vec = vec![(instance.to_string(), RRType::SRV)];

for srv in srv_vec {
if let Some(new_expire) = expire_at {
Expand All @@ -133,8 +132,8 @@ impl DnsCache {
};

// Will verify addresses for the hostname.
query_vec.push((srv_record.host.clone(), RR_TYPE_A));
query_vec.push((srv_record.host.clone(), RR_TYPE_AAAA));
query_vec.push((srv_record.host.clone(), RRType::A));
query_vec.push((srv_record.host.clone(), RRType::AAAA));

if let Some(new_expire) = expire_at {
if let Some(addrs) = self.addr.get_mut(&srv_record.host) {
Expand Down Expand Up @@ -164,7 +163,7 @@ impl DnsCache {

// If it is PTR with subtype, store a mapping from the instance fullname
// to the subtype in this cache.
if incoming.get_type() == RR_TYPE_PTR {
if incoming.get_type() == RRType::PTR {
let (_, subtype_opt) = split_sub_domain(&entry_name);
if let Some(subtype) = subtype_opt {
if let Some(ptr) = incoming.any().downcast_ref::<DnsPointer>() {
Expand All @@ -177,11 +176,11 @@ impl DnsCache {

// get the existing records for the type.
let record_vec = match incoming.get_type() {
RR_TYPE_PTR => self.ptr.entry(entry_name).or_default(),
RR_TYPE_SRV => self.srv.entry(entry_name).or_default(),
RR_TYPE_TXT => self.txt.entry(entry_name).or_default(),
RR_TYPE_A | RR_TYPE_AAAA => self.addr.entry(entry_name).or_default(),
RR_TYPE_NSEC => self.nsec.entry(entry_name).or_default(),
RRType::PTR => self.ptr.entry(entry_name).or_default(),
RRType::SRV => self.srv.entry(entry_name).or_default(),
RRType::TXT => self.txt.entry(entry_name).or_default(),
RRType::A | RRType::AAAA => self.addr.entry(entry_name).or_default(),
RRType::NSEC => self.nsec.entry(entry_name).or_default(),
_ => return None,
};

Expand All @@ -208,7 +207,7 @@ impl DnsCache {
should_flush = true;

// additional checks for address records.
if rtype == RR_TYPE_A || rtype == RR_TYPE_AAAA {
if rtype == RRType::A || rtype == RRType::AAAA {
if let Some(addr) = r.any().downcast_ref::<DnsAddress>() {
if let Some(addr_b) = incoming.any().downcast_ref::<DnsAddress>() {
should_flush =
Expand Down Expand Up @@ -255,10 +254,10 @@ impl DnsCache {
let mut found = false;
let record_name = record.get_name();
let record_vec = match record.get_type() {
RR_TYPE_PTR => self.ptr.get_mut(record_name),
RR_TYPE_SRV => self.srv.get_mut(record_name),
RR_TYPE_TXT => self.txt.get_mut(record_name),
RR_TYPE_A | RR_TYPE_AAAA => self.addr.get_mut(record_name),
RRType::PTR => self.ptr.get_mut(record_name),
RRType::SRV => self.srv.get_mut(record_name),
RRType::TXT => self.txt.get_mut(record_name),
RRType::A | RRType::AAAA => self.addr.get_mut(record_name),
_ => return found,
};
if let Some(record_vec) = record_vec {
Expand Down Expand Up @@ -497,14 +496,14 @@ impl DnsCache {
pub(crate) fn get_known_answers<'a>(
&'a self,
name: &str,
qtype: u16,
qtype: RRType,
now: u64,
) -> Vec<&'a DnsRecordBox> {
let records_opt = match qtype {
RR_TYPE_PTR => self.get_ptr(name),
RR_TYPE_SRV => self.get_srv(name),
RR_TYPE_A | RR_TYPE_AAAA => self.get_addr(name),
RR_TYPE_TXT => self.get_txt(name),
RRType::PTR => self.get_ptr(name),
RRType::SRV => self.get_srv(name),
RRType::A | RRType::AAAA => self.get_addr(name),
RRType::TXT => self.get_txt(name),
_ => None,
};

Expand Down
Loading

0 comments on commit e185d6f

Please sign in to comment.