Skip to content

Commit

Permalink
Add: Cairo0 Sqrt hint (#316)
Browse files Browse the repository at this point in the history
* Add: Cairo0 Sqrt hint

* add: complete hintcode in python for sqrt hint
  • Loading branch information
Tomi-3-0 authored Mar 20, 2024
1 parent 735c143 commit ce40eaf
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ const (
// split_felt() hints.
splitFeltCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128\nassert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW\nassert_integer(ids.value)\nids.low = ids.value & ((1 << 128) - 1)\nids.high = ids.value >> 128"

// sqrt() hint
sqrtCode string = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)"

// ------ Uint256 hints related code ------
uint256AddCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"
uint256AddLowCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createUint256UnsignedDivRemHinter(resolver)
case uint256SqrtCode:
return createUint256SqrtHinter(resolver)
case sqrtCode:
return createSqrtHinter(resolver)
default:
return nil, fmt.Errorf("Not identified hint")
}
Expand Down
52 changes: 52 additions & 0 deletions pkg/hintrunner/zero/zerohint_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"math/big"

"github.com/holiman/uint256"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
Expand Down Expand Up @@ -581,3 +583,53 @@ func createSplitFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error

return newSplitFeltHint(low, high, value), nil
}

func newSqrtHint(root, value hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "Sqrt",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
//> from starkware.python.math_utils import isqrt
// value = ids.value % PRIME
// assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)."
// assert 2 ** 250 < PRIME
// ids.root = isqrt(value)

rootAddr, err := root.GetAddress(vm)
if err != nil {
return err
}

value, err := hinter.ResolveAsFelt(vm, value)
if err != nil {
return err
}

if !utils.FeltLt(value, &utils.FeltUpperBound) {
return fmt.Errorf("assertion failed: %v is outside of the range [0, 2**250)", value)
}

// Conversion needed to handle non-square values
valueU256 := uint256.Int(value.Bits())
valueU256.Sqrt(&valueU256)

result := fp.Element{}
result.SetBytes(valueU256.Bytes())

v := memory.MemoryValueFromFieldElement(&result)
return vm.Memory.WriteToAddress(&rootAddr, &v)
},
}
}

func createSqrtHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {

root, err := resolver.GetResOperander("root")
if err != nil {
return nil, err
}
value, err := resolver.GetResOperander("value")
if err != nil {
return nil, err
}
return newSqrtHint(root, value), nil
}
43 changes: 43 additions & 0 deletions pkg/hintrunner/zero/zerohint_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,5 +586,48 @@ func TestZeroHintMath(t *testing.T) {
}),
},
},

"SqrtHint": {
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(25)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
check: varValueEquals("root", feltInt64(5)),
},
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(0)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
check: varValueEquals("root", feltInt64(0)),
},
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(50)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
check: varValueEquals("root", feltInt64(7)),
},
{
operanders: []*hintOperander{
{Name: "root", Kind: uninitialized},
{Name: "value", Kind: fpRelative, Value: feltInt64(-128)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"])
},
errCheck: errorTextContains("outside of the range [0, 2**250)"),
},
},
})
}

0 comments on commit ce40eaf

Please sign in to comment.