Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Uint256SquareRoot #134

Merged
merged 33 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
213cfe4
added UpdatePc tests and fixed a bug in UpdatePc
mmk-1 Sep 22, 2023
1e1cfe7
Fixed accessing field value without Read()
mmk-1 Sep 25, 2023
39f71f6
Merge branch 'main' into main
mmk-1 Sep 26, 2023
60194fb
fixed failing tests
mmk-1 Sep 26, 2023
df8cd7f
small refactor for TestUpdatePcJump
mmk-1 Sep 26, 2023
98a3d1c
Merge branch 'NethermindEth:main' into main
mmk-1 Oct 2, 2023
1a2280f
Merge branch 'NethermindEth:main' into main
mmk-1 Oct 13, 2023
8dec869
Merge branch 'NethermindEth:main' into main
mmk-1 Oct 17, 2023
0447eff
added SquareRoot hint + test
mmk-1 Oct 17, 2023
0900990
removed usage of U256
mmk-1 Oct 18, 2023
1e0188d
fixed the hint for sqrt
mmk-1 Oct 20, 2023
1973382
removed unnecessary byte conversion in squareroot hint
mmk-1 Oct 20, 2023
b344646
u256sqrt Execute() function done. need test
mmk-1 Oct 20, 2023
e5e5562
remove unnecessary comment
mmk-1 Oct 20, 2023
a0eedf5
refactoring u256sqrt method
mmk-1 Oct 20, 2023
b793e2e
fixed uint256sqrt with added test
mmk-1 Oct 23, 2023
ae53fc3
add new test case for high bytes
mmk-1 Oct 23, 2023
e9be2d0
refactoring u256sqrt method
mmk-1 Oct 23, 2023
df57e98
more refactoring and removed a .clone()
mmk-1 Oct 23, 2023
c85abad
Merge branch 'main' into u256sqrt
mmk-1 Oct 23, 2023
084cc48
removed whitespace
mmk-1 Oct 23, 2023
b32a7b9
fix whitespace
mmk-1 Oct 23, 2023
65892ee
add new test
mmk-1 Oct 24, 2023
1a8e38a
Merge branch 'NethermindEth:main' into main
mmk-1 Oct 25, 2023
0e77687
Merge branch 'main' into u256sqrt
mmk-1 Oct 31, 2023
e12f6f7
minor fixes
mmk-1 Oct 31, 2023
8d08f1a
Merge branch 'NethermindEth:main' into main
mmk-1 Oct 31, 2023
f75ca1a
fix test error
mmk-1 Oct 31, 2023
66db9bd
removed the usage of Clone() in Uin256SqrtRoot
mmk-1 Nov 1, 2023
85052a0
add benchmark test
mmk-1 Nov 1, 2023
16bf020
minor refactor
mmk-1 Nov 2, 2023
6912eb4
Merge branch 'NethermindEth:main' into main
mmk-1 Nov 6, 2023
36d994f
Merge branch 'main' into u256sqrt
mmk-1 Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions pkg/hintrunner/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,161 @@ func (hint SquareRoot) Execute(vm *VM.VirtualMachine) error {
}
return nil
}

type Uint256SquareRoot struct {
valueLow ResOperander
valueHigh ResOperander
sqrt0 CellRefer
sqrt1 CellRefer
remainderLow CellRefer
remainderHigh CellRefer
sqrtMul2MinusRemainderGeU128 CellRefer
}

func (hint Uint256SquareRoot) String() string {
return "Uint256SquareRoot"
}

func (hint Uint256SquareRoot) Execute(vm *VM.VirtualMachine) error {
valueLow, err := hint.valueLow.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve valueLow operand %s: %v", hint.valueLow, err)
}

valueHigh, err := hint.valueHigh.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve valueHigh operand %s: %v", hint.valueHigh, err)
}

valueLowFelt, err := valueLow.FieldElement()
if err != nil {
return err
}

valueHighFelt, err := valueHigh.FieldElement()
if err != nil {
return err
}

// value = {value_low} + {value_high} * 2**128
valueLowU256 := uint256.Int(valueLowFelt.Bits())
valueHighU256 := uint256.Int(valueHighFelt.Bits())
valueHighU256.Lsh(&valueHighU256, 128)
value := valueHighU256.Add(&valueHighU256, &valueLowU256)
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved

// root = math.isqrt(value)
root := value.Clone()
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved
root.Sqrt(value)

// remainder = value - root ** 2
root2 := root.Clone()
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved
root2.Mul(root, root)
remainder := value.Clone()
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved
remainder.Sub(value, root2)

// memory{sqrt0} = root & 0xFFFFFFFFFFFFFFFF
// memory{sqrt1} = root >> 64
mask64 := uint256.NewInt(0xFFFFFFFFFFFFFFFF)
rootMasked := root.Clone()
rootMasked.And(root, mask64)
rootShifted := root.Rsh(root, 64)

sqrt0 := f.Element{}
sqrt0.SetBytes(rootMasked.Bytes())
fmt.Println(rootMasked.Dec())
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved

sqrt1 := f.Element{}
sqrt1.SetBytes(rootShifted.Bytes())
fmt.Println(rootShifted.Dec())
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved

sqrt0Addr, err := hint.sqrt0.Get(vm)
if err != nil {
return fmt.Errorf("get sqrt0 cell: %v", err)
}

sqrt1Addr, err := hint.sqrt1.Get(vm)
if err != nil {
return fmt.Errorf("get sqrt1 cell: %v", err)
}

sqrt0Val := memory.MemoryValueFromFieldElement(&sqrt0)
err = vm.Memory.WriteToAddress(&sqrt0Addr, &sqrt0Val)
if err != nil {
return fmt.Errorf("write sqrt0 cell: %v", err)
}

sqrt1Val := memory.MemoryValueFromFieldElement(&sqrt1)
err = vm.Memory.WriteToAddress(&sqrt1Addr, &sqrt1Val)
if err != nil {
return fmt.Errorf("write sqrt1 cell: %v", err)
}

// memory{remainder_low} = remainder & 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
// memory{remainder_high} = remainder >> 128
mask128 := uint256.NewInt(0xFFFFFFFFFFFFFFFF)
mask128.Lsh(mask128, 64)
mask128.Or(mask128, mask64)
remainderMasked := remainder.Clone()
remainderMasked.And(remainder, mask128)
remainderLow := f.Element{}
remainderLow.SetBytes(remainderMasked.Bytes())

remainderShifted := remainder.Clone()
remainderShifted.Rsh(remainder, 128)
remainderHigh := f.Element{}
remainderHigh.SetBytes(remainderShifted.Bytes())

remainderLowAddr, err := hint.remainderLow.Get(vm)
if err != nil {
return fmt.Errorf("get remainderLow cell: %v", err)
}

remainderHighAddr, err := hint.remainderHigh.Get(vm)
if err != nil {
return fmt.Errorf("get remainderHigh cell: %v", err)
}

remainderLowVal := memory.MemoryValueFromFieldElement(&remainderLow)
err = vm.Memory.WriteToAddress(&remainderLowAddr, &remainderLowVal)
if err != nil {
return fmt.Errorf("write remainderLow cell: %v", err)
}

remainderHighVal := memory.MemoryValueFromFieldElement(&remainderHigh)
err = vm.Memory.WriteToAddress(&remainderHighAddr, &remainderHighVal)
if err != nil {
return fmt.Errorf("write remainderHigh cell: %v", err)
}

// memory{sqrt_mul_2_minus_remainder_ge_u128} = root * 2 - remainder >= 2**128
rootMul2 := root.Clone()
rootMul2.Lsh(root, 1)
lhs := rootMul2.Clone()
lhs.Sub(rootMul2, remainder)

rhs := uint256.NewInt(1)
rhs.Lsh(rhs, 128)
// rhs.Mul(mask128, mask128)
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved
result := rhs.Gt(lhs)
result = !result

sqrtMul2MinusRemainderGeU128 := f.Element{}
if result {
sqrtMul2MinusRemainderGeU128.SetOne()
} else {
sqrtMul2MinusRemainderGeU128.SetZero()
}
mmk-1 marked this conversation as resolved.
Show resolved Hide resolved

sqrtMul2MinusRemainderGeU128Addr, err := hint.sqrtMul2MinusRemainderGeU128.Get(vm)
if err != nil {
return fmt.Errorf("get sqrtMul2MinusRemainderGeU128Addr cell: %v", err)
}

sqrtMul2MinusRemainderGeU128AddrVal := memory.MemoryValueFromFieldElement(&sqrtMul2MinusRemainderGeU128)
err = vm.Memory.WriteToAddress(&sqrtMul2MinusRemainderGeU128Addr, &sqrtMul2MinusRemainderGeU128AddrVal)
if err != nil {
return fmt.Errorf("write sqrtMul2MinusRemainderGeU128Addr cell: %v", err)
}

return nil
}
141 changes: 141 additions & 0 deletions pkg/hintrunner/hint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,144 @@ func TestSquareRoot(t *testing.T) {
readFrom(vm, VM.ExecutionSegment, 1),
)
}

func TestUint256SquareRootLow(t *testing.T) {
vm := defaultVirtualMachine()
vm.Context.Ap = 0
vm.Context.Fp = 0

var sqrt0 ApCellRef = 1
var sqrt1 ApCellRef = 2
var remainderLow ApCellRef = 3
var remainderHigh ApCellRef = 4
var sqrtMul2MinusRemainderGeU128 ApCellRef = 5

valueLow := Immediate(*big.NewInt(121))
valueHigh := Immediate(*big.NewInt(0))

hint := Uint256SquareRoot{
valueLow: valueLow,
valueHigh: valueHigh,
sqrt0: sqrt0,
sqrt1: sqrt1,
remainderLow: remainderLow,
remainderHigh: remainderHigh,
sqrtMul2MinusRemainderGeU128: sqrtMul2MinusRemainderGeU128,
}

err := hint.Execute(vm)

require.NoError(t, err)

expectedSqrt0 := memory.MemoryValueFromInt(11)
expectedSqrt1 := memory.MemoryValueFromInt(0)
expectedRemainderLow := memory.MemoryValueFromInt(0)
expectedRemainderHigh := memory.MemoryValueFromInt(0)
expectedSqrtMul2MinusRemainderGeU128 := memory.MemoryValueFromInt(0)

actualSqrt0 := readFrom(vm, VM.ExecutionSegment, 1)
actualSqrt1 := readFrom(vm, VM.ExecutionSegment, 2)
actualRemainderLow := readFrom(vm, VM.ExecutionSegment, 3)
actualRemainderHigh := readFrom(vm, VM.ExecutionSegment, 4)
actualSqrtMul2MinusRemainderGeU128 := readFrom(vm, VM.ExecutionSegment, 5)

require.Equal(t, expectedSqrt0, actualSqrt0)
require.Equal(t, expectedSqrt1, actualSqrt1)
require.Equal(t, expectedRemainderLow, actualRemainderLow)
require.Equal(t, expectedRemainderHigh, actualRemainderHigh)
require.Equal(t, expectedSqrtMul2MinusRemainderGeU128, actualSqrtMul2MinusRemainderGeU128)
}

func TestUint256SquareRootHigh(t *testing.T) {
vm := defaultVirtualMachine()
vm.Context.Ap = 0
vm.Context.Fp = 0

var sqrt0 ApCellRef = 1
var sqrt1 ApCellRef = 2
var remainderLow ApCellRef = 3
var remainderHigh ApCellRef = 4
var sqrtMul2MinusRemainderGeU128 ApCellRef = 5

valueLow := Immediate(*big.NewInt(0))
valueHigh := Immediate(*big.NewInt(1 << 8))

hint := Uint256SquareRoot{
valueLow: valueLow,
valueHigh: valueHigh,
sqrt0: sqrt0,
sqrt1: sqrt1,
remainderLow: remainderLow,
remainderHigh: remainderHigh,
sqrtMul2MinusRemainderGeU128: sqrtMul2MinusRemainderGeU128,
}

err := hint.Execute(vm)

require.NoError(t, err)

expectedSqrt0 := memory.MemoryValueFromInt(0)
expectedSqrt1 := memory.MemoryValueFromInt(16)
expectedRemainderLow := memory.MemoryValueFromInt(0)
expectedRemainderHigh := memory.MemoryValueFromInt(0)
expectedSqrtMul2MinusRemainderGeU128 := memory.MemoryValueFromInt(0)

actualSqrt0 := readFrom(vm, VM.ExecutionSegment, 1)
actualSqrt1 := readFrom(vm, VM.ExecutionSegment, 2)
actualRemainderLow := readFrom(vm, VM.ExecutionSegment, 3)
actualRemainderHigh := readFrom(vm, VM.ExecutionSegment, 4)
actualSqrtMul2MinusRemainderGeU128 := readFrom(vm, VM.ExecutionSegment, 5)

require.Equal(t, expectedSqrt0, actualSqrt0)
require.Equal(t, expectedSqrt1, actualSqrt1)
require.Equal(t, expectedRemainderLow, actualRemainderLow)
require.Equal(t, expectedRemainderHigh, actualRemainderHigh)
require.Equal(t, expectedSqrtMul2MinusRemainderGeU128, actualSqrtMul2MinusRemainderGeU128)
}

func TestUint256SquareRoot(t *testing.T) {
vm := defaultVirtualMachine()
vm.Context.Ap = 0
vm.Context.Fp = 0

var sqrt0 ApCellRef = 1
var sqrt1 ApCellRef = 2
var remainderLow ApCellRef = 3
var remainderHigh ApCellRef = 4
var sqrtMul2MinusRemainderGeU128 ApCellRef = 5

valueLow := Immediate(*big.NewInt(51))
valueHigh := Immediate(*big.NewInt(1024))

hint := Uint256SquareRoot{
valueLow: valueLow,
valueHigh: valueHigh,
sqrt0: sqrt0,
sqrt1: sqrt1,
remainderLow: remainderLow,
remainderHigh: remainderHigh,
sqrtMul2MinusRemainderGeU128: sqrtMul2MinusRemainderGeU128,
}

err := hint.Execute(vm)

require.NoError(t, err)

expectedSqrt0 := memory.MemoryValueFromInt(0)
expectedSqrt1 := memory.MemoryValueFromInt(32)
expectedRemainderLow := memory.MemoryValueFromInt(51)
expectedRemainderHigh := memory.MemoryValueFromInt(0)
expectedSqrtMul2MinusRemainderGeU128 := memory.MemoryValueFromInt(0)

actualSqrt0 := readFrom(vm, VM.ExecutionSegment, 1)
actualSqrt1 := readFrom(vm, VM.ExecutionSegment, 2)
actualRemainderLow := readFrom(vm, VM.ExecutionSegment, 3)
actualRemainderHigh := readFrom(vm, VM.ExecutionSegment, 4)
actualSqrtMul2MinusRemainderGeU128 := readFrom(vm, VM.ExecutionSegment, 5)

require.Equal(t, expectedSqrt0, actualSqrt0)
require.Equal(t, expectedSqrt1, actualSqrt1)
require.Equal(t, expectedRemainderLow, actualRemainderLow)
require.Equal(t, expectedRemainderHigh, actualRemainderHigh)
require.Equal(t, expectedSqrtMul2MinusRemainderGeU128, actualSqrtMul2MinusRemainderGeU128)
}