diff --git a/rvgo/fast/vm.go b/rvgo/fast/vm.go index 3e1755d6..f6b38408 100644 --- a/rvgo/fast/vm.go +++ b/rvgo/fast/vm.go @@ -760,7 +760,7 @@ func (inst *InstrumentedState) riscvStep() (outErr error) { setPC(add64(pc, toU64(4))) case 0x3B: // 011_1011: register arithmetic and logic in 32 bits rs1Value := getRegister(rs1) - rs2Value := getRegister(rs2) + rs2Value := and64(getRegister(rs2), u32Mask()) var rdValue U64 switch funct7 { case 1: // RV M extension diff --git a/rvgo/slow/vm.go b/rvgo/slow/vm.go index 458f00c6..3f8edf04 100644 --- a/rvgo/slow/vm.go +++ b/rvgo/slow/vm.go @@ -933,7 +933,7 @@ func Step(calldata []byte, po PreimageOracle) (stateHash common.Hash, outErr err setPC(add64(pc, toU64(4))) case 0x3B: // 011_1011: register arithmetic and logic in 32 bits rs1Value := getRegister(rs1) - rs2Value := getRegister(rs2) + rs2Value := and64(getRegister(rs2), u32Mask()) var rdValue U64 switch funct7.val() { case 1: // RV M extension diff --git a/rvsol/src/RISCV.sol b/rvsol/src/RISCV.sol index 1bbb681e..4919c854 100644 --- a/rvsol/src/RISCV.sol +++ b/rvsol/src/RISCV.sol @@ -1399,7 +1399,7 @@ contract RISCV is IBigStepper { case 0x3B { // 011_1011: register arithmetic and logic in 32 bits let rs1Value := getRegister(rs1) - let rs2Value := getRegister(rs2) + let rs2Value := and64(getRegister(rs2), u32Mask()) let rdValue := 0 switch funct7 case 1 { diff --git a/rvsol/test/RISCV.t.sol b/rvsol/test/RISCV.t.sol index 79849ff4..9ab04ded 100644 --- a/rvsol/test/RISCV.t.sol +++ b/rvsol/test/RISCV.t.sol @@ -607,6 +607,25 @@ contract RISCV_Test is CommonTest { assertEq(postState, outputState(expect), "unexpected post state"); } + function test_remw_by_zero_succeeds() public { + uint32 insn = encodeRType(0x3b, 27, 6, 22, 21, 1); // remw x27, x22, x21 + (State memory state, bytes memory proof) = constructRISCVState(0, insn); + state.registers[22] = 0x100f00000; //bits > 32 should be ignored + state.registers[21] = 0x200000000; // bits > 32 should be ignored, resulting in division by zero + bytes memory encodedState = encodeState(state); + + State memory expect; + expect.memRoot = state.memRoot; + expect.pc = state.pc + 4; + expect.step = state.step + 1; + expect.registers[27] = 0x00f00000; // should return original dividend (least 32 bits) + expect.registers[22] = state.registers[22]; + expect.registers[21] = state.registers[21]; + + bytes32 postState = riscv.step(encodedState, proof, 0); + assertEq(postState, outputState(expect), "unexpected post state"); + } + function test_remuw_succeeds() public { uint32 insn = encodeRType(0x3b, 30, 7, 27, 9, 1); // remuw x30, x27, x9 (State memory state, bytes memory proof) = constructRISCVState(0, insn);