From b84d859de1f70e9400ef931a23a90fe8682a3b9b Mon Sep 17 00:00:00 2001 From: TKF Date: Sun, 4 Aug 2024 13:40:10 +0800 Subject: [PATCH] Protect memory mappings --- src/area.rs | 42 ++++++++++++++++++++++--- src/set.rs | 87 +++++++++++++++++++++++++++++++++++++++++++++++++--- src/tests.rs | 56 ++++++++++++++++++++++++++++++++- 3 files changed, 175 insertions(+), 10 deletions(-) diff --git a/src/area.rs b/src/area.rs index 18cd40b..0341e44 100644 --- a/src/area.rs +++ b/src/area.rs @@ -12,11 +12,20 @@ use crate::{MappingError, MappingResult}; /// mappings, the target physical address is known when it is added to the page /// table. For lazy mappings, an empty mapping needs to be added to the page table /// to trigger a page fault. -pub trait MappingBackend: Clone { +pub trait MappingBackend: Clone { /// What to do when mapping a region within the area with the given flags. fn map(&self, start: VirtAddr, size: usize, flags: F, page_table: &mut P) -> bool; /// What to do when unmaping a memory region within the area. fn unmap(&self, start: VirtAddr, size: usize, page_table: &mut P) -> bool; + /// What to do when changing access flags. + fn protect( + &self, + start: VirtAddr, + size: usize, + old_flags: F, + new_flags: F, + page_table: &mut P, + ) -> Option; } /// A memory area represents a continuous range of virtual memory with the same @@ -24,14 +33,15 @@ pub trait MappingBackend: Clone { /// /// The target physical memory frames are determined by [`MappingBackend`] and /// may not be contiguous. -pub struct MemoryArea> { +#[derive(Clone)] +pub struct MemoryArea> { va_range: VirtAddrRange, flags: F, backend: B, _phantom: PhantomData<(F, P)>, } -impl> MemoryArea { +impl> MemoryArea { /// Creates a new memory area. pub const fn new(start: VirtAddr, size: usize, flags: F, backend: B) -> Self { Self { @@ -71,9 +81,19 @@ impl> MemoryArea { pub const fn backend(&self) -> &B { &self.backend } + + /// Changes the flags. + pub fn set_flags(&mut self, new_flags: F) { + self.flags = new_flags; + } + + /// Changes the end address of the memory area. + pub fn set_end(&mut self, new_end: VirtAddr) { + self.va_range.end = new_end; + } } -impl> MemoryArea { +impl> MemoryArea { /// Maps the whole memory area in the page table. pub(crate) fn map_area(&self, page_table: &mut P) -> MappingResult { self.backend @@ -90,6 +110,18 @@ impl> MemoryArea { .ok_or(MappingError::BadState) } + /// Changes the flags in the page table. + pub(crate) fn protect_area( + &mut self, + old_flags: F, + new_flags: F, + page_table: &mut P, + ) -> MappingResult { + self.backend + .protect(self.start(), self.size(), old_flags, new_flags, page_table) + .ok_or(MappingError::BadState) + } + /// Shrinks the memory area at the left side. /// /// The start address of the memory area is increased by `new_size`. The @@ -146,7 +178,7 @@ impl> MemoryArea { impl> fmt::Debug for MemoryArea where - F: fmt::Debug + Copy, + F: fmt::Debug + Copy + PartialEq, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("MemoryArea") diff --git a/src/set.rs b/src/set.rs index 1da2e81..1d5a71a 100644 --- a/src/set.rs +++ b/src/set.rs @@ -1,4 +1,4 @@ -use alloc::collections::BTreeMap; +use alloc::{collections::BTreeMap, vec::Vec}; use core::fmt; use memory_addr::{VirtAddr, VirtAddrRange}; @@ -6,11 +6,11 @@ use memory_addr::{VirtAddr, VirtAddrRange}; use crate::{MappingBackend, MappingError, MappingResult, MemoryArea}; /// A container that maintains memory mappings ([`MemoryArea`]). -pub struct MemorySet> { +pub struct MemorySet> { areas: BTreeMap>, } -impl> MemorySet { +impl> MemorySet { /// Creates a new memory set. pub const fn new() -> Self { Self { @@ -176,9 +176,88 @@ impl> MemorySet { self.areas.clear(); Ok(()) } + + /// Change the flags of memory mappings within the given address range. + /// + /// Memory areas with the same flags will be skipped. Memory areas that + /// are fully contained in the range or contains the range or intersects + /// with the boundary will be handled similarly to `munmap`. + pub fn protect( + &mut self, + start: VirtAddr, + size: usize, + new_flags: F, + page_table: &mut P, + ) -> MappingResult { + let end = start + size; + let mut to_insert = Vec::new(); + for (_, area) in self.areas.iter_mut() { + if area.start() >= end { + /* + * [ prot ] + * [ area ] + */ + break; + } else if area.end() <= start { + /* + * [ prot ] + * [ area ] + */ + // Do nothing + } else if area.start() >= start && area.end() <= end { + /* + * [ prot ] + * [ area ] + */ + let new_flags = area.protect_area(area.flags(), new_flags, page_table)?; + area.set_flags(new_flags); + } else if area.start() < start && area.end() > end { + /* + * [ prot ] + * [ left | area | right ] + */ + let right_part = area.split(end).unwrap(); + area.set_end(start); + + let mut middle_part = + MemoryArea::new(start, size, area.flags(), area.backend().clone()); + let new_flags = middle_part.protect_area(area.flags(), new_flags, page_table)?; + middle_part.set_flags(new_flags); + + to_insert.push((right_part.start(), right_part)); + to_insert.push((middle_part.start(), middle_part)); + } else if area.end() > end { + /* + * [ prot ] + * [ area | right ] + */ + let right_part = area.split(end).unwrap(); + + let new_flags = area.protect_area(area.flags(), new_flags, page_table)?; + area.set_flags(new_flags); + + to_insert.push((right_part.start(), right_part)); + } else { + /* + * [ prot ] + * [ left | area ] + */ + let mut right_part = area.split(start).unwrap(); + + let new_flags = right_part.protect_area(area.flags(), new_flags, page_table)?; + right_part.set_flags(new_flags); + + to_insert.push((right_part.start(), right_part)); + } + } + self.areas.extend(to_insert.into_iter()); + Ok(()) + } } -impl> fmt::Debug for MemorySet { +impl> fmt::Debug + for MemorySet +{ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_list().entries(self.areas.values()).finish() } diff --git a/src/tests.rs b/src/tests.rs index f8fc586..e5f66f1 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -32,6 +32,21 @@ impl MappingBackend for MockBackend { } true } + + fn protect( + &self, + start: VirtAddr, + size: usize, + old_flags: MockFlags, + new_flags: MockFlags, + pt: &mut MockPageTable, + ) -> Option { + let flags = (new_flags & 0x7) | (old_flags & !0x7); + for entry in pt.iter_mut().skip(start.as_usize()).take(size) { + *entry = flags; + } + Some(flags) + } } macro_rules! assert_ok { @@ -168,7 +183,7 @@ fn test_unmap_split() { } } - // Unmap [0x800, 0x900), [0x2800, 0x4400), [0x4800, 0x4900), ... + // Unmap [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ... // The areas are split into two areas. for start in (0..MAX_ADDR).step_by(0x2000) { assert_ok!(set.unmap((start + 0x800).into(), 0x100, &mut pt)); @@ -208,3 +223,42 @@ fn test_unmap_split() { assert_eq!(pt[addr], 0); } } + +#[test] +fn test_protect() { + let mut set = MockMemorySet::new(); + let mut pt = [0; MAX_ADDR]; + + // Map [0, 0x1000), [0x2000, 0x3000), [0x4000, 0x5000), ... + for start in (0..MAX_ADDR).step_by(0x2000) { + assert_ok!(set.map( + MemoryArea::new(start.into(), 0x1000, 0x7, MockBackend), + &mut pt, + false, + )); + } + assert_eq!(set.len(), 8); + + // Protect [0xc00, 0x2400), [0x2c00, 0x4400), [0x4c00, 0x6400), ... + // The areas are shrinked at the left and right boundaries. + for start in (0..MAX_ADDR).step_by(0x2000) { + assert_ok!(set.protect((start + 0xc00).into(), 0x1800, 0x1, &mut pt)); + } + dump_memory_set(&set); + assert_eq!(set.len(), 23); + + // Protect [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ... + // The areas are split into two areas. + for start in (0..MAX_ADDR).step_by(0x2000) { + assert_ok!(set.protect((start + 0x800).into(), 0x100, 0x13, &mut pt)); + } + dump_memory_set(&set); + assert_eq!(set.len(), 39); + + // Unmap all areas. + assert_ok!(set.unmap(0.into(), MAX_ADDR, &mut pt)); + assert_eq!(set.len(), 0); + for addr in 0..MAX_ADDR { + assert_eq!(pt[addr], 0); + } +}