diff --git a/oak_restricted_kernel/src/lib.rs b/oak_restricted_kernel/src/lib.rs index a1be29d9244..8cc935fb222 100644 --- a/oak_restricted_kernel/src/lib.rs +++ b/oak_restricted_kernel/src/lib.rs @@ -37,6 +37,8 @@ #![feature(naked_functions)] #![feature(c_size_t)] #![feature(never_type)] +#![feature(offset_of_nested)] +#![feature(asm_const)] mod acpi; mod args; diff --git a/oak_restricted_kernel/src/syscall/mod.rs b/oak_restricted_kernel/src/syscall/mod.rs index 32c14915159..102215e4e13 100644 --- a/oak_restricted_kernel/src/syscall/mod.rs +++ b/oak_restricted_kernel/src/syscall/mod.rs @@ -29,7 +29,7 @@ mod switch_process; mod tests; use alloc::boxed::Box; -use core::{arch::asm, ffi::c_void}; +use core::{arch::asm, ffi::c_void, mem::offset_of}; use oak_channel::Channel; use oak_restricted_kernel_interface::{Errno, Syscall}; @@ -51,24 +51,30 @@ use self::{ use crate::mm; /// State we need to track for system calls. -/// -/// Do not change the order of the fields here, as this is accessed from -/// assembly! -#[repr(C)] +#[repr(C, align(16))] #[derive(Debug)] struct GsData { /// Kernel stack pointer (what to set in RSP after saving user RSP). kernel_sp: VirtAddr, - /// User stack pointer. Saved from RSP after SYSCALL. - user_sp: VirtAddr, - - /// User instruction pointer (where to return after SYSCALL). Saved from - /// RCX. - user_ip: VirtAddr, + /// User context data + user_ctx: UserContext, +} - /// User flags. Saved from R11. - user_flags: usize, +#[repr(C, align(16))] +#[derive(Debug, Default)] +struct UserContext { + rsp: u64, + rcx: u64, + rdi: u64, + rsi: u64, + rdx: u64, + rbx: u64, + r8: u64, + r9: u64, + r10: u64, + r11: u64, + ymm: [[u64; 4]; 16], } pub fn enable_syscalls(channel: Box, dice_data: dice_data::DiceData) { @@ -85,9 +91,7 @@ pub fn enable_syscalls(channel: Box, dice_data: dice_data::DiceData let gsdata = Box::leak(Box::new(GsData { // Stack grows down, so SP points to the end of the page kernel_sp, - user_sp: VirtAddr::zero(), - user_ip: VirtAddr::zero(), - user_flags: 0, + user_ctx: UserContext::default(), })); KernelGsBase::write(VirtAddr::from_ptr(gsdata)); @@ -170,43 +174,38 @@ extern "C" fn syscall_entrypoint() { asm! { // Switch to the syscall stack "swapgs", // switch to kernel GS - "mov gs:[0x8], rsp", // save user RSP - "mov gs:[0x10], rcx", // save user RIP - "mov gs:[0x18], r11", // save user RFLAGS - "mov rsp, gs:[0x0]", // switch to kernel stack - // Save mutable registers other than RAX, RCX and R11 (the first is trashed for the return value; - // the latter two are are stored in gsdata). - "push rsi", - "push rdi", - "push rdx", - "push rbx", - "push r8", - "push r9", - "push r10", + // Save user context to GsData + "mov gs:[{OFFSET_RSP}], rsp", + "mov gs:[{OFFSET_RCX}], rcx", + "mov gs:[{OFFSET_RDI}], rdi", + "mov gs:[{OFFSET_RSI}], rsi", + "mov gs:[{OFFSET_RDX}], rdx", + "mov gs:[{OFFSET_RBX}], rbx", + "mov gs:[{OFFSET_R8}], r8", + "mov gs:[{OFFSET_R9}], r9", + "mov gs:[{OFFSET_R10}], r10", + "mov gs:[{OFFSET_R11}], r11", - // Make sure the stack is 16-byte aligned. - "sub rsp, 8", + // Save AVX registers to GsData + "vmovups gs:[{OFFSET_YMM} + 0*32], YMM0", + "vmovups gs:[{OFFSET_YMM} + 1*32], YMM1", + "vmovups gs:[{OFFSET_YMM} + 2*32], YMM2", + "vmovups gs:[{OFFSET_YMM} + 3*32], YMM3", + "vmovups gs:[{OFFSET_YMM} + 4*32], YMM4", + "vmovups gs:[{OFFSET_YMM} + 5*32], YMM5", + "vmovups gs:[{OFFSET_YMM} + 6*32], YMM6", + "vmovups gs:[{OFFSET_YMM} + 7*32], YMM7", + "vmovups gs:[{OFFSET_YMM} + 8*32], YMM8", + "vmovups gs:[{OFFSET_YMM} + 9*32], YMM9", + "vmovups gs:[{OFFSET_YMM} + 10*32], YMM10", + "vmovups gs:[{OFFSET_YMM} + 11*32], YMM11", + "vmovups gs:[{OFFSET_YMM} + 12*32], YMM12", + "vmovups gs:[{OFFSET_YMM} + 13*32], YMM13", + "vmovups gs:[{OFFSET_YMM} + 14*32], YMM14", + "vmovups gs:[{OFFSET_YMM} + 15*32], YMM15", - // Back up the AVX registers. - // TODO(#3329): Update interrupt handler macro to support AVX, SSE or neither. - "sub rsp, 16*32", - "vmovups [rsp + 0*32], YMM0", - "vmovups [rsp + 1*32], YMM1", - "vmovups [rsp + 2*32], YMM2", - "vmovups [rsp + 3*32], YMM3", - "vmovups [rsp + 4*32], YMM4", - "vmovups [rsp + 5*32], YMM5", - "vmovups [rsp + 6*32], YMM6", - "vmovups [rsp + 7*32], YMM7", - "vmovups [rsp + 8*32], YMM8", - "vmovups [rsp + 9*32], YMM9", - "vmovups [rsp + 10*32], YMM10", - "vmovups [rsp + 11*32], YMM11", - "vmovups [rsp + 12*32], YMM12", - "vmovups [rsp + 13*32], YMM13", - "vmovups [rsp + 14*32], YMM14", - "vmovups [rsp + 15*32], YMM15", + "mov rsp, gs:[{OFFSET_KERNEL_STACK_POINTER}]", // switch to kernel stack // Shuffle around register values to match sysv calling convention, and escape into // proper Rust code from the assembly. @@ -222,47 +221,54 @@ extern "C" fn syscall_entrypoint() { "pop r9", "add rsp, 8", - // Restore AVX registers. - "vmovups YMM0, [rsp + 0*32]", - "vmovups YMM1, [rsp + 1*32]", - "vmovups YMM2, [rsp + 2*32]", - "vmovups YMM3, [rsp + 3*32]", - "vmovups YMM4, [rsp + 4*32]", - "vmovups YMM5, [rsp + 5*32]", - "vmovups YMM6, [rsp + 6*32]", - "vmovups YMM7, [rsp + 7*32]", - "vmovups YMM8, [rsp + 8*32]", - "vmovups YMM9, [rsp + 9*32]", - "vmovups YMM10, [rsp + 10*32]", - "vmovups YMM11, [rsp + 11*32]", - "vmovups YMM12, [rsp + 12*32]", - "vmovups YMM13, [rsp + 13*32]", - "vmovups YMM14, [rsp + 14*32]", - "vmovups YMM15, [rsp + 15*32]", - "add rsp, 16*32", + // Restore AVX registers from GsData. + "vmovups YMM0, gs:[{OFFSET_YMM} + 0*32]", + "vmovups YMM1, gs:[{OFFSET_YMM} + 1*32]", + "vmovups YMM2, gs:[{OFFSET_YMM} + 2*32]", + "vmovups YMM3, gs:[{OFFSET_YMM} + 3*32]", + "vmovups YMM4, gs:[{OFFSET_YMM} + 4*32]", + "vmovups YMM5, gs:[{OFFSET_YMM} + 5*32]", + "vmovups YMM6, gs:[{OFFSET_YMM} + 6*32]", + "vmovups YMM7, gs:[{OFFSET_YMM} + 7*32]", + "vmovups YMM8, gs:[{OFFSET_YMM} + 8*32]", + "vmovups YMM9, gs:[{OFFSET_YMM} + 9*32]", + "vmovups YMM10, gs:[{OFFSET_YMM} + 10*32]", + "vmovups YMM11, gs:[{OFFSET_YMM} + 11*32]", + "vmovups YMM12, gs:[{OFFSET_YMM} + 12*32]", + "vmovups YMM13, gs:[{OFFSET_YMM} + 13*32]", + "vmovups YMM14, gs:[{OFFSET_YMM} + 14*32]", + "vmovups YMM15, gs:[{OFFSET_YMM} + 15*32]", - // Undo stack alignment. - "add rsp, 8", - // Restore scratch registers. - "pop r10", - "pop r9", - "pop r8", - "pop rbx", - "pop rdx", - "pop rdi", - "pop rsi", + // Restore scratch registers from GsData + "mov rcx, gs:[{OFFSET_RCX}]", + "mov rdi, gs:[{OFFSET_RDI}]", + "mov rsi, gs:[{OFFSET_RSI}]", + "mov rdx, gs:[{OFFSET_RDX}]", + "mov rbx, gs:[{OFFSET_RBX}]", + "mov r8, gs:[{OFFSET_R8}]", + "mov r9, gs:[{OFFSET_R9}]", + "mov r10, gs:[{OFFSET_R10}]", + "mov r11, gs:[{OFFSET_R11}]", // restore user RFLAGS - // Restore user values in preparation for SYSRET. - // We don't save the kernel stack value; we'll just overwrite what's there next time - // a syscall is invoked. - "mov rsp, gs:[0x8]", // restore user RSP - "mov rcx, gs:[0x10]", // restore user RIP - "mov r11, gs:[0x18]", // restore user RFLAGS + // Restore user RSP in preparation for SYSRET. + "mov rsp, gs:[{OFFSET_RSP}]", "swapgs", // restore user GS // Back to user code in Ring 3. "sysretq", HANDLER = sym syscall_handler, + OFFSET_KERNEL_STACK_POINTER = const(offset_of!(GsData, kernel_sp)), + OFFSET_RSP = const(offset_of!(GsData, user_ctx.rsp)), + OFFSET_RCX = const(offset_of!(GsData, user_ctx.rcx)), + OFFSET_RDI = const(offset_of!(GsData, user_ctx.rdi)), + OFFSET_RSI = const(offset_of!(GsData, user_ctx.rsi)), + OFFSET_RDX = const(offset_of!(GsData, user_ctx.rdx)), + OFFSET_RBX = const(offset_of!(GsData, user_ctx.rbx)), + OFFSET_R8 = const(offset_of!(GsData, user_ctx.r8)), + OFFSET_R9 = const(offset_of!(GsData, user_ctx.r9)), + OFFSET_R10 = const(offset_of!(GsData, user_ctx.r10)), + OFFSET_R11 = const(offset_of!(GsData, user_ctx.r11)), + OFFSET_YMM = const(offset_of!(GsData, user_ctx.ymm)), options(noreturn) } }