Skip to content

Commit

Permalink
fix caching
Browse files Browse the repository at this point in the history
  • Loading branch information
hanneary committed Jul 30, 2024
1 parent 28e7ab2 commit 2a3fb33
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 61 deletions.
1 change: 1 addition & 0 deletions control-plane/src/dnsproxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ impl<R: RngCore + Clone> DnsProxy<R> {
socket.send(&request_buffer[..packet_size]).await?;
let (amt, _) = socket.recv_from(&mut response_buffer).await?;
let response_bytes = &response_buffer[..amt];
println!("About to cache IP");
cache_ip_for_allowlist(response_bytes)?;
stream.write_all(response_bytes).await?;
stream.flush().await?;
Expand Down
39 changes: 19 additions & 20 deletions data-plane/src/dns/enclavedns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@ use super::error::DNSError;
use bytes::Bytes;
use shared::server::egress::check_dns_allowed_for_domain;
use shared::server::egress::get_cached_dns;
use shared::server::egress::{cache_ip_for_allowlist, EgressDestinations};
use shared::server::egress::EgressDestinations;
use shared::server::get_vsock_client;
use shared::server::CID::Parent;
use shared::DNS_PROXY_VSOCK_PORT;
use std::net::Ipv4Addr;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UdpSocket;
use tokio::sync::{mpsc::Receiver, Semaphore};
use tokio::time::timeout;
use trust_dns_proto::op::{Message, MessageType, OpCode, ResponseCode};
use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordType};
use trust_dns_proto::serialize::binary::{BinEncodable, BinEncoder};
use trust_dns_proto::rr::Record;
use trust_dns_proto::serialize::binary::BinEncodable;
use shared::server::egress::cache_ip_for_allowlist;

/// Empty struct for the DNS proxy that runs in the data plane
pub struct EnclaveDnsProxy;
Expand Down Expand Up @@ -136,21 +135,12 @@ impl EnclaveDnsDriver {
message.set_recursion_desired(true);
message.set_response_code(ResponseCode::NoError);

// Create an answer
let mut record: Record = Record::new();
record.set_name(Name::from_str("jsonplaceholder.typicode.com").unwrap());
record.set_record_type(RecordType::A);
record.set_dns_class(DNSClass::IN);
record.set_ttl(300);
record.set_data(Some(RData::A(trust_dns_proto::rr::rdata::A(
Ipv4Addr::new(172, 67, 167, 151),
))));
message.add_answer(record);

let bbb = message.to_bytes().unwrap();
println!("TESTTTTT::::::: {:?}", bbb.len());
let response_bytes = message.to_bytes().unwrap();
println!("TESTTTTT::::::: {:?}", response_bytes.len());

return bbb;
response_bytes
}

/// Perform a DNS lookup using the proxy running on the Host
Expand All @@ -163,17 +153,26 @@ impl EnclaveDnsDriver {
let packet = check_dns_allowed_for_domain(&dns_packet, &allowed_destinations)?;
match get_cached_dns(packet.clone()) {
Ok(record) => {
let dns_response = Self::get_dns_answer(packet.header().id(), record);
return Ok(Bytes::copy_from_slice(&dns_response));
match record {
Some(record) => {
let dns_response = Self::get_dns_answer(packet.header().id(), record);
return Ok(Bytes::copy_from_slice(&dns_response));
}
None => {
let dns_response = timeout(request_upper_bound, Self::forward_dns_lookup(dns_packet)).await??;
cache_ip_for_allowlist(&dns_response.clone())?;
Ok(dns_response)
}
}
}
Err(_) => {
// // Attempt DNS lookup wth a timeout, flatten timeout errors into a DNS Error
let dns_response =
timeout(request_upper_bound, Self::forward_dns_lookup(dns_packet)).await??;
cache_ip_for_allowlist(&dns_response.clone())?;
Ok(dns_response)
}
}

}

/// Takes a DNS lookup as `Bytes` and sends forwards it over VSock to the host process to be sent to
Expand Down
91 changes: 50 additions & 41 deletions shared/src/server/egress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::sync::Mutex;
use std::time::Duration;
use thiserror::Error;
use trust_dns_proto::op::Message;
use trust_dns_proto::rr::RData;
use trust_dns_proto::rr::Record;
use trust_dns_proto::serialize::binary::BinDecodable;
use ttl_cache::TtlCache;
Expand All @@ -33,7 +32,7 @@ pub static ALLOWED_IPS_FROM_DNS: Lazy<Mutex<TtlCache<String, String>>> =
Lazy::new(|| Mutex::new(TtlCache::new(1000)));

pub static DOMAINS_CACHED_DNS: Lazy<Mutex<TtlCache<String, Record>>> =
Lazy::new(|| Mutex::new(TtlCache::new(1000)));
Lazy::new(|| Mutex::new(TtlCache::new(1000)));

pub fn get_egress_allow_list_from_env() -> EgressDestinations {
let domain_str = std::env::var("EV_EGRESS_ALLOW_LIST").unwrap_or("".to_string());
Expand Down Expand Up @@ -80,11 +79,11 @@ pub fn check_domain_allow_list(
}
}

pub fn check_dns_allowed_for_domain<'a>(
packet: &'a [u8],
pub fn check_dns_allowed_for_domain(
packet: &[u8],
destinations: &EgressDestinations,
) -> Result<Message, EgressError> {
let parsed_packet = Message::from_bytes(&packet).unwrap();
let parsed_packet = Message::from_bytes(packet).unwrap();
parsed_packet.queries().iter().try_for_each(|q| {
let domain = q.name().to_string();
let domain = &domain[..domain.len() - 1];
Expand All @@ -93,45 +92,63 @@ pub fn check_dns_allowed_for_domain<'a>(
Ok(parsed_packet)
}

pub fn cache_ip_for_allowlist(packet: &[u8]) -> Result<Record, EgressError> {
let packet = Message::from_bytes(packet)?;
let ip = packet.answers().get(0).unwrap().data().unwrap().ip_addr().unwrap().to_string();
match get_ip_from_cache(ip)? {
Some(record) => Ok(record),
None => {
packet.answers().iter().try_for_each(|ans| {
cache_ip(
ans.data().unwrap().ip_addr().unwrap().to_string(),
ans.clone(),
)
});
Ok(packet.answers().get(0).unwrap().clone())
}
}
pub fn cache_ip_for_allowlist(packet: &[u8]) -> Result<(), EgressError> {
let parsed_packet = Message::from_bytes(packet).unwrap();
parsed_packet.answers().iter().try_for_each(|ans| {
let ip = ans.data().unwrap().ip_addr().unwrap().to_string();
println!("Caching IP: {}", ip);
cache_ip(
ip,
ans.name().to_string(),
ans.ttl(),
)
})
}

pub fn get_cached_dns(packet: Message) -> Result<Record, EgressError> {
let ip = packet.queries().get(0).unwrap();
match get_ip_from_cache(ip)? {
Some(record) => Ok(record),
None => {
packet.answers().iter().try_for_each(|ans| {
cache_ip(
ans.data().unwrap().ip_addr().unwrap().to_string(),
ans.clone(),
)
});
Ok(packet.answers().get(0).unwrap().clone())
pub fn get_ip_from_cache(ip: String) -> Result<(), EgressError> {
println!("Checking IP cache!!!!!!!!");
ALLOWED_IPS_FROM_DNS.lock().unwrap().iter().for_each(|(k, v)| println!("{}: {}", k, v));
let cache = match ALLOWED_IPS_FROM_DNS.lock() {
Ok(cache) => cache,
Err(_) => return Err(EgressError::CouldntObtainLock),
};
match cache.get(&ip) {
Some(_) => {
println!("Found IP: {} in cache", ip);
Ok(())
}
None => Err(EgressError::EgressIpNotAllowed(ip)),
}
}

pub fn get_cached_dns(parsed_packet: Message) -> Result<Option<Record>, EgressError> {
let domains = parsed_packet
.queries()
.iter()
.map(|query| query.name().to_string());
let domain = domains.clone().next().unwrap().to_string();
let cache = match DOMAINS_CACHED_DNS.lock() {
Ok(cache) => cache,
Err(_) => return Err(EgressError::CouldntObtainLock),
};
Ok(cache.get(&domain).cloned())
}

fn cache_ip(ip: String, answer: Record) -> Result<(), EgressError> {
fn cache_ip(ip: String, name: String, ttl: u32) -> Result<(), EgressError> {
let mut cache = match ALLOWED_IPS_FROM_DNS.lock() {
Ok(cache) => cache,
Err(_) => return Err(EgressError::CouldntObtainLock),
};
println!("Caching IP: {} for domain: {}", ip, name);
cache.insert(ip, name, Duration::from_secs(ttl as u64));
Ok(())
}

fn cache_dns_record(ip: String, answer: Record) -> Result<(), EgressError> {

Check failure on line 147 in shared/src/server/egress.rs

View workflow job for this annotation

GitHub Actions / clippy

function `cache_dns_record` is never used

error: function `cache_dns_record` is never used --> shared/src/server/egress.rs:147:4 | 147 | fn cache_dns_record(ip: String, answer: Record) -> Result<(), EgressError> { | ^^^^^^^^^^^^^^^^ | = note: `-D dead-code` implied by `-D warnings` = help: to override `-D warnings` add `#[allow(dead_code)]`
let mut cache = match DOMAINS_CACHED_DNS.lock() {
Ok(cache) => cache,
Err(_) => return Err(EgressError::CouldntObtainLock),
};
cache.insert(ip, answer.clone(), Duration::from_secs(answer.ttl() as u64));
Ok(())
}
Expand All @@ -152,14 +169,6 @@ pub fn check_ip_allow_list(
}
}

fn get_dns_from_cache(ip: String) -> Result<Option<Record>, EgressError> {
let cache = match DOMAINS_CACHED_DNS.lock() {
Ok(cache) => cache,
Err(_) => return Err(EgressError::CouldntObtainLock),
};
Ok(cache.get(&ip).cloned())
}

#[derive(Clone, PartialEq, Debug, Deserialize)]
pub struct EgressDestinations {
pub wildcard: Vec<String>,
Expand Down

0 comments on commit 2a3fb33

Please sign in to comment.