diff --git a/crates/cannon/src/kernel.rs b/crates/cannon/src/kernel.rs index ed13d1c..dcf56fa 100644 --- a/crates/cannon/src/kernel.rs +++ b/crates/cannon/src/kernel.rs @@ -108,7 +108,7 @@ where delta.subsec_millis(), step, self.ins_state.state.pc, - self.ins_state.state.memory.get_memory(self.ins_state.state.pc)?, + self.ins_state.state.memory.get_memory_b4(self.ins_state.state.pc)?, (step - start_step) as f64 / delta.as_secs_f64(), self.ins_state.state.memory.page_count(), self.ins_state.state.memory.usage(), diff --git a/crates/cannon/src/types.rs b/crates/cannon/src/types.rs index 24731fd..ce5c61e 100644 --- a/crates/cannon/src/types.rs +++ b/crates/cannon/src/types.rs @@ -18,7 +18,7 @@ pub struct Proof { pub step_input: Vec, pub oracle_key: Option>, pub oracle_value: Option>, - pub oracle_offset: Option, + pub oracle_offset: Option, pub oracle_input: Option>, } diff --git a/crates/mipsevm/src/memory.rs b/crates/mipsevm/src/memory.rs index 11e5f9e..9159f32 100644 --- a/crates/mipsevm/src/memory.rs +++ b/crates/mipsevm/src/memory.rs @@ -11,6 +11,8 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use std::{io::Read, rc::Rc}; +pub(crate) const PROOF_LEN: usize = 64 - 4; + /// The [Memory] struct represents the MIPS emulator's memory. #[derive(Clone, Debug, Eq, PartialEq)] pub struct Memory { @@ -62,13 +64,13 @@ impl Memory { } // Find the page and invalidate the address within it. - match self.page_lookup(address as u64 >> page::PAGE_ADDRESS_SIZE) { + match self.page_lookup(address >> page::PAGE_ADDRESS_SIZE) { Some(page) => { let mut page = page.borrow_mut(); let prev_valid = !page.valid[1]; // Invalidate the address within the page. - page.invalidate(address & page::PAGE_ADDRESS_MASK as u32)?; + page.invalidate(address & page::PAGE_ADDRESS_MASK as u64)?; // If the page was already invalid before, then nodes to the memory // root will also still be invalid. @@ -83,7 +85,8 @@ impl Memory { } // Find the generalized index of the first page covering the address - let mut g_index = ((1u64 << 32) | address as u64) >> page::PAGE_ADDRESS_SIZE; + let mut g_index = + (1u64 << (64 - page::PAGE_ADDRESS_SIZE)) | (address >> page::PAGE_ADDRESS_SIZE); // Invalidate all nodes in the branch while g_index > 0 { self.nodes.insert(g_index, None); @@ -119,7 +122,7 @@ impl Memory { pub fn merkleize_subtree(&mut self, g_index: Gindex) -> Result<[u8; 32]> { // Fetch the amount of bits required to represent the generalized index let bits = 64 - g_index.leading_zeros(); - if bits > 28 { + if bits > PROOF_LEN as u32 { anyhow::bail!("Gindex is too deep") } @@ -127,7 +130,7 @@ impl Memory { let depth_into_page = bits - 1 - page::PAGE_KEY_SIZE as u32; let page_index = (g_index >> depth_into_page) & page::PAGE_KEY_MASK as u64; return self.pages.get(&page_index).map_or( - Ok(page::ZERO_HASHES[28 - bits as usize]), + Ok(page::ZERO_HASHES[PROOF_LEN - bits as usize]), |page| { let page_g_index = (1 << depth_into_page) | (g_index & ((1 << depth_into_page) - 1)); @@ -142,7 +145,7 @@ impl Memory { match self.nodes.get(&g_index) { Some(Some(node)) => return Ok(*node), - None => return Ok(page::ZERO_HASHES[28 - bits as usize]), + None => return Ok(page::ZERO_HASHES[PROOF_LEN - bits as usize]), _ => { /* noop */ } } @@ -170,7 +173,7 @@ impl Memory { /// /// ### Returns /// - The 896 bit merkle proof for the given address. - pub fn merkle_proof(&mut self, address: Address) -> Result<[u8; 28 * 32]> { + pub fn merkle_proof(&mut self, address: Address) -> Result<[u8; PROOF_LEN * 32]> { let proof = self.traverse_branch(1, address, 0)?; proof @@ -196,19 +199,19 @@ impl Memory { address: Address, depth: u8, ) -> Result> { - if depth == 32 - 5 { - let mut proof = Vec::with_capacity(32 - 5 + 1); + if depth == (PROOF_LEN - 1) as u8 { + let mut proof = Vec::with_capacity(PROOF_LEN); proof.push(self.merkleize_subtree(parent)?); return Ok(proof); } - if depth > 32 - 5 { + if depth > (PROOF_LEN - 1) as u8 { anyhow::bail!("Traversed too deep") } let mut local = parent << 1; let mut sibling = local | 1; - if address & (1 << (31 - depth)) != 0 { + if address & (1 << (63 - depth)) != 0 { (local, sibling) = (sibling, local); } @@ -229,7 +232,7 @@ impl Memory { /// ### Returns /// - A [Result] indicating if the operation was successful. #[inline(always)] - pub fn set_memory(&mut self, address: Address, value: u32) -> Result<()> { + pub fn set_memory_b4(&mut self, address: Address, value: u32) -> Result<()> { // Address must be aligned to 4 bytes if address & 0x3 != 0 { anyhow::bail!("Unaligned memory access: {:x}", address); @@ -269,13 +272,13 @@ impl Memory { /// ### Returns /// - The 32 bit value at the given address. #[inline(always)] - pub fn get_memory(&mut self, address: Address) -> Result { + pub fn get_memory_b4(&mut self, address: Address) -> Result { // Address must be aligned to 4 bytes if address & 0x3 != 0 { anyhow::bail!("Unaligned memory access: {:x}", address); } - match self.page_lookup(address as u64 >> page::PAGE_ADDRESS_SIZE as u64) { + match self.page_lookup(address >> page::PAGE_ADDRESS_SIZE as u64) { Some(page) => { let page_address = address as usize & page::PAGE_ADDRESS_MASK; Ok(u32::from_be_bytes( @@ -286,6 +289,73 @@ impl Memory { } } + /// Set a 64 bit value in the [Memory] at a given address. + /// This will invalidate the page at the given address, or allocate a new page if it does not exist. + /// + /// ### Takes + /// - `address`: The address to set the value at. + /// - `value`: The 64 bit value to set. + /// + /// ### Returns + /// - A [Result] indicating if the operation was successful. + #[inline(always)] + pub fn set_memory(&mut self, address: Address, value: u64) -> Result<()> { + // Address must be aligned to 8 bytes or 4 bytes + if (address & 0x7 != 0) && (address & 0x3 != 0) { + anyhow::bail!("Unaligned memory access: {:x}", address); + } + + let page_index = address as PageIndex >> page::PAGE_ADDRESS_SIZE as u64; + let page_address = address as usize & page::PAGE_ADDRESS_MASK; + + // Attempt to look up the page. + // - If it does exist, invalidate it before changing it. + // - If it does not exist, allocate it. + let page = self + .page_lookup(page_index) + .map(|page| { + // If the page exists, invalidate it - the value will change. + self.invalidate(address)?; + Ok::<_, anyhow::Error>(page) + }) + .unwrap_or_else(|| { + let page = self.alloc_page(page_index)?; + let _ = page.borrow_mut().invalidate(page_address as Address); + Ok(page) + })?; + + // Copy the 64 bit value into the page + page.borrow_mut().data[page_address..page_address + 8] + .copy_from_slice(&value.to_be_bytes()); + + Ok(()) + } + + /// Retrieve a 64 bit value from the [Memory] at a given address. + /// + /// ### Takes + /// - `address`: The [Address] to retrieve the value from. + /// + /// ### Returns + /// - The 64 bit value at the given address. + #[inline(always)] + pub fn get_memory(&mut self, address: Address) -> Result { + // Address must be aligned to 8 bytes or 4 bytes + if (address & 0x7 != 0) && (address & 0x3 != 0) { + anyhow::bail!("Unaligned memory access: {:x}", address); + } + + match self.page_lookup(address >> page::PAGE_ADDRESS_SIZE as u64) { + Some(page) => { + let page_address = address as usize & page::PAGE_ADDRESS_MASK; + Ok(u64::from_be_bytes( + page.borrow().data[page_address..page_address + 8].try_into()?, + )) + } + None => Ok(0), + } + } + /// Allocate a new page in the [Memory] at a given page index. /// /// ### Takes @@ -331,7 +401,7 @@ impl Memory { if n == 0 { return Ok(()); } - address += n as u32; + address += n as u64; } Err(e) => return Err(e.into()), }; @@ -431,11 +501,11 @@ impl<'de> Deserialize<'de> for Memory { pub struct MemoryReader<'a> { memory: &'a mut Memory, address: Address, - count: u32, + count: u64, } impl<'a> MemoryReader<'a> { - pub fn new(memory: &'a mut Memory, address: Address, count: u32) -> Self { + pub fn new(memory: &'a mut Memory, address: Address, count: u64) -> Self { Self { memory, address, @@ -456,7 +526,7 @@ impl<'a> Read for MemoryReader<'a> { let start = self.address as usize & page::PAGE_ADDRESS_MASK; let mut end = page::PAGE_SIZE; - if page_index == (end_address as u64 >> page::PAGE_ADDRESS_SIZE as u64) { + if page_index == (end_address >> page::PAGE_ADDRESS_SIZE as u64) { end = end_address as usize & page::PAGE_ADDRESS_MASK; } let n = end - start; @@ -468,8 +538,8 @@ impl<'a> Read for MemoryReader<'a> { std::io::copy(&mut vec![0; n].as_slice(), &mut buf)?; } }; - self.address += n as u32; - self.count -= n as u32; + self.address += n as u64; + self.count -= n as u64; Ok(n) } } @@ -484,7 +554,7 @@ mod test { #[test] fn small_tree() { let mut memory = Memory::default(); - memory.set_memory(0x10000, 0xaabbccdd).unwrap(); + memory.set_memory_b4(0x10000, 0xaabbccdd).unwrap(); let proof = memory.merkle_proof(0x10000).unwrap(); assert_eq!([0xaa, 0xbb, 0xcc, 0xdd], proof[..4]); (0..32 - 5).for_each(|i| { @@ -496,9 +566,9 @@ mod test { #[test] fn larger_tree() { let mut memory = Memory::default(); - memory.set_memory(0x10000, 0xaabbccdd).unwrap(); - memory.set_memory(0x80004, 42).unwrap(); - memory.set_memory(0x13370000, 123).unwrap(); + memory.set_memory_b4(0x10000, 0xaabbccdd).unwrap(); + memory.set_memory_b4(0x80004, 42).unwrap(); + memory.set_memory_b4(0x13370000, 123).unwrap(); let root = memory.merkle_root().unwrap(); let proof = memory.merkle_proof(0x80004).unwrap(); assert_eq!([0x00, 0x00, 0x00, 0x2a], proof[4..8]); @@ -525,7 +595,7 @@ mod test { let mut memory = Memory::default(); let root = memory.merkle_root().unwrap(); assert_eq!( - page::ZERO_HASHES[32 - 5], + page::ZERO_HASHES[64 - 5], root, "Fully zeroed memory should have expected zero hash" ); @@ -534,10 +604,10 @@ mod test { #[test] fn empty_page() { let mut memory = Memory::default(); - memory.set_memory(0xF000, 0).unwrap(); + memory.set_memory_b4(0xF000, 0).unwrap(); let root = memory.merkle_root().unwrap(); assert_eq!( - page::ZERO_HASHES[32 - 5], + page::ZERO_HASHES[64 - 5], root, "Fully zeroed memory should have expected zero hash" ); @@ -546,10 +616,10 @@ mod test { #[test] fn single_page() { let mut memory = Memory::default(); - memory.set_memory(0xF000, 1).unwrap(); + memory.set_memory_b4(0xF000, 1).unwrap(); let root = memory.merkle_root().unwrap(); assert_ne!( - page::ZERO_HASHES[32 - 5], + page::ZERO_HASHES[64 - 5], root, "Non-zero memory should not have expected zero hash" ); @@ -558,11 +628,11 @@ mod test { #[test] fn repeat_zero() { let mut memory = Memory::default(); - memory.set_memory(0xF000, 0).unwrap(); - memory.set_memory(0xF004, 0).unwrap(); + memory.set_memory_b4(0xF000, 0).unwrap(); + memory.set_memory_b4(0xF004, 0).unwrap(); let root = memory.merkle_root().unwrap(); assert_eq!( - page::ZERO_HASHES[32 - 5], + page::ZERO_HASHES[64 - 5], root, "Still should have expected zero hash" ); @@ -572,13 +642,13 @@ mod test { fn random_few_pages() { let mut memory = Memory::default(); memory - .set_memory(page::PAGE_SIZE as Address * 3, 1) + .set_memory_b4(page::PAGE_SIZE as Address * 3, 1) .unwrap(); memory - .set_memory(page::PAGE_SIZE as Address * 5, 42) + .set_memory_b4(page::PAGE_SIZE as Address * 5, 42) .unwrap(); memory - .set_memory(page::PAGE_SIZE as Address * 6, 123) + .set_memory_b4(page::PAGE_SIZE as Address * 6, 123) .unwrap(); let p3 = memory .merkleize_subtree((1 << page::PAGE_KEY_SIZE) | 3) @@ -614,21 +684,21 @@ mod test { #[test] fn invalidate_page() { let mut memory = Memory::default(); - memory.set_memory(0xF000, 0).unwrap(); + memory.set_memory_b4(0xF000, 0).unwrap(); assert_eq!( - page::ZERO_HASHES[32 - 5], + page::ZERO_HASHES[64 - 5], memory.merkle_root().unwrap(), "Zero at first" ); - memory.set_memory(0xF004, 1).unwrap(); + memory.set_memory_b4(0xF004, 1).unwrap(); assert_ne!( - page::ZERO_HASHES[32 - 5], + page::ZERO_HASHES[64 - 5], memory.merkle_root().unwrap(), "Non-zero" ); - memory.set_memory(0xF004, 0).unwrap(); + memory.set_memory_b4(0xF004, 0).unwrap(); assert_eq!( - page::ZERO_HASHES[32 - 5], + page::ZERO_HASHES[64 - 5], memory.merkle_root().unwrap(), "Zero again" ); @@ -650,7 +720,7 @@ mod test { .set_memory_range(0, &data[..]) .expect("Should not error"); for i in [0, 4, 1000, 20_000 - 4] { - let value = memory.get_memory(i).expect("Should not error"); + let value = memory.get_memory_b4(i).expect("Should not error"); let expected = u32::from_be_bytes(data[i as usize..i as usize + 4].try_into().unwrap()); assert_eq!(expected, value, "read at {}", i); @@ -665,7 +735,7 @@ mod test { .set_memory_range(0x1337, &data[..]) .expect("Should not error"); - let mut reader = MemoryReader::new(&mut memory, 0x1337 - 10, data.len() as u32 + 20); + let mut reader = MemoryReader::new(&mut memory, 0x1337 - 10, data.len() as u64 + 20); let mut buf = Vec::with_capacity(1260); reader.read_to_end(&mut buf).unwrap(); @@ -677,33 +747,33 @@ mod test { #[test] fn read_write() { let mut memory = Memory::default(); - memory.set_memory(12, 0xaabbccdd).unwrap(); - assert_eq!(0xaabbccdd, memory.get_memory(12).unwrap()); - memory.set_memory(12, 0xaabbc1dd).unwrap(); - assert_eq!(0xaabbc1dd, memory.get_memory(12).unwrap()); + memory.set_memory_b4(12, 0xaabbccdd).unwrap(); + assert_eq!(0xaabbccdd, memory.get_memory_b4(12).unwrap()); + memory.set_memory_b4(12, 0xaabbc1dd).unwrap(); + assert_eq!(0xaabbc1dd, memory.get_memory_b4(12).unwrap()); } #[test] fn unaligned_read() { let mut memory = Memory::default(); - memory.set_memory(12, 0xaabbccdd).unwrap(); - memory.set_memory(16, 0x11223344).unwrap(); - assert!(memory.get_memory(13).is_err()); - assert!(memory.get_memory(14).is_err()); - assert!(memory.get_memory(15).is_err()); - assert_eq!(0x11223344, memory.get_memory(16).unwrap()); - assert_eq!(0, memory.get_memory(20).unwrap()); - assert_eq!(0xaabbccdd, memory.get_memory(12).unwrap()); + memory.set_memory_b4(12, 0xaabbccdd).unwrap(); + memory.set_memory_b4(16, 0x11223344).unwrap(); + assert!(memory.get_memory_b4(13).is_err()); + assert!(memory.get_memory_b4(14).is_err()); + assert!(memory.get_memory_b4(15).is_err()); + assert_eq!(0x11223344, memory.get_memory_b4(16).unwrap()); + assert_eq!(0, memory.get_memory_b4(20).unwrap()); + assert_eq!(0xaabbccdd, memory.get_memory_b4(12).unwrap()); } #[test] fn unaligned_write() { let mut memory = Memory::default(); - memory.set_memory(12, 0xaabbccdd).unwrap(); - assert!(memory.set_memory(13, 0x11223344).is_err()); - assert!(memory.set_memory(14, 0x11223344).is_err()); - assert!(memory.set_memory(15, 0x11223344).is_err()); - assert_eq!(0xaabbccdd, memory.get_memory(12).unwrap()); + memory.set_memory_b4(12, 0xaabbccdd).unwrap(); + assert!(memory.set_memory_b4(13, 0x11223344).is_err()); + assert!(memory.set_memory_b4(14, 0x11223344).is_err()); + assert!(memory.set_memory_b4(15, 0x11223344).is_err()); + assert_eq!(0xaabbccdd, memory.get_memory_b4(12).unwrap()); } } diff --git a/crates/mipsevm/src/mips/instrumented.rs b/crates/mipsevm/src/mips/instrumented.rs index 0844240..e30ebc8 100644 --- a/crates/mipsevm/src/mips/instrumented.rs +++ b/crates/mipsevm/src/mips/instrumented.rs @@ -1,11 +1,12 @@ //! This module contains the [InstrumentedState] definition. +use crate::memory::PROOF_LEN; use crate::{traits::PreimageOracle, Address, State, StepWitness}; use anyhow::Result; use std::io::{BufWriter, Write}; -pub(crate) const MIPS_EBADF: u32 = 0x9; -pub(crate) const MIPS_EINVAL: u32 = 0x16; +pub(crate) const MIPS_EBADF: u64 = 0x9; +pub(crate) const MIPS_EINVAL: u64 = 0x16; /// The [InstrumentedState] is a wrapper around [State] that contains cached machine state, /// the input and output buffers, and an implementation of the MIPS VM. @@ -25,16 +26,16 @@ pub struct InstrumentedState { /// Whether or not the memory proof generation is enabled. pub(crate) mem_proof_enabled: bool, /// The memory proof, if it is enabled. - pub(crate) mem_proof: [u8; 28 * 32], + pub(crate) mem_proof: [u8; PROOF_LEN * 32], /// The [PreimageOracle] used to fetch preimages. pub(crate) preimage_oracle: P, /// Cached pre-image data, including 8 byte length prefix pub(crate) last_preimage: Vec, /// Key for the above preimage pub(crate) last_preimage_key: [u8; 32], - /// The offset we last read from, or max u32 if nothing is read at + /// The offset we last read from, or max u64 if nothing is read at /// the current step. - pub(crate) last_preimage_offset: u32, + pub(crate) last_preimage_offset: u64, } impl InstrumentedState @@ -50,7 +51,7 @@ where std_err: BufWriter::new(std_err), last_mem_access: 0, mem_proof_enabled: false, - mem_proof: [0u8; 28 * 32], + mem_proof: [0u8; PROOF_LEN * 32], preimage_oracle: oracle, last_preimage: Vec::default(), last_preimage_key: [0u8; 32], @@ -66,15 +67,15 @@ where #[inline(always)] pub fn step(&mut self, proof: bool) -> Result> { self.mem_proof_enabled = proof; - self.last_mem_access = !0u32 as Address; - self.last_preimage_offset = !0u32; + self.last_mem_access = !0u64 as Address; + self.last_preimage_offset = !0u64; let mut witness = None; if proof { let instruction_proof = self.state.memory.merkle_proof(self.state.pc as Address)?; - let mut mem_proof = vec![0; 28 * 32 * 2]; - mem_proof[0..28 * 32].copy_from_slice(instruction_proof.as_slice()); + let mut mem_proof = vec![0; PROOF_LEN * 32 * 2]; + mem_proof[0..PROOF_LEN * 32].copy_from_slice(instruction_proof.as_slice()); witness = Some(StepWitness { state: self.state.encode_witness()?, mem_proof, @@ -86,8 +87,8 @@ where if proof { witness = witness.map(|mut wit| { - wit.mem_proof[28 * 32..].copy_from_slice(self.mem_proof.as_slice()); - if self.last_preimage_offset != u32::MAX { + wit.mem_proof[PROOF_LEN * 32..].copy_from_slice(self.mem_proof.as_slice()); + if self.last_preimage_offset != u64::MAX { wit.preimage_key = Some(self.last_preimage_key); wit.preimage_value = Some(self.last_preimage.clone()); wit.preimage_offset = Some(self.last_preimage_offset); @@ -185,8 +186,8 @@ mod test { assert_eq!(END_ADDR, ins.state.pc, "must reach end"); let mut state = ins.state.memory; let (done, result) = ( - state.get_memory((BASE_ADDR_END + 4) as Address).unwrap(), - state.get_memory((BASE_ADDR_END + 8) as Address).unwrap(), + state.get_memory_b4((BASE_ADDR_END + 4) as Address).unwrap(), + state.get_memory_b4((BASE_ADDR_END + 8) as Address).unwrap(), ); assert_eq!(done, 1, "must set done to 1"); assert_eq!(result, 1, "must have success result {:?}", f.file_name()); @@ -223,8 +224,8 @@ mod test { let mut expected_witness = [0u8; STATE_WITNESS_SIZE]; let mem_root = state.memory.merkle_root().unwrap(); expected_witness[..32].copy_from_slice(mem_root.as_slice()); - expected_witness[32 * 2 + 4 * 6] = exit_code; - expected_witness[32 * 2 + 4 * 6 + 1] = exited as u8; + expected_witness[32 * 2 + 8 * 6] = exit_code; + expected_witness[32 * 2 + 8 * 6 + 1] = exited as u8; assert_eq!(actual_witness, expected_witness, "Incorrect witness"); diff --git a/crates/mipsevm/src/mips/mips_vm.rs b/crates/mipsevm/src/mips/mips_vm.rs index a83c997..62e948a 100644 --- a/crates/mipsevm/src/mips/mips_vm.rs +++ b/crates/mipsevm/src/mips/mips_vm.rs @@ -1,5 +1,6 @@ //! This module contains the MIPS VM implementation for the [InstrumentedState]. +use crate::utils::sign_extend; use crate::{ memory::MemoryReader, mips::instrumented::{MIPS_EBADF, MIPS_EINVAL}, @@ -29,7 +30,7 @@ where pub(crate) fn read_preimage( &mut self, key: [u8; 32], - offset: u32, + offset: u64, ) -> Result<([u8; 32], usize)> { if key != self.last_preimage_key { let data = self.preimage_oracle.get(key)?; @@ -74,7 +75,7 @@ where /// /// ### Returns /// - A [Result] indicating if the step was successful. - #[inline(always)] + // #[inline(always)] pub(crate) fn inner_step(&mut self) -> Result<()> { if self.state.exited { return Ok(()); @@ -83,7 +84,7 @@ where self.state.step += 1; // Fetch the instruction - let instruction = self.state.memory.get_memory(self.state.pc as Address)?; + let instruction = self.state.memory.get_memory_b4(self.state.pc as Address)?; let opcode = instruction >> 26; // j-type j/jal @@ -91,7 +92,8 @@ where let link_reg = if opcode == 3 { 31 } else { 0 }; // Take the top 4 bits of the next PC (its 256MB region), and concatenate with the // 26-bit offset - let target = self.state.next_pc & 0xF0000000 | ((instruction & 0x03FFFFFF) << 2); + let target = self.state.next_pc & 0xFFFFFFFFF0000000 + | (((instruction & 0x03FFFFFF) << 2) as u64); return self.handle_jump(link_reg, target); } @@ -111,10 +113,10 @@ where // Don't sign extend for andi, ori, xori if (0x0c..=0x0e).contains(&opcode) { // ZeroExtImm - rt = instruction & 0xFFFF; + rt = (instruction & 0xFFFF) as u64; } else { // SignExtImm - rt = sign_extend(instruction & 0xFFFF, 16); + rt = sign_extend((instruction & 0xFFFF) as u64, 16); } } else if opcode >= 0x28 || [0x22, 0x26].contains(&opcode) { // Store rt value with store @@ -128,14 +130,14 @@ where return self.handle_branch(opcode, instruction, rt_reg, rs); } - let mut store_address: u32 = 0xFFFFFFFF; + let mut store_address: u64 = 0xFFFFFFFFFFFFFFFF; let mut mem = 0; // Memory fetch (all I-type) // We also do the load for stores if opcode >= 0x20 { // M[R[rs]+SignExtImm] - rs += sign_extend(instruction & 0xFFFF, 16); - let address = rs & 0xFFFFFFFC; + rs += sign_extend((instruction & 0xFFFF) as u64, 16); + let address = rs & 0xFFFFFFFFFFFFFFF8; self.track_mem_access(address as Address)?; mem = self.state.memory.get_memory(address as Address)?; @@ -183,7 +185,7 @@ where } // Write memory - if store_address != 0xFFFFFFFF { + if store_address != 0xFFFFFFFFFFFFFFFF { self.track_mem_access(store_address as Address)?; self.state .memory @@ -216,9 +218,9 @@ where // Adjust the size to align with the page size if the size // cannot fit within the page address mask. - let masked_size = sz & page::PAGE_ADDRESS_MASK as u32; + let masked_size = sz & page::PAGE_ADDRESS_MASK as u64; if masked_size != 0 { - sz += page::PAGE_SIZE as u32 - masked_size; + sz += page::PAGE_SIZE as u64 - masked_size; } if a0 == 0 { @@ -245,10 +247,10 @@ where // Nothing to do; Leave v0 and v1 zero, read nothing, and give no error. } Ok(Fd::PreimageRead) => { - let effective_address = (a1 & 0xFFFFFFFC) as Address; + let effective_address = (a1 & 0xFFFFFFFFFFFFFFFC) as Address; self.track_mem_access(effective_address)?; - let memory = self.state.memory.get_memory(effective_address)?; + let memory = self.state.memory.get_memory_b4(effective_address)?; let (data, mut data_len) = self .read_preimage(self.state.preimage_key, self.state.preimage_offset)?; @@ -266,9 +268,9 @@ where out_mem[alignment..alignment + data_len].copy_from_slice(&data[..data_len]); self.state .memory - .set_memory(effective_address, u32::from_be_bytes(out_mem))?; - self.state.preimage_offset += data_len as u32; - v0 = data_len as u32; + .set_memory_b4(effective_address, u32::from_be_bytes(out_mem))?; + self.state.preimage_offset += data_len as u64; + v0 = data_len as u64; } Ok(Fd::HintRead) => { // Don't actually read anything into memory, just say we read it. The @@ -276,7 +278,7 @@ where v0 = a2; } _ => { - v0 = 0xFFFFFFFF; + v0 = 0xFFFFFFFFFFFFFFFF; v1 = MIPS_EBADF; } }, @@ -318,10 +320,13 @@ where v0 = a2; } Ok(Fd::PreimageWrite) => { - let effective_address = a1 & 0xFFFFFFFC; + let effective_address = a1 & 0xFFFFFFFFFFFFFFFC; self.track_mem_access(effective_address as Address)?; - let memory = self.state.memory.get_memory(effective_address as Address)?; + let memory = self + .state + .memory + .get_memory_b4(effective_address as Address)?; let mut key = self.state.preimage_key; let alignment = a1 & 0x3; let space = 4 - alignment; @@ -345,7 +350,7 @@ where v0 = a2; } _ => { - v0 = 0xFFFFFFFF; + v0 = 0xFFFFFFFFFFFFFFFF; v1 = MIPS_EBADF; } }, @@ -359,13 +364,13 @@ where v0 = 1; // O_WRONLY } _ => { - v0 = 0xFFFFFFFF; + v0 = 0xFFFFFFFFFFFFFFFF; v1 = MIPS_EBADF; } } } else { // The command is not recognized by this kernel. - v0 = 0xFFFFFFFF; + v0 = 0xFFFFFFFFFFFFFFFF; v1 = MIPS_EINVAL; } } @@ -397,7 +402,7 @@ where opcode: u32, instruction: u32, rt_reg: u32, - rs: u32, + rs: u64, ) -> Result<()> { if self.state.next_pc != self.state.pc + 4 { anyhow::bail!("Unexpected branch in delay slot at {:x}", self.state.pc,); @@ -434,7 +439,8 @@ where self.state.pc = self.state.next_pc; if should_branch { - self.state.next_pc = prev_pc + 4 + (sign_extend(instruction & 0xFFFF, 16) << 2); + self.state.next_pc = + prev_pc + 4 + (sign_extend((instruction & 0xFFFF) as u64, 16) << 2); } else { // Branch not taken; proceed as normal. self.state.next_pc += 4; @@ -457,8 +463,8 @@ where pub(crate) fn handle_hi_lo( &mut self, fun: u32, - rs: u32, - rt: u32, + rs: u64, + rt: u64, store_reg: u32, ) -> Result<()> { let val = match fun { @@ -483,27 +489,27 @@ where 0x18 => { // mult let acc = ((rs as i32) as i64) as u64 * ((rt as i32) as i64) as u64; - self.state.hi = (acc >> 32) as u32; - self.state.lo = acc as u32; + self.state.hi = sign_extend(acc >> 32, 32); + self.state.lo = sign_extend((acc as u32) as u64, 32); 0 } 0x19 => { // multu - let acc = rs as u64 * rt as u64; - self.state.hi = (acc >> 32) as u32; - self.state.lo = acc as u32; + let acc = (rs as u32) as u64 * (rt as u32) as u64; + self.state.hi = sign_extend(acc >> 32, 32); + self.state.lo = sign_extend((acc as u32) as u64, 32); 0 } 0x1a => { // div - self.state.hi = (rs as i32 % rt as i32) as u32; - self.state.lo = (rs as i32 / rt as i32) as u32; + self.state.hi = sign_extend((rs as i32 % rt as i32) as u64, 32); + self.state.lo = sign_extend((rs as i32 / rt as i32) as u64, 32); 0 } 0x1b => { // divu - self.state.hi = rs % rt; - self.state.lo = rs / rt; + self.state.hi = sign_extend((rs as u32 % rt as u32) as u64, 32); + self.state.lo = sign_extend((rs as u32 / rt as u32) as u64, 32); 0 } _ => 0, @@ -528,7 +534,7 @@ where /// ### Returns /// - A [Result] indicating if the branch dispatch was successful. #[inline(always)] - pub(crate) fn handle_jump(&mut self, link_reg: u32, dest: u32) -> Result<()> { + pub(crate) fn handle_jump(&mut self, link_reg: u32, dest: u64) -> Result<()> { if self.state.next_pc != self.state.pc + 4 { anyhow::bail!("Unexpected jump in delay slot at {:x}", self.state.pc); } @@ -552,7 +558,7 @@ where /// ### Returns /// - A [Result] indicating if the branch dispatch was successful. #[inline(always)] - pub(crate) fn handle_rd(&mut self, store_reg: u32, val: u32, conditional: bool) -> Result<()> { + pub(crate) fn handle_rd(&mut self, store_reg: u32, val: u64, conditional: bool) -> Result<()> { if store_reg >= 32 { anyhow::bail!("Invalid register index {}", store_reg); } @@ -578,7 +584,7 @@ where /// - `Ok(n)` - The result of the instruction execution. /// - `Err(_)`: An error occurred while executing the instruction. #[inline(always)] - pub(crate) fn execute(&mut self, instruction: u32, rs: u32, rt: u32, mem: u32) -> Result { + pub(crate) fn execute(&mut self, instruction: u32, rs: u64, rt: u64, mem: u64) -> Result { // Opcodes in MIPS are 6 bits in size, and stored in the high-order bits of the big-endian // instruction. let opcode = instruction >> 26; @@ -604,19 +610,29 @@ where match fun { // sll - 0 => Ok(rt << ((instruction >> 6) & 0x1F)), + 0 => Ok(sign_extend( + (rt & 0xFFFFFFFF) << ((instruction >> 6) & 0x1F), + 32, + )), // srl - 2 => Ok(rt >> ((instruction >> 6) & 0x1F)), + 2 => Ok(sign_extend( + (rt & 0xFFFFFFFF) >> ((instruction >> 6) & 0x1F), + 32, + )), // sra 3 => { let shamt = (instruction >> 6) & 0x1F; - Ok(sign_extend(rt >> shamt, 32 - shamt)) + Ok(sign_extend((rt & 0xFFFFFFFF) >> shamt, (32 - shamt) as u64)) } - // sslv - 4 => Ok(rt << (rs & 0x1F)), + // sllv + 4 => Ok(sign_extend((rt & 0xFFFFFFFF) << (rs & 0x1F), 32)), // srlv - 6 => Ok(rt >> (rs & 0x1F)), - 7 => Ok(sign_extend(rt >> rs, 32 - rs)), + 6 => Ok(sign_extend((rt & 0xFFFFFFFF) >> (rs & 0x1F), 32)), + // srav + 7 => { + let shamt = rs & 0x1F; + Ok(sign_extend((rt & 0xFFFFFFFF) >> shamt, 32 - shamt)) + } // Functions in range [0x8, 0x1b] are handled specially by other functions. @@ -627,9 +643,9 @@ where // The rest are transformed R-type arithmetic imm instructions. // add / addu - 0x20 | 0x21 => Ok(rs + rt), + 0x20 | 0x21 => Ok(sign_extend(rs + rt, 32)), // sub / subu - 0x22 | 0x23 => Ok(rs - rt), + 0x22 | 0x23 => Ok(sign_extend(rs - rt, 32)), // and 0x24 => Ok(rs & rt), // or @@ -639,9 +655,9 @@ where // nor 0x27 => Ok(!(rs | rt)), // slti - 0x2a => Ok(((rs as i32) < (rt as i32)) as u32), + 0x2a => Ok(((rs as i32) < (rt as i32)) as u64), // sltiu - 0x2b => Ok((rs < rt) as u32), + 0x2b => Ok(((rs as u32) < (rt as u32)) as u64), _ => anyhow::bail!("Invalid function code {:x}", fun), } } else { @@ -651,14 +667,14 @@ where let fun = instruction & 0x3F; match fun { // mul - 0x02 => Ok(((rs as i32) * (rt as i32)) as u32), - // clo + 0x02 => Ok(sign_extend(((rs as i32) * (rt as i32)) as u64, 32)), + // clo / clz 0x20 | 0x21 => { let mut rs = rs; if fun == 0x20 { rs = !rs; } - let mut i = 0u32; + let mut i = 0u64; // TODO(clabby): Remove loop, do some good ol' bit twiddling instead. while rs & 0x80000000 != 0 { @@ -671,88 +687,92 @@ where } } // lui - 0x0F => Ok(rt << 16), + 0x0F => Ok(sign_extend(rt << 16, 32)), // lb - 0x20 => Ok(sign_extend((mem >> (24 - ((rs & 0x3) << 3))) & 0xFF, 8)), + 0x20 => Ok(sign_extend((mem >> (56 - ((rs & 0x7) << 3))) & 0xFF, 8)), // lh - 0x21 => Ok(sign_extend((mem >> (16 - ((rs & 0x2) << 3))) & 0xFFFF, 16)), + 0x21 => Ok(sign_extend((mem >> (48 - ((rs & 0x6) << 3))) & 0xFFFF, 16)), // lwl 0x22 => { let sl = (rs & 0x3) << 3; - let val = mem << sl; - let mask = 0xFFFFFFFF << sl; - Ok((rt & !mask) | val) + let val = ((mem >> (32 - ((rs & 0x4) << 3))) << sl) & 0xFFFFFFFF; + let mask = (0xFFFFFFFFu32 << sl) as u64; + Ok(sign_extend((rt & !mask) | val, 32)) } - // lw - 0x23 => Ok(mem), + // lw / ll + 0x23 | 0x30 => Ok(sign_extend( + (mem >> (32 - ((rs & 0x4) << 3))) & 0xFFFFFFFF, + 32, + )), // lbu - 0x24 => Ok((mem >> (24 - ((rs & 0x3) << 3))) & 0xFF), + 0x24 => Ok((mem >> (56 - ((rs & 0x7) << 3))) & 0xFF), // lhu - 0x25 => Ok((mem >> (16 - ((rs & 0x2) << 3))) & 0xFFFF), + 0x25 => Ok((mem >> (48 - ((rs & 0x6) << 3))) & 0xFFFF), // lwr 0x26 => { let sr = 24 - ((rs & 0x3) << 3); - let val = mem >> sr; - let mask = 0xFFFFFFFFu32 >> sr; - Ok((rt & !mask) | val) + let val = ((mem >> (32 - ((rs & 0x4) << 3))) >> sr) & 0xFFFFFFFF; + let mask = (0xFFFFFFFFu32 >> sr) as u64; + Ok(sign_extend((rt & !mask) | val, 32)) } // sb 0x28 => { - let sl = 24 - ((rs & 0x3) << 3); + let sl = 56 - ((rs & 0x7) << 3); let val = (rt & 0xFF) << sl; - let mask = 0xFFFFFFFF ^ (0xFF << sl); + let mask = 0xFFFFFFFFFFFFFFFF ^ (0xFF << sl); Ok((mem & mask) | val) } // sh 0x29 => { - let sl = 16 - ((rs & 0x2) << 3); + let sl = 48 - ((rs & 0x6) << 3); let val = (rt & 0xFFFF) << sl; - let mask = 0xFFFFFFFF ^ (0xFFFF << sl); + let mask = 0xFFFFFFFFFFFFFFFF ^ (0xFFFF << sl); Ok((mem & mask) | val) } // swl 0x2a => { let sr = (rs & 0x3) << 3; - let val = rt >> sr; - let mask = 0xFFFFFFFFu32 >> sr; + let val = ((rt & 0xFFFFFFFF) >> sr) << (32 - ((rs & 0x4) << 3)); + let mask = ((0xFFFFFFFFu32 >> sr) as u64) << (32 - ((rs & 0x4) << 3)); Ok((mem & !mask) | val) } - // sw - 0x2b => Ok(rt), + // sw / sc + 0x2b | 0x38 => { + let sl = 32 - ((rs & 0x4) << 3); + let val = (rt & 0xFFFFFFFF) << sl; + let mask = 0xFFFFFFFFFFFFFFFF ^ (0xFFFFFFFF << sl); + Ok((mem & mask) | val) + } // swr 0x2e => { let sl = 24 - ((rs & 0x3) << 3); - let val = rt << sl; - let mask = 0xFFFFFFFF << sl; + let val = ((rt & 0xFFFFFFFF) << sl) << (32 - ((rs & 0x4) << 3)); + let mask = ((0xFFFFFFFFu32 << sl) as u64) << (32 - ((rs & 0x4) << 3)); Ok((mem & !mask) | val) } - // ll - 0x30 => Ok(mem), - // sc - 0x38 => Ok(rt), _ => anyhow::bail!("Invalid opcode {:x}", opcode), } } } } - -/// Perform a sign extension of a value embedded in the lower bits of `data` up to -/// the `index`th bit. -/// -/// ### Takes -/// - `data`: The data to sign extend. -/// - `index`: The index of the bit to sign extend to. -/// -/// ### Returns -/// - The sign extended value. -#[inline(always)] -pub(crate) fn sign_extend(data: u32, index: u32) -> u32 { - let is_signed = (data >> (index - 1)) != 0; - let signed = ((1 << (32 - index)) - 1) << index; - let mask = (1 << index) - 1; - if is_signed { - (data & mask) | signed - } else { - data & mask - } -} +// +// /// Perform a sign extension of a value embedded in the lower bits of `data` up to +// /// the `index`th bit. +// /// +// /// ### Takes +// /// - `data`: The data to sign extend. +// /// - `index`: The index of the bit to sign extend to. +// /// +// /// ### Returns +// /// - The sign extended value. +// #[inline(always)] +// pub(crate) fn sign_extend(data: u64, index: u64) -> u64 { +// let is_signed = (data >> (index - 1)) != 0; +// let signed = ((1 << (64 - index)) - 1) << index; +// let mask = (1 << index) - 1; +// if is_signed { +// (data & mask) | signed +// } else { +// data & mask +// } +// } diff --git a/crates/mipsevm/src/page.rs b/crates/mipsevm/src/page.rs index c720409..26052cb 100644 --- a/crates/mipsevm/src/page.rs +++ b/crates/mipsevm/src/page.rs @@ -8,7 +8,7 @@ use once_cell::sync::Lazy; use crate::utils::keccak256; pub(crate) const PAGE_ADDRESS_SIZE: usize = 12; -pub(crate) const PAGE_KEY_SIZE: usize = 32 - PAGE_ADDRESS_SIZE; +pub(crate) const PAGE_KEY_SIZE: usize = 64 - PAGE_ADDRESS_SIZE; pub(crate) const PAGE_SIZE: usize = 1 << PAGE_ADDRESS_SIZE; pub(crate) const PAGE_SIZE_WORDS: usize = PAGE_SIZE >> 5; pub(crate) const PAGE_ADDRESS_MASK: usize = PAGE_SIZE - 1; diff --git a/crates/mipsevm/src/patch.rs b/crates/mipsevm/src/patch.rs index f142a57..f01fad7 100644 --- a/crates/mipsevm/src/patch.rs +++ b/crates/mipsevm/src/patch.rs @@ -1,5 +1,6 @@ //! This module contains utilities for loading ELF files into [State] objects. +use crate::utils::sign_extend; use crate::{page, Address, State}; use anyhow::Result; use elf::{abi::PT_LOAD, endian::AnyEndian, ElfBytes}; @@ -35,9 +36,9 @@ pub fn load_elf(raw: &[u8]) -> Result { let elf = ElfBytes::::minimal_parse(raw)?; let mut state = State { - pc: elf.ehdr.e_entry as u32, - next_pc: elf.ehdr.e_entry as u32 + 4, - heap: 0x20000000, + pc: elf.ehdr.e_entry, + next_pc: elf.ehdr.e_entry + 4, + heap: 0x20000000u64, ..Default::default() }; @@ -78,9 +79,9 @@ pub fn load_elf(raw: &[u8]) -> Result { } } - if header.p_vaddr + header.p_memsz >= 1 << 32 { + if header.p_vaddr + header.p_memsz >= 1 << 47 { anyhow::bail!( - "Program segment {} out of 32-bit mem range: {} - {} (size: {})", + "Program segment {} out of 64-bit mem range: {} - {} (size: {})", i, header.p_vaddr, header.p_vaddr + header.p_memsz, @@ -88,9 +89,7 @@ pub fn load_elf(raw: &[u8]) -> Result { ); } - state - .memory - .set_memory_range(header.p_vaddr as u32, reader)?; + state.memory.set_memory_range(header.p_vaddr, reader)?; } Ok(state) @@ -120,12 +119,12 @@ pub fn patch_go(raw: &[u8], state: &mut State) -> Result<()> { // 03e00008 = jr $ra = ret (pseudo instruction) // 00000000 = nop (executes with delay-slot, but does nothing) state.memory.set_memory_range( - symbol.st_value as u32, + symbol.st_value, [0x03, 0xe0, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00].as_slice(), )?; } else if name == "runtime.MemProfileRate" { // disable mem profiling, to avoid a lot of unnecessary floating point ops - state.memory.set_memory(symbol.st_value as u32, 0)?; + state.memory.set_memory_b4(symbol.st_value, 0)?; } } Ok(()) @@ -141,34 +140,36 @@ pub fn patch_go(raw: &[u8], state: &mut State) -> Result<()> { /// - `Err(_)` if the patch failed pub fn patch_stack(state: &mut State) -> Result<()> { // Setup stack pointer - let ptr = 0x7F_FF_D0_00_u32; + let ptr = 0x7F_FF_FF_FF_D0_00_u64; // Allocate 1 page for the initial stack data, and 16KB = 4 pages for the stack to grow. state.memory.set_memory_range( - ptr - 4 * page::PAGE_SIZE as u32, + sign_extend(ptr - 4 * page::PAGE_SIZE as u64, 47), [0u8; page::PAGE_SIZE * 5].as_slice(), )?; state.registers[29] = ptr; #[inline(always)] fn store_mem(st: &mut State, address: Address, value: u32) -> Result<()> { - st.memory.set_memory(address, value) + st.memory.set_memory_b4(address, value) } // init argc, argv, aux on stack - store_mem(state, ptr + 4, 0x42)?; // argc = 0 (argument count) - store_mem(state, ptr + 4 * 2, 0x35)?; // argv[n] = 0 (terminating argv) - store_mem(state, ptr + 4 * 3, 0)?; // envp[term] = 0 (no env vars) - store_mem(state, ptr + 4 * 4, 6)?; // auxv[0] = _AT_PAGESZ = 6 (key) - store_mem(state, ptr + 4 * 5, 4096)?; // auxv[1] = page size of 4 KiB (value) - (== minPhysPageSize) - store_mem(state, ptr + 4 * 6, 25)?; // auxv[2] = AT_RANDOM - store_mem(state, ptr + 4 * 7, ptr + 4 * 9)?; // auxv[3] = address of 16 bytes containing random value - store_mem(state, ptr + 4 * 8, 0)?; // auxv[term] = 0 + store_mem(state, sign_extend(ptr + 4, 47), 0x42)?; // argc = 0 (argument count) + store_mem(state, sign_extend(ptr + 4 * 2, 47), 0x35)?; // argv[n] = 0 (terminating argv) + store_mem(state, sign_extend(ptr + 4 * 3, 47), 0)?; // envp[term] = 0 (no env vars) + store_mem(state, sign_extend(ptr + 4 * 4, 47), 6)?; // auxv[0] = _AT_PAGESZ = 6 (key) + store_mem(state, sign_extend(ptr + 4 * 5, 47), 4096)?; // auxv[1] = page size of 4 KiB (value) - (== minPhysPageSize) + store_mem(state, sign_extend(ptr + 4 * 6, 47), 25)?; // auxv[2] = AT_RANDOM + state + .memory + .set_memory(sign_extend(ptr + 4 * 7, 47), sign_extend(ptr + 4 * 9, 47))?; // auxv[3] = address of 16 bytes + store_mem(state, sign_extend(ptr + 4 * 8, 47), 0)?; // auxv[term] = 0 // 16 bytes of "randomness" state .memory - .set_memory_range(ptr + 4 * 9, b"4;byfairdiceroll".as_slice())?; + .set_memory_range(sign_extend(ptr + 4 * 9, 47), b"4;byfairdiceroll".as_slice())?; Ok(()) } diff --git a/crates/mipsevm/src/state.rs b/crates/mipsevm/src/state.rs index 4d169c6..9ede701 100644 --- a/crates/mipsevm/src/state.rs +++ b/crates/mipsevm/src/state.rs @@ -17,17 +17,17 @@ pub struct State { #[serde(with = "crate::ser::fixed_32_hex")] pub preimage_key: [u8; 32], /// The preimage offset. - pub preimage_offset: u32, + pub preimage_offset: u64, /// The current program counter. - pub pc: u32, + pub pc: u64, /// The next program counter. - pub next_pc: u32, + pub next_pc: u64, /// The lo register - pub lo: u32, + pub lo: u64, /// The hi register - pub hi: u32, + pub hi: u64, /// The heap pointer - pub heap: u32, + pub heap: u64, /// The exit code of the MIPS emulator. pub exit_code: u8, /// The exited status of the MIPS emulator. @@ -35,7 +35,7 @@ pub struct State { /// The current step of the MIPS emulator. pub step: u64, /// The MIPS emulator's registers. - pub registers: [u32; 32], + pub registers: [u64; 32], /// The last hint sent to the host. #[serde(with = "crate::ser::vec_u8_hex")] pub last_hint: Vec, @@ -50,18 +50,18 @@ impl State { let mut witness: StateWitness = [0u8; STATE_WITNESS_SIZE]; witness[..32].copy_from_slice(self.memory.merkle_root()?.as_slice()); witness[32..64].copy_from_slice(self.preimage_key.as_slice()); - witness[64..68].copy_from_slice(&self.preimage_offset.to_be_bytes()); - witness[68..72].copy_from_slice(&self.pc.to_be_bytes()); - witness[72..76].copy_from_slice(&self.next_pc.to_be_bytes()); - witness[76..80].copy_from_slice(&self.lo.to_be_bytes()); - witness[80..84].copy_from_slice(&self.hi.to_be_bytes()); - witness[84..88].copy_from_slice(&self.heap.to_be_bytes()); - witness[88] = self.exit_code; - witness[89] = self.exited as u8; - witness[90..98].copy_from_slice(&self.step.to_be_bytes()); + witness[64..72].copy_from_slice(&self.preimage_offset.to_be_bytes()); + witness[72..80].copy_from_slice(&self.pc.to_be_bytes()); + witness[80..88].copy_from_slice(&self.next_pc.to_be_bytes()); + witness[88..96].copy_from_slice(&self.lo.to_be_bytes()); + witness[96..104].copy_from_slice(&self.hi.to_be_bytes()); + witness[104..112].copy_from_slice(&self.heap.to_be_bytes()); + witness[112] = self.exit_code; + witness[113] = self.exited as u8; + witness[114..122].copy_from_slice(&self.step.to_be_bytes()); for (i, r) in self.registers.iter().enumerate() { - let start = 98 + i * 4; - witness[start..start + 4].copy_from_slice(&r.to_be_bytes()); + let start = 122 + i * 8; + witness[start..start + 8].copy_from_slice(&r.to_be_bytes()); } Ok(witness) } diff --git a/crates/mipsevm/src/test_utils/evm.rs b/crates/mipsevm/src/test_utils/evm.rs index 34a4695..a9b9e00 100644 --- a/crates/mipsevm/src/test_utils/evm.rs +++ b/crates/mipsevm/src/test_utils/evm.rs @@ -320,7 +320,7 @@ mod test { let instruction = instrumented .state .memory - .get_memory(instrumented.state.pc as Address) + .get_memory_b4(instrumented.state.pc as Address) .unwrap(); println!( "{}", @@ -347,8 +347,8 @@ mod test { assert_eq!(END_ADDR, instrumented.state.pc, "must reach end"); let mut state = instrumented.state.memory; let (done, result) = ( - state.get_memory((BASE_ADDR_END + 4) as Address).unwrap(), - state.get_memory((BASE_ADDR_END + 8) as Address).unwrap(), + state.get_memory_b4((BASE_ADDR_END + 4) as Address).unwrap(), + state.get_memory_b4((BASE_ADDR_END + 8) as Address).unwrap(), ); assert_eq!(done, 1, "must set done to 1"); assert_eq!(result, 1, "must have success result {:?}", f.file_name()); @@ -385,7 +385,7 @@ mod test { let mut state = State::default(); state.pc = pc; state.next_pc = next_pc; - state.memory.set_memory(pc, instruction).unwrap(); + state.memory.set_memory_b4(pc, instruction).unwrap(); let mut instrumented = InstrumentedState::new( state, @@ -420,7 +420,7 @@ mod test { next_pc: next_pc as Address, ..Default::default() }; - state.memory.set_memory(0, instruction).unwrap(); + state.memory.set_memory_b4(0, instruction).unwrap(); // Set the return address ($ra) to jump to when the test completes. state.registers[31] = END_ADDR; @@ -472,7 +472,7 @@ mod test { let instruction = instrumented .state .memory - .get_memory(instrumented.state.pc as Address) + .get_memory_b4(instrumented.state.pc as Address) .unwrap(); println!( "step: {} pc: 0x{:08x} instruction: {:08x}", @@ -516,7 +516,7 @@ mod test { let instruction = instrumented .state .memory - .get_memory(instrumented.state.pc as Address) + .get_memory_b4(instrumented.state.pc as Address) .unwrap(); println!( "step: {} pc: 0x{:08x} instruction: {:08x}", diff --git a/crates/mipsevm/src/test_utils/mod.rs b/crates/mipsevm/src/test_utils/mod.rs index a8f11c2..f0adc4b 100644 --- a/crates/mipsevm/src/test_utils/mod.rs +++ b/crates/mipsevm/src/test_utils/mod.rs @@ -9,10 +9,10 @@ use rustc_hash::FxHashMap; pub mod evm; /// Used in tests to write the results to -pub const BASE_ADDR_END: u32 = 0xBF_FF_FF_F0; +pub const BASE_ADDR_END: u64 = 0xFF_FF_FF_FF_BF_FF_FF_F0; /// Used as the return-address for tests -pub const END_ADDR: u32 = 0xA7_EF_00_D0; +pub const END_ADDR: u64 = 0xA7_EF_00_D0; #[derive(Default)] pub struct StaticOracle { diff --git a/crates/mipsevm/src/types.rs b/crates/mipsevm/src/types.rs index cf931d4..7def7d5 100644 --- a/crates/mipsevm/src/types.rs +++ b/crates/mipsevm/src/types.rs @@ -19,7 +19,7 @@ pub type PageIndex = u64; pub type Gindex = u64; /// An [Address] is a 32 bit address in the MIPS emulator's memory. -pub type Address = u32; +pub type Address = u64; /// The [VMStatus] is an indicator within the [StateWitness] hash that indicates /// the current status of the MIPS emulator. @@ -71,10 +71,10 @@ pub enum Syscall { Fcntl = 4055, } -impl TryFrom for Syscall { +impl TryFrom for Syscall { type Error = anyhow::Error; - fn try_from(n: u32) -> Result { + fn try_from(n: u64) -> Result { match n { 4090 => Ok(Syscall::Mmap), 4045 => Ok(Syscall::Brk), diff --git a/crates/mipsevm/src/utils.rs b/crates/mipsevm/src/utils.rs index d46cf19..5fb21a4 100644 --- a/crates/mipsevm/src/utils.rs +++ b/crates/mipsevm/src/utils.rs @@ -35,3 +35,24 @@ pub(crate) fn keccak256>(input: T) -> B256 { xkcp_rs::keccak256(input.as_ref(), out.as_mut()); out } + +/// Perform a sign extension of a value embedded in the lower bits of `data` up to +/// the `index`th bit. +/// +/// ### Takes +/// - `data`: The data to sign extend. +/// - `index`: The index of the bit to sign extend to. +/// +/// ### Returns +/// - The sign extended value. +#[inline(always)] +pub(crate) fn sign_extend(data: u64, index: u64) -> u64 { + let is_signed = (data >> (index - 1)) != 0; + let signed = ((1 << (64 - index)) - 1) << index; + let mask = (1 << index) - 1; + if is_signed { + (data & mask) | signed + } else { + data & mask + } +} diff --git a/crates/mipsevm/src/witness.rs b/crates/mipsevm/src/witness.rs index 53344f8..297aa5f 100644 --- a/crates/mipsevm/src/witness.rs +++ b/crates/mipsevm/src/witness.rs @@ -1,5 +1,6 @@ //! This module contains the various witness types. +use crate::memory::PROOF_LEN; use crate::{utils::keccak256, State, StateWitness, StateWitnessHasher}; use alloy_primitives::{B256, U256}; use alloy_sol_types::{sol, SolCall}; @@ -7,12 +8,12 @@ use preimage_oracle::KeyType; use revm::primitives::Bytes; /// The size of an encoded [StateWitness] in bytes. -pub const STATE_WITNESS_SIZE: usize = 226; +pub const STATE_WITNESS_SIZE: usize = 378; impl StateWitnessHasher for StateWitness { fn state_hash(&self) -> [u8; 32] { let mut hash = keccak256(self); - let offset = 32 * 2 + 4 * 6; + let offset = 32 * 2 + 8 * 6; let exit_code = self[offset]; let exited = self[offset + 1] == 1; hash[0] = State::vm_status(exited, exit_code) as u8; @@ -33,14 +34,14 @@ pub struct StepWitness { /// The preimage value pub preimage_value: Option>, /// The preimage offset - pub preimage_offset: Option, + pub preimage_offset: Option, } impl Default for StepWitness { fn default() -> Self { Self { state: [0u8; crate::witness::STATE_WITNESS_SIZE], - mem_proof: Vec::with_capacity(28 * 32 * 2), + mem_proof: Vec::with_capacity(PROOF_LEN * 32 * 2), preimage_key: Default::default(), preimage_value: Default::default(), preimage_offset: Default::default(),