Skip to content
This repository has been archived by the owner on Sep 14, 2024. It is now read-only.

Commit

Permalink
Protect memory mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf2019 committed Aug 4, 2024
1 parent c5e582a commit b84d859
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 10 deletions.
42 changes: 37 additions & 5 deletions src/area.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,36 @@ use crate::{MappingError, MappingResult};
/// mappings, the target physical address is known when it is added to the page
/// table. For lazy mappings, an empty mapping needs to be added to the page table
/// to trigger a page fault.
pub trait MappingBackend<F: Copy, P>: Clone {
pub trait MappingBackend<F: Copy + PartialEq, P>: Clone {
/// What to do when mapping a region within the area with the given flags.
fn map(&self, start: VirtAddr, size: usize, flags: F, page_table: &mut P) -> bool;
/// What to do when unmaping a memory region within the area.
fn unmap(&self, start: VirtAddr, size: usize, page_table: &mut P) -> bool;
/// What to do when changing access flags.
fn protect(
&self,
start: VirtAddr,
size: usize,
old_flags: F,
new_flags: F,
page_table: &mut P,
) -> Option<F>;
}

/// A memory area represents a continuous range of virtual memory with the same
/// flags.
///
/// The target physical memory frames are determined by [`MappingBackend`] and
/// may not be contiguous.
pub struct MemoryArea<F: Copy, P, B: MappingBackend<F, P>> {
#[derive(Clone)]
pub struct MemoryArea<F: Copy + PartialEq, P, B: MappingBackend<F, P>> {
va_range: VirtAddrRange,
flags: F,
backend: B,
_phantom: PhantomData<(F, P)>,
}

impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
impl<F: Copy + PartialEq, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
/// Creates a new memory area.
pub const fn new(start: VirtAddr, size: usize, flags: F, backend: B) -> Self {
Self {
Expand Down Expand Up @@ -71,9 +81,19 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
pub const fn backend(&self) -> &B {
&self.backend
}

/// Changes the flags.
pub fn set_flags(&mut self, new_flags: F) {
self.flags = new_flags;
}

/// Changes the end address of the memory area.
pub fn set_end(&mut self, new_end: VirtAddr) {
self.va_range.end = new_end;
}
}

impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
impl<F: Copy + PartialEq, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
/// Maps the whole memory area in the page table.
pub(crate) fn map_area(&self, page_table: &mut P) -> MappingResult {
self.backend
Expand All @@ -90,6 +110,18 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {
.ok_or(MappingError::BadState)
}

/// Changes the flags in the page table.
pub(crate) fn protect_area(
&mut self,
old_flags: F,
new_flags: F,
page_table: &mut P,
) -> MappingResult<F> {
self.backend
.protect(self.start(), self.size(), old_flags, new_flags, page_table)
.ok_or(MappingError::BadState)
}

/// Shrinks the memory area at the left side.
///
/// The start address of the memory area is increased by `new_size`. The
Expand Down Expand Up @@ -146,7 +178,7 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemoryArea<F, P, B> {

impl<F, P, B: MappingBackend<F, P>> fmt::Debug for MemoryArea<F, P, B>
where
F: fmt::Debug + Copy,
F: fmt::Debug + Copy + PartialEq,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("MemoryArea")
Expand Down
87 changes: 83 additions & 4 deletions src/set.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use alloc::collections::BTreeMap;
use alloc::{collections::BTreeMap, vec::Vec};
use core::fmt;

use memory_addr::{VirtAddr, VirtAddrRange};

use crate::{MappingBackend, MappingError, MappingResult, MemoryArea};

/// A container that maintains memory mappings ([`MemoryArea`]).
pub struct MemorySet<F: Copy, P, B: MappingBackend<F, P>> {
pub struct MemorySet<F: Copy + PartialEq, P, B: MappingBackend<F, P>> {
areas: BTreeMap<VirtAddr, MemoryArea<F, P, B>>,
}

impl<F: Copy, P, B: MappingBackend<F, P>> MemorySet<F, P, B> {
impl<F: Copy + PartialEq, P, B: MappingBackend<F, P>> MemorySet<F, P, B> {
/// Creates a new memory set.
pub const fn new() -> Self {
Self {
Expand Down Expand Up @@ -176,9 +176,88 @@ impl<F: Copy, P, B: MappingBackend<F, P>> MemorySet<F, P, B> {
self.areas.clear();
Ok(())
}

/// Change the flags of memory mappings within the given address range.
///
/// Memory areas with the same flags will be skipped. Memory areas that
/// are fully contained in the range or contains the range or intersects
/// with the boundary will be handled similarly to `munmap`.
pub fn protect(
&mut self,
start: VirtAddr,
size: usize,
new_flags: F,
page_table: &mut P,
) -> MappingResult {
let end = start + size;
let mut to_insert = Vec::new();
for (_, area) in self.areas.iter_mut() {
if area.start() >= end {
/*
* [ prot ]
* [ area ]
*/
break;
} else if area.end() <= start {
/*
* [ prot ]
* [ area ]
*/
// Do nothing
} else if area.start() >= start && area.end() <= end {
/*
* [ prot ]
* [ area ]
*/
let new_flags = area.protect_area(area.flags(), new_flags, page_table)?;
area.set_flags(new_flags);
} else if area.start() < start && area.end() > end {
/*
* [ prot ]
* [ left | area | right ]
*/
let right_part = area.split(end).unwrap();
area.set_end(start);

let mut middle_part =
MemoryArea::new(start, size, area.flags(), area.backend().clone());
let new_flags = middle_part.protect_area(area.flags(), new_flags, page_table)?;
middle_part.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
to_insert.push((middle_part.start(), middle_part));
} else if area.end() > end {
/*
* [ prot ]
* [ area | right ]
*/
let right_part = area.split(end).unwrap();

let new_flags = area.protect_area(area.flags(), new_flags, page_table)?;
area.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
} else {
/*
* [ prot ]
* [ left | area ]
*/
let mut right_part = area.split(start).unwrap();

let new_flags = right_part.protect_area(area.flags(), new_flags, page_table)?;
right_part.set_flags(new_flags);

to_insert.push((right_part.start(), right_part));
}
}
self.areas.extend(to_insert.into_iter());
Ok(())
}
}

impl<F: Copy + fmt::Debug, P, B: MappingBackend<F, P>> fmt::Debug for MemorySet<F, P, B> {
impl<F: Copy + PartialEq + fmt::Debug, P, B: MappingBackend<F, P>> fmt::Debug
for MemorySet<F, P, B>
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_list().entries(self.areas.values()).finish()
}
Expand Down
56 changes: 55 additions & 1 deletion src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ impl MappingBackend<MockFlags, MockPageTable> for MockBackend {
}
true
}

fn protect(
&self,
start: VirtAddr,
size: usize,
old_flags: MockFlags,
new_flags: MockFlags,
pt: &mut MockPageTable,
) -> Option<MockFlags> {
let flags = (new_flags & 0x7) | (old_flags & !0x7);
for entry in pt.iter_mut().skip(start.as_usize()).take(size) {
*entry = flags;
}
Some(flags)
}
}

macro_rules! assert_ok {
Expand Down Expand Up @@ -168,7 +183,7 @@ fn test_unmap_split() {
}
}

// Unmap [0x800, 0x900), [0x2800, 0x4400), [0x4800, 0x4900), ...
// Unmap [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ...
// The areas are split into two areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.unmap((start + 0x800).into(), 0x100, &mut pt));
Expand Down Expand Up @@ -208,3 +223,42 @@ fn test_unmap_split() {
assert_eq!(pt[addr], 0);
}
}

#[test]
fn test_protect() {
let mut set = MockMemorySet::new();
let mut pt = [0; MAX_ADDR];

// Map [0, 0x1000), [0x2000, 0x3000), [0x4000, 0x5000), ...
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.map(
MemoryArea::new(start.into(), 0x1000, 0x7, MockBackend),
&mut pt,
false,
));
}
assert_eq!(set.len(), 8);

// Protect [0xc00, 0x2400), [0x2c00, 0x4400), [0x4c00, 0x6400), ...
// The areas are shrinked at the left and right boundaries.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0xc00).into(), 0x1800, 0x1, &mut pt));
}
dump_memory_set(&set);
assert_eq!(set.len(), 23);

// Protect [0x800, 0x900), [0x2800, 0x2900), [0x4800, 0x4900), ...
// The areas are split into two areas.
for start in (0..MAX_ADDR).step_by(0x2000) {
assert_ok!(set.protect((start + 0x800).into(), 0x100, 0x13, &mut pt));
}
dump_memory_set(&set);
assert_eq!(set.len(), 39);

// Unmap all areas.
assert_ok!(set.unmap(0.into(), MAX_ADDR, &mut pt));
assert_eq!(set.len(), 0);
for addr in 0..MAX_ADDR {
assert_eq!(pt[addr], 0);
}
}

0 comments on commit b84d859

Please sign in to comment.