diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 24bbd3c78..47fddcfa7 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -37,6 +37,9 @@ const ( // split_felt() hints. splitFeltCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert ids.MAX_HIGH < 2**128 and ids.MAX_LOW < 2**128\nassert PRIME - 1 == ids.MAX_HIGH * 2**128 + ids.MAX_LOW\nassert_integer(ids.value)\nids.low = ids.value & ((1 << 128) - 1)\nids.high = ids.value >> 128" + // sqrt() hint + sqrtCode string = "from starkware.python.math_utils import isqrt\nvalue = ids.value % PRIME\nassert value < 2 ** 250, f\"value={value} is outside of the range [0, 2**250).\"\nassert 2 ** 250 < PRIME\nids.root = isqrt(value)" + // ------ Uint256 hints related code ------ uint256AddCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0" uint256AddLowCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0" diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index dc8a424a6..0798c264e 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -100,6 +100,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createUint256UnsignedDivRemHinter(resolver) case uint256SqrtCode: return createUint256SqrtHinter(resolver) + case sqrtCode: + return createSqrtHinter(resolver) default: return nil, fmt.Errorf("Not identified hint") } diff --git a/pkg/hintrunner/zero/zerohint_math.go b/pkg/hintrunner/zero/zerohint_math.go index 04e046bfc..0587f08f8 100644 --- a/pkg/hintrunner/zero/zerohint_math.go +++ b/pkg/hintrunner/zero/zerohint_math.go @@ -4,6 +4,8 @@ import ( "fmt" "math/big" + "github.com/holiman/uint256" + "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" "github.com/NethermindEth/cairo-vm-go/pkg/utils" @@ -581,3 +583,53 @@ func createSplitFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error return newSplitFeltHint(low, high, value), nil } + +func newSqrtHint(root, value hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "Sqrt", + Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error { + //> from starkware.python.math_utils import isqrt + // value = ids.value % PRIME + // assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)." + // assert 2 ** 250 < PRIME + // ids.root = isqrt(value) + + rootAddr, err := root.GetAddress(vm) + if err != nil { + return err + } + + value, err := hinter.ResolveAsFelt(vm, value) + if err != nil { + return err + } + + if !utils.FeltLt(value, &utils.FeltUpperBound) { + return fmt.Errorf("assertion failed: %v is outside of the range [0, 2**250)", value) + } + + // Conversion needed to handle non-square values + valueU256 := uint256.Int(value.Bits()) + valueU256.Sqrt(&valueU256) + + result := fp.Element{} + result.SetBytes(valueU256.Bytes()) + + v := memory.MemoryValueFromFieldElement(&result) + return vm.Memory.WriteToAddress(&rootAddr, &v) + }, + } +} + +func createSqrtHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + + root, err := resolver.GetResOperander("root") + if err != nil { + return nil, err + } + value, err := resolver.GetResOperander("value") + if err != nil { + return nil, err + } + return newSqrtHint(root, value), nil +} diff --git a/pkg/hintrunner/zero/zerohint_math_test.go b/pkg/hintrunner/zero/zerohint_math_test.go index 1d25db04f..475fad0b4 100644 --- a/pkg/hintrunner/zero/zerohint_math_test.go +++ b/pkg/hintrunner/zero/zerohint_math_test.go @@ -586,5 +586,48 @@ func TestZeroHintMath(t *testing.T) { }), }, }, + + "SqrtHint": { + { + operanders: []*hintOperander{ + {Name: "root", Kind: uninitialized}, + {Name: "value", Kind: fpRelative, Value: feltInt64(25)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"]) + }, + check: varValueEquals("root", feltInt64(5)), + }, + { + operanders: []*hintOperander{ + {Name: "root", Kind: uninitialized}, + {Name: "value", Kind: fpRelative, Value: feltInt64(0)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"]) + }, + check: varValueEquals("root", feltInt64(0)), + }, + { + operanders: []*hintOperander{ + {Name: "root", Kind: uninitialized}, + {Name: "value", Kind: fpRelative, Value: feltInt64(50)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"]) + }, + check: varValueEquals("root", feltInt64(7)), + }, + { + operanders: []*hintOperander{ + {Name: "root", Kind: uninitialized}, + {Name: "value", Kind: fpRelative, Value: feltInt64(-128)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSqrtHint(ctx.operanders["root"], ctx.operanders["value"]) + }, + errCheck: errorTextContains("outside of the range [0, 2**250)"), + }, + }, }) }