diff --git a/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts b/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts index 591274fe..a34363a9 100644 --- a/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts +++ b/src/boilerplate/circuit/zokrates/nodes/BoilerplateGenerator.ts @@ -86,6 +86,7 @@ const collectIncrements = (bpg: BoilerplateGenerator) => { incrementsString += ` + ${modName}`; } } + return { incrementsArray, incrementsString }; }; @@ -354,7 +355,6 @@ class BoilerplateGenerator { } else { throw new Error('This should be unreachable.'); } - return index; } diff --git a/src/transformers/visitors/checks/incrementedVisitor.ts b/src/transformers/visitors/checks/incrementedVisitor.ts index 2c7bff04..e4bba3ab 100644 --- a/src/transformers/visitors/checks/incrementedVisitor.ts +++ b/src/transformers/visitors/checks/incrementedVisitor.ts @@ -21,11 +21,24 @@ const literalOneNode = { precedingOperator: '', }; -const collectIncrements = (increments: any, incrementedIdentifier: any) => { +const collectIncrements = (increments: any, incrementedIdentifier: any, assignmentOperator: any, isTupleExpression: boolean) => { const { operands, precedingOperator } = increments; const newIncrements: any[] = []; const Idname = incrementedIdentifier.name || incrementedIdentifier.expression?.name; for (const [index, operand] of operands.entries()) { +// This Logic, changes the sign in case of decrements when don't have a tuple expression as in the circuits/Orchestration we +// translate a = a - b + c - d as a = a - (b - c + d) + if(assignmentOperator === '=' && precedingOperator[1] === '-' && index != 0){ + if(index == 1) + operand.precedingOperator = '+'; + else { + if(!isTupleExpression) + operand.precedingOperator = precedingOperator[index] === '+' ? '-' : '+'; + else + operand.precedingOperator = precedingOperator[index]; + } + } + else operand.precedingOperator = precedingOperator[index]; if ( operand.name !== Idname && @@ -37,6 +50,15 @@ const collectIncrements = (increments: any, incrementedIdentifier: any) => { return newIncrements; }; +const mixedOperatorsWarning = (path: NodePath) => { + backtrace.getSourceCode(path.node.src); + logger.warn( + `When we mix positive and negative operands in assigning to a secret variable, we may encounter underflow errors. Make sure that incrementing (a = a + ...) always increases the secret state value while decrementing (a = a - ...) decreases it. + \nWhenever we see something like a = a + b - c, we assume it's a positive incrementation, so b > c. Similarly, we assume a = a - b + c is a decrementation, so c - b < a.`, + ); + +}; + // marks the parent ExpressionStatement const markParentIncrementation = ( path: NodePath, @@ -51,6 +73,8 @@ const markParentIncrementation = ( : incrementedIdentifier; const parent = path.getAncestorOfType('ExpressionStatement'); if (!parent) throw new Error(`No parent of node ${path.node.name} found`); + const isTupleExpression = parent?.node.expression?.rightHandSide?.nodeType === 'TupleExpression' || parent?.node.expression?.rightHandSide?.rightExpression?.nodeType === 'TupleExpression' + const assignmentOp = parent?.node.expression?.operator; parent.isIncremented = isIncremented; parent.isDecremented = isDecremented; parent.incrementedDeclaration = incrementedIdentifier.referencedDeclaration; @@ -58,7 +82,7 @@ const markParentIncrementation = ( state.unmarkedIncrementation = false; state.incrementedIdentifier = incrementedIdentifier; if (increments?.operands) - increments = collectIncrements(increments, incrementedIdentifier); + increments = collectIncrements(increments, incrementedIdentifier, assignmentOp, isTupleExpression); increments?.forEach((inc: any) => { if ( inc.precedingOperator === '-' && @@ -94,23 +118,20 @@ const getIncrementedPath = (path: NodePath, state: any) => { state.stopTraversal = !!state.incrementedPath?.node; }; -const mixedOperatorsWarning = (path: NodePath) => { - backtrace.getSourceCode(path.node.src); - logger.warn( - `When we mix positive and negative operands in assigning to a secret variable, we may encounter underflow errors. Make sure that incrementing (a = a + ...) always increases the secret state value while decrementing (a = a - ...) decreases it. \nWhenever we see something like a = a + b - c, we assume it's a positive incrementation, so b > c. Similarly, we assume a = a - b + c is a decrementation, so c - b < a.`, - ); -}; + const binOpToIncrements = (path: NodePath, state: any) => { - const parentExpressionStatement = path.getAncestorOfType( + let parentExpressionStatement = path.getAncestorOfType( 'ExpressionStatement', ); const lhsNode = parentExpressionStatement?.node.expression?.leftHandSide; const assignmentOp = parentExpressionStatement?.node.expression?.operator; - const { operator, leftExpression, rightExpression } = path.node; - const operands = [leftExpression, rightExpression]; - const precedingOperator = ['+', operator]; + const { operator, leftExpression, rightExpression } = path.node ; + let operands = [leftExpression, rightExpression]; + + const precedingOperator = ['+', operator]; + const isTupleExpression = operands[1].nodeType === 'TupleExpression'; // if we dont have any + or -, it can't be an incrementation if ( !operator.includes('+') && @@ -121,25 +142,61 @@ const binOpToIncrements = (path: NodePath, state: any) => { markParentIncrementation(path, state, false, false, lhsNode); return; } - + // correct the operands for case when a = a - (b + c + d). + if(isTupleExpression) { + operands[0] = operands[1].components[0].rightExpression; + precedingOperator.push(operands[1].components[0].operator); + operands[1] = operands[1].components[0].leftExpression; + + for (const [index, operand] of operands.entries()) { + if (operand.nodeType === 'BinaryOperation') { + operands[index] = operand.leftExpression; + operands.splice(0, 0, operand.rightExpression); + precedingOperator.splice(2, 0, operand.operator); + } + } + operands.splice(0, 0, operands[operands.length -1]).slice(0, -1); + + } // fills an array of operands // e.g. if we have a = b - c + a + d, operands = [b, c, a, d] + if(!isTupleExpression){ + operands = operands.reverse(); for (const [index, operand] of operands.entries()) { if (operand.nodeType === 'BinaryOperation') { operands[index] = operand.leftExpression; - operands.push(operand.rightExpression); - precedingOperator.push(operand.operator); + operands.splice(0, 0, operand.rightExpression); + precedingOperator.splice(1, 0, operand.operator); } } + operands.splice(0, 0, operands[operands.length -1]); +} + // if we have mixed operators, we may have an underflow or not be able to tell whether this is increasing (incrementation) or decreasing (decrementation) the secret value + // Here we give out a warning when we don't use parentheses. if ( precedingOperator.length > 2 && precedingOperator.includes('+') && precedingOperator.includes('-') && parentExpressionStatement ) - mixedOperatorsWarning(parentExpressionStatement); - + { + mixedOperatorsWarning(parentExpressionStatement); + if(!isTupleExpression) + logger.warn( + `Whenever you have multiple operands in an expression, such as a = a - b - c + d, it's better to use parentheses for clarity. For example, rewrite it as a = a - (b + c - d). This makes the expression easier to understand. `, + ); +} +if(assignmentOp === '=' && precedingOperator.length > 2) { + if(isTupleExpression) { + operands.splice(0, 0, path.node.leftExpression); + } else { + if(operands[0].rightExpression){ + operands.splice(1, 0, operands[0].rightExpression); + precedingOperator.splice(1, 0, operands[0].operator); + operands[0] = operands[0].leftExpression;} + } + } return { operands, precedingOperator }; }; @@ -154,6 +211,7 @@ export default { ExpressionStatement: { enter(path: NodePath, state: any) { // starts here - if the path hasn't yet been marked as incremented, we find out if it is + if (path.isIncremented === undefined) { state.unmarkedIncrementation = true; state.increments = []; @@ -176,7 +234,6 @@ export default { const { isIncremented, isDecremented } = path; expressionNode.isIncremented = isIncremented; expressionNode.isDecremented = isDecremented; - // print if in debug mode if (logger.level === 'debug') backtrace.getSourceCode(node.src); logger.debug(`statement is incremented? ${isIncremented}`); @@ -239,7 +296,6 @@ export default { const { operator, leftHandSide, rightHandSide } = node; const lhsSecret = !!scope.getReferencedBinding(leftHandSide)?.isSecret; - if (['bool', 'address'].includes(leftHandSide.typeDescriptions.typeString)) { markParentIncrementation(path, state, false, false, leftHandSide); const lhsBinding = scope.getReferencedBinding(leftHandSide) @@ -351,7 +407,6 @@ export default { } const { operands, precedingOperator } = binOpToIncrements(path, state) || {}; - if (!operands || !precedingOperator) return; // if we find our lhs variable (a) on the rhs (a = a + b), then we make sure we don't find it again (a = a + b + a = b + 2a) @@ -395,7 +450,6 @@ export default { discoveredLHS += 1; isIncremented = { incremented: true, decremented: true }; } - // a = something - a if ( nameMatch && @@ -417,17 +471,19 @@ export default { ) { // a = a + b - c - d counts as an incrementation since the 1st operator is a plus // the mixed operators warning will have been given + // length of operator will be more than 2 in case of mixed operators if ( + precedingOperator.length > 2 && precedingOperator.includes('+') && - precedingOperator.includes('-') && - precedingOperator[0] === '+' - ) - isIncremented.decremented = false; + precedingOperator.includes('-') + ){ + isIncremented.decremented = precedingOperator[1] === '+' ? false : true; + } markParentIncrementation( path, state, isIncremented.incremented, - false, + isIncremented.decremented, lhsNode.baseExpression || lhsNode, { operands, precedingOperator }, ); diff --git a/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts b/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts index 44466416..3417b712 100644 --- a/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts +++ b/src/transformers/visitors/orchestrationInternalFunctionCallVisitor.ts @@ -25,7 +25,7 @@ const internalCallVisitor = { let sendTransactionNode : any; let newdecrementedSecretStates = []; node._newASTPointer.forEach(file => { - state.intFnindex = {}; + state.intFnindex = {}; state.internalFncName?.forEach( (name, index)=> { let callingFncName = state.callingFncName[index].name; if(file.fileName === name && file.nodeType === 'File') { diff --git a/test/contracts/Arrays1.zol b/test/contracts/Arrays1.zol index 2bd6f19c..87c5c43b 100644 --- a/test/contracts/Arrays1.zol +++ b/test/contracts/Arrays1.zol @@ -18,7 +18,10 @@ contract Assign { index++; a = a + index; index++; + j =j + 2; + j++; b[index] = value; + a += j; } function remove(secret uint256 value) public { diff --git a/test/contracts/Arrays2.zol b/test/contracts/Arrays2.zol index d85cc8a0..45b72b90 100644 --- a/test/contracts/Arrays2.zol +++ b/test/contracts/Arrays2.zol @@ -20,7 +20,7 @@ contract Assign { b[index] = value; index++; j++; - b[index] = value +index +j; + b[index] = (value - index +j); index += 1; a += value + index; b[index] = value + index; diff --git a/test/contracts/Assign-Increment.zol b/test/contracts/Assign-Increment.zol new file mode 100644 index 00000000..68316e74 --- /dev/null +++ b/test/contracts/Assign-Increment.zol @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: CC0 + +pragma solidity ^0.8.0; + +contract Assign { + + secret uint256 private a; + secret uint256 private b; + function add(secret uint256 value) public { + a += value; + unknown b += value; + } + + function remove(secret uint256 value, secret uint256 value1) public { + a += value; + b -= value + value1; + } + + function add1(secret uint256 value, secret uint256 value1, secret uint256 value2, secret uint256 value3, secret uint256 value4) public { + a = a + value - value1 + value3; + unknown b = b + (value1 - value2 - value3 + value4); +} + + +function remove1(secret uint256 value, secret uint256 value1, secret uint256 value2, secret uint256 value3, secret uint256 value4) public { + a = a - value - value1 + value3; + b = b - value1 - value2 + value3 - value4; +} + +} \ No newline at end of file diff --git a/test/contracts/Assign.zol b/test/contracts/Assign.zol index 0abde841..b7c546fb 100644 --- a/test/contracts/Assign.zol +++ b/test/contracts/Assign.zol @@ -9,9 +9,8 @@ contract Assign { unknown a += value; } - function remove(secret uint256 value) public returns (uint256) { - a -= value; - return a; + function remove(secret uint256 value) public { + a -= value ; } } diff --git a/test/contracts/BucketsOfBalls.zol b/test/contracts/BucketsOfBalls.zol index ba2ab2ac..95632a59 100644 --- a/test/contracts/BucketsOfBalls.zol +++ b/test/contracts/BucketsOfBalls.zol @@ -12,6 +12,6 @@ contract BucketsOfBalls { function transfer(secret address toBucketId, secret uint256 numberOfBalls) public { buckets[msg.sender] -= numberOfBalls; - encrypt unknown buckets[toBucketId] += numberOfBalls; + unknown buckets[toBucketId] += numberOfBalls; } }