Skip to content

Commit

Permalink
Make the page encryption intent explicit.
Browse files Browse the repository at this point in the history
Previously we just OR-ed the encryption bit to the physical address
whenever needed, but this won't work with TDX as the bits work the other
way around. Therefore, let's add a new interface that lets you deal with
page table entries and the encryptedness of said page explicitly.

Next steps are to generalize the handling of the encrypted bit out of
`paging.rs` and into the HAL.

One downside with the current implementation is that you can still call
the underlying `set_addr()`/`addr()` directly, netting you an PhysAddr
with the encryptedness state bit set unexpectedly. However, that'll
likely mean reimplementing `PageTable`/`PageTableEntry` ourselves, and
I'm questioning whether that's worth the trouble.

Bug: 350496083
Change-Id: Id7d19773749a35429459523cd156d6ca44568ea6
  • Loading branch information
andrisaar committed Aug 1, 2024
1 parent b97b04f commit 1307898
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 70 deletions.
5 changes: 3 additions & 2 deletions stage0/src/hal/base/mmio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use x86_64::{
PhysAddr, VirtAddr,
};

use crate::paging::PAGE_TABLE_REFS;
use crate::paging::{PageEncryption, PageTableEntryWithState, PAGE_TABLE_REFS};

pub struct Mmio<S: PageSize> {
pub base_address: PhysAddr,
Expand All @@ -49,9 +49,10 @@ impl<S: PageSize> Mmio<S> {
}
let mut tables = PAGE_TABLE_REFS.get().unwrap().lock();
let old_pte = tables.pt_0[mmio_memory.p1_index()].clone();
tables.pt_0[mmio_memory.p1_index()].set_addr(
tables.pt_0[mmio_memory.p1_index()].set_address(
base_address,
PageTableFlags::PRESENT | PageTableFlags::WRITABLE | PageTableFlags::NO_CACHE,
PageEncryption::Unencrypted,
);
flush_all();

Expand Down
70 changes: 29 additions & 41 deletions stage0/src/hal/sev/accept_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use x86_64::{
};
use zeroize::Zeroize;

use crate::paging::{PageEncryption, PageTableEntryWithState};

//
// Page tables come in three sizes: for 1 GiB, 2 MiB and 4 KiB pages. However,
// `PVALIDATE` can only be invoked on 2 MiB and 4 KiB pages.
Expand Down Expand Up @@ -80,7 +82,6 @@ impl<S: PageSize> MappedPage<S> {
fn pvalidate_range<S: NotGiantPageSize + ValidatablePageSize, T: PageSize, F>(
range: &PhysFrameRange<S>,
memory: &mut MappedPage<T>,
encrypted: u64,
flags: PageTableFlags,
success_counter: &AtomicUsize,
mut f: F,
Expand Down Expand Up @@ -108,7 +109,11 @@ where
.iter_mut()
.filter_map(|entry| range.next().map(|frame| (entry, frame)))
.map(|(entry, frame)| {
entry.set_addr(frame.start_address() + encrypted, PageTableFlags::PRESENT | flags)
entry.set_address(
frame.start_address(),
PageTableFlags::PRESENT | flags,
PageEncryption::Encrypted,
)
})
.count()
> 0
Expand All @@ -123,7 +128,7 @@ where
.zip(pages)
.filter(|(entry, _)| !entry.is_unused())
.map(|(entry, page)| (entry, page.pvalidate(success_counter)))
.map(|(entry, result)| result.or_else(|err| f(entry.addr(), err)))
.map(|(entry, result)| result.or_else(|err| f(entry.address(), err)))
.find(|result| result.is_err())
{
return err;
Expand Down Expand Up @@ -158,27 +163,14 @@ trait Validatable4KiB {
///
/// Args:
/// pt: pointer to the page table we can mutate to map 4 KiB pages to
/// memory encrypted: value of the encrypted bit in the page table
fn pvalidate(
&self,
pt: &mut MappedPage<Size2MiB>,
encrypted: u64,
) -> Result<(), InstructionError>;
/// memory
fn pvalidate(&self, pt: &mut MappedPage<Size2MiB>) -> Result<(), InstructionError>;
}

impl Validatable4KiB for PhysFrameRange<Size4KiB> {
fn pvalidate(
&self,
pt: &mut MappedPage<Size2MiB>,
encrypted: u64,
) -> Result<(), InstructionError> {
pvalidate_range(
self,
pt,
encrypted,
PageTableFlags::empty(),
&counters::VALIDATED_4K,
|_addr, err| match err {
fn pvalidate(&self, pt: &mut MappedPage<Size2MiB>) -> Result<(), InstructionError> {
pvalidate_range(self, pt, PageTableFlags::empty(), &counters::VALIDATED_4K, |_addr, err| {
match err {
InstructionError::ValidationStatusNotUpdated => {
// We don't treat this as an error. It only happens if SEV-SNP is not enabled,
// or it is already validated. See the PVALIDATE instruction in
Expand All @@ -187,8 +179,8 @@ impl Validatable4KiB for PhysFrameRange<Size4KiB> {
Ok(())
}
other => Err(other),
},
)
}
})
}
}

Expand All @@ -205,7 +197,6 @@ trait Validatable2MiB {
&self,
pd: &mut MappedPage<Size1GiB>,
pt: &mut MappedPage<Size2MiB>,
encrypted: u64,
) -> Result<(), InstructionError>;
}

Expand All @@ -214,26 +205,21 @@ impl Validatable2MiB for PhysFrameRange<Size2MiB> {
&self,
pd: &mut MappedPage<Size1GiB>,
pt: &mut MappedPage<Size2MiB>,
encrypted: u64,
) -> Result<(), InstructionError> {
pvalidate_range(
self,
pd,
encrypted,
PageTableFlags::HUGE_PAGE,
&counters::VALIDATED_2M,
|addr, err| match err {
InstructionError::FailSizeMismatch => {
// 2MiB is no go, fail back to 4KiB pages.
counters::ERROR_FAIL_SIZE_MISMATCH.fetch_add(1, Ordering::SeqCst);
// This will not panic as every address that is 2 MiB-aligned is by definition
// also 4 KiB-aligned.
counters::ERROR_FAIL_SIZE_MISMATCH.fetch_add(1, Ordering::SeqCst);
let start = PhysFrame::<Size4KiB>::from_start_address(PhysAddr::new(
addr.as_u64() & !encrypted,
))
.unwrap();
let start = PhysFrame::<Size4KiB>::from_start_address(addr).unwrap();
let range = PhysFrame::range(start, start + 512);
range.pvalidate(pt, encrypted)
range.pvalidate(pt)
}
InstructionError::ValidationStatusNotUpdated => {
// We don't treat this as an error. It only happens if SEV-SNP is not enabled,
Expand Down Expand Up @@ -269,7 +255,7 @@ impl<S: NotGiantPageSize> PageStateChange for PhysFrameRange<S> {

/// Calls `PVALIDATE` on all memory ranges specified in the E820 table with type
/// `RAM`.
pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) {
pub fn validate_memory(e820_table: &[BootE820Entry]) {
log::info!("starting SEV-SNP memory validation");

let mut page_tables = crate::paging::PAGE_TABLE_REFS.get().unwrap().lock();
Expand All @@ -281,9 +267,10 @@ pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) {
if page_tables.pdpt[1].flags().contains(PageTableFlags::PRESENT) {
panic!("PDPT[1] is in use");
}
page_tables.pdpt[1].set_addr(
PhysAddr::new(&validation_pd.page_table as *const _ as u64 | encrypted),
page_tables.pdpt[1].set_address(
PhysAddr::new(&validation_pd.page_table as *const _ as u64),
PageTableFlags::PRESENT,
PageEncryption::Encrypted,
);

// Page table, for validation with 4 KiB pages.
Expand All @@ -294,9 +281,10 @@ pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) {
if page_tables.pd_0[1].flags().contains(PageTableFlags::PRESENT) {
panic!("PD_0[1] is in use");
}
page_tables.pd_0[1].set_addr(
PhysAddr::new(&validation_pt.page_table as *const _ as u64 | encrypted),
page_tables.pd_0[1].set_address(
PhysAddr::new(&validation_pt.page_table as *const _ as u64),
PageTableFlags::PRESENT,
PageEncryption::Encrypted,
);

// We already pvalidated the memory in the first 640KiB of RAM in the boot
Expand Down Expand Up @@ -340,7 +328,7 @@ pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) {
PhysFrame::from_start_address(limit).unwrap(),
);
range.page_state_change(PageAssignment::Private).unwrap();
range.pvalidate(&mut validation_pt, encrypted).expect("failed to validate memory");
range.pvalidate(&mut validation_pt).expect("failed to validate memory");
}

// If hugepage_limit > hugepage_start, we've got some contiguous 2M chunks that
Expand All @@ -354,7 +342,7 @@ pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) {
);
range.page_state_change(PageAssignment::Private).unwrap();
range
.pvalidate(&mut validation_pd, &mut validation_pt, encrypted)
.pvalidate(&mut validation_pd, &mut validation_pt)
.expect("failed to validate memory");
}

Expand All @@ -368,7 +356,7 @@ pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) {
PhysFrame::from_start_address(limit_address).unwrap(),
);
range.page_state_change(PageAssignment::Private).unwrap();
range.pvalidate(&mut validation_pt, encrypted).expect("failed to validate memory");
range.pvalidate(&mut validation_pt).expect("failed to validate memory");
}
}

Expand All @@ -393,7 +381,7 @@ pub fn validate_memory(e820_table: &[BootE820Entry], encrypted: u64) {
))
.unwrap(),
);
range.pvalidate(&mut validation_pt, encrypted).expect("failed to validate SMBIOS memory");
range.pvalidate(&mut validation_pt).expect("failed to validate SMBIOS memory");

// Safety: the E820 table indicates that this is the correct memory segment.
let legacy_smbios_range_bytes = unsafe {
Expand Down
2 changes: 1 addition & 1 deletion stage0/src/hal/sev/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ pub use port::*;

pub fn accept_memory(e820_table: &[BootE820Entry]) {
if crate::sev_status().contains(SevStatus::SNP_ACTIVE) {
accept_memory::validate_memory(e820_table, crate::encrypted())
accept_memory::validate_memory(e820_table)
}
}
10 changes: 0 additions & 10 deletions stage0/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,6 @@ pub fn sev_status() -> SevStatus {
unsafe { SEV_STATUS }
}

/// Returns the location of the ENCRYPTED bit when running under AMD SEV.
pub fn encrypted() -> u64 {
#[no_mangle]
static mut ENCRYPTED: u64 = 0;

// Safety: we don't allow mutation and this is initialized in the bootstrap
// assembly.
unsafe { ENCRYPTED }
}

/// Entry point for the Rust code in the stage0 BIOS.
///
/// # Arguments
Expand Down
62 changes: 50 additions & 12 deletions stage0/src/paging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use oak_core::sync::OnceCell;
use spinning_top::Spinlock;
use x86_64::{
instructions::tlb::flush_all,
structures::paging::{page_table::PageTableFlags, PageSize, PageTable, Size2MiB, Size4KiB},
structures::paging::{
page_table::{PageTableEntry, PageTableFlags},
PageSize, PageTable, Size2MiB, Size4KiB,
},
PhysAddr,
};

Expand Down Expand Up @@ -57,6 +60,40 @@ pub struct PageTableRefs {
/// References to all the pages tables we care about.
pub static PAGE_TABLE_REFS: OnceCell<Spinlock<PageTableRefs>> = OnceCell::new();

pub enum PageEncryption {
Encrypted,
Unencrypted,
}

pub trait PageTableEntryWithState {
fn set_address(&mut self, addr: PhysAddr, flags: PageTableFlags, state: PageEncryption);
fn address(&self) -> PhysAddr;
}

/// Returns the location of the ENCRYPTED bit when running under AMD SEV.
fn encrypted() -> u64 {
#[no_mangle]
static mut ENCRYPTED: u64 = 0;

// Safety: we don't allow mutation and this is initialized in the bootstrap
// assembly.
unsafe { ENCRYPTED }
}

impl PageTableEntryWithState for PageTableEntry {
fn set_address(&mut self, addr: PhysAddr, flags: PageTableFlags, state: PageEncryption) {
let addr = match state {
PageEncryption::Encrypted => PhysAddr::new(addr.as_u64() | encrypted()),
PageEncryption::Unencrypted => addr,
};
self.set_addr(addr, flags);
}

fn address(&self) -> PhysAddr {
PhysAddr::new(self.addr().as_u64() & !encrypted())
}
}

/// Initialises the page table references.
pub fn init_page_table_refs() {
// Safety: accessing the mutable statics here is safe since we only do it once
Expand All @@ -72,14 +109,16 @@ pub fn init_page_table_refs() {
// using an identity mapping between virtual and physical addresses.
let mut pt_0 = Box::new_in(PageTable::new(), &BOOT_ALLOC);
pt_0.iter_mut().enumerate().skip(1).for_each(|(i, entry)| {
entry.set_addr(
PhysAddr::new(((i as u64) * Size4KiB::SIZE) | crate::encrypted()),
entry.set_address(
PhysAddr::new((i as u64) * Size4KiB::SIZE),
PageTableFlags::PRESENT | PageTableFlags::WRITABLE,
PageEncryption::Encrypted,
);
});
pd_0[0].set_addr(
PhysAddr::new(pt_0.as_ref() as *const _ as usize as u64 | crate::encrypted()),
pd_0[0].set_address(
PhysAddr::new(pt_0.as_ref() as *const _ as usize as u64),
PageTableFlags::PRESENT | PageTableFlags::WRITABLE,
PageEncryption::Encrypted,
);

let page_tables = PageTableRefs { pml4, pdpt, pd_0, pd_3, pt_0 };
Expand All @@ -98,9 +137,10 @@ pub fn map_additional_memory() {
let mut page_tables = PAGE_TABLE_REFS.get().expect("page tables not initiallized").lock();
let pd = &mut page_tables.pd_0;
pd.iter_mut().enumerate().skip(1).for_each(|(i, entry)| {
entry.set_addr(
PhysAddr::new(((i as u64) * Size2MiB::SIZE) | crate::encrypted()),
entry.set_address(
PhysAddr::new((i as u64) * Size2MiB::SIZE),
PageTableFlags::PRESENT | PageTableFlags::WRITABLE | PageTableFlags::HUGE_PAGE,
PageEncryption::Encrypted,
);
});
}
Expand All @@ -115,12 +155,10 @@ pub fn remap_first_huge_page() {
let mut page_tables = PAGE_TABLE_REFS.get().expect("page tables not initiallized").lock();
let pd = &mut page_tables.pd_0;

// Allow identity-op to keep the fact that the address we're talking about here
// is 0x00.
#[allow(clippy::identity_op)]
pd[0].set_addr(
PhysAddr::new(0x0 | crate::encrypted()),
pd[0].set_address(
PhysAddr::new(0x0),
PageTableFlags::PRESENT | PageTableFlags::WRITABLE | PageTableFlags::HUGE_PAGE,
PageEncryption::Encrypted,
);
}

Expand Down
13 changes: 9 additions & 4 deletions stage0/src/sev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ use x86_64::{
use zerocopy::{AsBytes, FromBytes};
use zeroize::Zeroize;

use crate::{sev_status, BootAllocator};
use crate::{
paging::{PageEncryption, PageTableEntryWithState},
sev_status, BootAllocator,
};

pub static GHCB_WRAPPER: Ghcb = Ghcb::new();

Expand Down Expand Up @@ -199,9 +202,10 @@ fn share_page(page: Page<Size4KiB>) {
{
let mut page_tables = crate::paging::PAGE_TABLE_REFS.get().unwrap().lock();
let pt = &mut page_tables.pt_0;
pt[page.p1_index()].set_addr(
pt[page.p1_index()].set_address(
PhysAddr::new(page_start),
PageTableFlags::PRESENT | PageTableFlags::WRITABLE,
PageEncryption::Unencrypted,
);
}
tlb::flush_all();
Expand Down Expand Up @@ -230,9 +234,10 @@ fn unshare_page(page: Page<Size4KiB>) {
{
let mut page_tables = crate::paging::PAGE_TABLE_REFS.get().unwrap().lock();
let pt = &mut page_tables.pt_0;
pt[page.p1_index()].set_addr(
PhysAddr::new(page_start | crate::encrypted()),
pt[page.p1_index()].set_address(
PhysAddr::new(page_start),
PageTableFlags::PRESENT | PageTableFlags::WRITABLE,
PageEncryption::Encrypted,
);
}
tlb::flush_all();
Expand Down

0 comments on commit 1307898

Please sign in to comment.