Skip to content

Commit

Permalink
Formatted the code in Black style
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexvozhak committed Aug 1, 2024
1 parent 43336d8 commit d445eb3
Show file tree
Hide file tree
Showing 7 changed files with 353 additions and 218 deletions.
222 changes: 164 additions & 58 deletions diffyscan/diffyscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from .utils.common import load_config, load_env
from .utils.constants import *
from .utils.explorer import get_contract_from_explorer
from .utils.github import get_file_from_github, get_file_from_github_recursive, resolve_dep
from .utils.github import (
get_file_from_github,
get_file_from_github_recursive,
resolve_dep,
)
from .utils.helpers import create_dirs
from .utils.logger import logger
from .utils.binary_verifier import *
Expand All @@ -21,53 +25,85 @@

g_skip_user_input: bool = False


def prettify_solidity(solidity_contract_content: str):
github_file_name = os.path.join(tempfile.gettempdir(), "9B91E897-EA51-4FCC-8DAF-FCFF135A6963.sol")
github_file_name = os.path.join(
tempfile.gettempdir(), "9B91E897-EA51-4FCC-8DAF-FCFF135A6963.sol"
)
with open(github_file_name, "w") as fp:
fp.write(solidity_contract_content)
prettier_return_code = subprocess.call(
["npx", "prettier", "--plugin=prettier-plugin-solidity", "--write", github_file_name],
stdout=subprocess.DEVNULL)
[
"npx",
"prettier",
"--plugin=prettier-plugin-solidity",
"--write",
github_file_name,
],
stdout=subprocess.DEVNULL,
)
if prettier_return_code != 0:
logger.error("Prettier/npx subprocess failed (see the error above)")
sys.exit()
with open(github_file_name, "r") as fp:
return fp.read()


def run_binary_diff(remote_contract_address, contract_source_code, config):
logger.info(f'Started binary checking for {remote_contract_address}')

contract_creation_code, immutables, is_valid_constructor = get_contract_creation_code_from_etherscan(contract_source_code, config, remote_contract_address)

logger.info(f"Started binary checking for {remote_contract_address}")

contract_creation_code, immutables, is_valid_constructor = (
get_contract_creation_code_from_etherscan(
contract_source_code, config, remote_contract_address
)
)

if not is_valid_constructor:
logger.error(f'Failed to find constructorArgs, binary diff skipped')
logger.error(f"Failed to find constructorArgs, binary diff skipped")
return

deployer_account = get_account(LOCAL_RPC_URL)
if (deployer_account is None):
logger.error(f'Failed to receive the account, binary diff skipped')

if deployer_account is None:
logger.error(f"Failed to receive the account, binary diff skipped")
return

local_contract_address = deploy_contract(LOCAL_RPC_URL, deployer_account, contract_creation_code)

if (local_contract_address is None):
logger.error(f'Failed to deploy bytecode to {LOCAL_RPC_URL}, binary diff skipped')

local_contract_address = deploy_contract(
LOCAL_RPC_URL, deployer_account, contract_creation_code
)

if local_contract_address is None:
logger.error(
f"Failed to deploy bytecode to {LOCAL_RPC_URL}, binary diff skipped"
)
return

local_deployed_bytecode = get_bytecode(local_contract_address, LOCAL_RPC_URL)
if (local_deployed_bytecode is None):
logger.error(f'Failed to receive bytecode from {LOCAL_RPC_URL}')
if local_deployed_bytecode is None:
logger.error(f"Failed to receive bytecode from {LOCAL_RPC_URL}")
return

remote_deployed_bytecode = get_bytecode(remote_contract_address, REMOTE_RPC_URL)
if remote_deployed_bytecode is None:
logger.error(f'Failed to receive bytecode from {REMOTE_RPC_URL}')
logger.error(f"Failed to receive bytecode from {REMOTE_RPC_URL}")
return

to_match(local_deployed_bytecode, remote_deployed_bytecode, immutables, remote_contract_address)

def run_source_diff(contract_address_from_config, contract_code, config, github_api_token, recursive_parsing=False, prettify=False):
to_match(
local_deployed_bytecode,
remote_deployed_bytecode,
immutables,
remote_contract_address,
)


def run_source_diff(
contract_address_from_config,
contract_code,
config,
github_api_token,
recursive_parsing=False,
prettify=False,
):
logger.divider()
logger.okay("Contract", contract_address_from_config)
logger.okay("Blockchain explorer Hostname", config["explorer_hostname"])
Expand All @@ -81,7 +117,11 @@ def run_source_diff(contract_address_from_config, contract_code, config, github_
f"Fetching source code from blockchain explorer {config['explorer_hostname']} ..."
)

source_files = contract_code["solcInput"].items() if not "sources" in contract_code["solcInput"] else contract_code["solcInput"]["sources"].items()
source_files = (
contract_code["solcInput"].items()
if not "sources" in contract_code["solcInput"]
else contract_code["solcInput"]["sources"].items()
)
files_count = len(source_files)
logger.okay("Contract", contract_code["name"])
logger.okay("Files", files_count)
Expand Down Expand Up @@ -119,9 +159,13 @@ def run_source_diff(contract_address_from_config, contract_code, config, github_
file_found = bool(repo)

if recursive_parsing:
github_file = get_file_from_github_recursive(github_api_token, repo, path_to_file, dep_name)
github_file = get_file_from_github_recursive(
github_api_token, repo, path_to_file, dep_name
)
else:
github_file = get_file_from_github(github_api_token, repo, path_to_file, dep_name)
github_file = get_file_from_github(
github_api_token, repo, path_to_file, dep_name
)

if not github_file:
github_file = "<!-- No file content -->"
Expand All @@ -137,7 +181,9 @@ def run_source_diff(contract_address_from_config, contract_code, config, github_
explorer_lines = explorer_content.splitlines()

diff_html = difflib.HtmlDiff().make_file(github_lines, explorer_lines)
diff_report_filename = f"{DIFFS_DIR}/{contract_address_from_config}/{filename}.html"
diff_report_filename = (
f"{DIFFS_DIR}/{contract_address_from_config}/{filename}.html"
)

create_dirs(diff_report_filename)
with open(diff_report_filename, "w") as f:
Expand Down Expand Up @@ -167,55 +213,97 @@ def run_source_diff(contract_address_from_config, contract_code, config, github_

logger.report_table(report)

def process_config(path: str, recursive_parsing: bool, unify_formatting: bool, binary_check: bool, autoclean: bool):

def process_config(
path: str,
recursive_parsing: bool,
unify_formatting: bool,
binary_check: bool,
autoclean: bool,
):
logger.info(f"Loading config {path}...")
config = load_config(path)

explorer_token = None
if "explorer_token_env_var" in config:
explorer_token = load_env(config["explorer_token_env_var"], masked=True, required=False)
if (explorer_token is None):
explorer_token = os.getenv('ETHERSCAN_EXPLORER_TOKEN', default=None)
if (explorer_token is None):
raise ValueError(f'Failed to find "ETHERSCAN_EXPLORER_TOKEN" in env')

explorer_token = load_env(
config["explorer_token_env_var"], masked=True, required=False
)
if explorer_token is None:
explorer_token = os.getenv("ETHERSCAN_EXPLORER_TOKEN", default=None)
if explorer_token is None:
raise ValueError(f'Failed to find "ETHERSCAN_EXPLORER_TOKEN" in env')

contracts = config["contracts"]

try:
if (binary_check):
ganache.start()
if binary_check:
ganache.start()

for contract_address, contract_name in contracts.items():
contract_code = get_contract_from_explorer(explorer_token, config["explorer_hostname"], contract_address, contract_name)
run_source_diff(contract_address, contract_code, config, GITHUB_API_TOKEN, recursive_parsing, unify_formatting)
if (binary_check):
contract_code = get_contract_from_explorer(
explorer_token,
config["explorer_hostname"],
contract_address,
contract_name,
)
run_source_diff(
contract_address,
contract_code,
config,
GITHUB_API_TOKEN,
recursive_parsing,
unify_formatting,
)
if binary_check:
run_binary_diff(contract_address, contract_code, config)
except KeyboardInterrupt:
logger.info(f'Keyboard interrupt by user')
logger.info(f"Keyboard interrupt by user")
finally:
ganache.stop()


if (autoclean):
if autoclean:
shutil.rmtree(SOLC_DIR)
logger.okay(f'{SOLC_DIR} deleted')
logger.okay(f"{SOLC_DIR} deleted")


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--version", "-V", action="store_true", help="Display version information")
parser.add_argument("path", nargs="?", default=None, help="Path to config or directory with configs")
parser.add_argument("--yes", "-y", help="If set don't ask for input before validating each contract", action="store_true")
parser.add_argument(
"--version", "-V", action="store_true", help="Display version information"
)
parser.add_argument(
"path", nargs="?", default=None, help="Path to config or directory with configs"
)
parser.add_argument(
"--yes",
"-y",
help="If set don't ask for input before validating each contract",
action="store_true",
)
parser.add_argument(
"--support-brownie",
help="Support recursive retrieving for contracts. It may be useful for contracts whose sources have been verified by the brownie tooling, which automatically replaces relative paths to contracts in imports with plain contract names.",
action=argparse.BooleanOptionalAction,
)
parser.add_argument("--prettify", "-p", help="Unify formatting by prettier before comparing", action="store_true")
parser.add_argument("--binary-check", "-binary", help="Match contracts by binaries such as verify-bytecode.ts", default=True)
parser.add_argument("--autoclean", "-clean", help="Autoclean build dir after work", default=True)
parser.add_argument(
"--prettify",
"-p",
help="Unify formatting by prettier before comparing",
action="store_true",
)
parser.add_argument(
"--binary-check",
"-binary",
help="Match contracts by binaries such as verify-bytecode.ts",
default=True,
)
parser.add_argument(
"--autoclean", "-clean", help="Autoclean build dir after work", default=True
)
return parser.parse_args()


def main():
global g_skip_user_input

Expand All @@ -228,14 +316,32 @@ def main():
logger.divider()

if args.path is None:
process_config(DEFAULT_CONFIG_PATH, args.support_brownie, args.prettify, args.binary_check, args.autoclean)
process_config(
DEFAULT_CONFIG_PATH,
args.support_brownie,
args.prettify,
args.binary_check,
args.autoclean,
)
elif os.path.isfile(args.path):
process_config(args.path, args.support_brownie, args.prettify, args.binary_check, args.autoclean)
process_config(
args.path,
args.support_brownie,
args.prettify,
args.binary_check,
args.autoclean,
)
elif os.path.isdir(args.path):
for filename in os.listdir(args.path):
config_path = os.path.join(args.path, filename)
if os.path.isfile(config_path):
process_config(config_path, args.support_brownie, args.prettify, args.binary_check, args.autoclean)
process_config(
config_path,
args.support_brownie,
args.prettify,
args.binary_check,
args.autoclean,
)
else:
logger.error(f"Specified config path {args.path} not found")
sys.exit(1)
Expand Down
2 changes: 1 addition & 1 deletion diffyscan/utils/binary_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,4 @@ def parse(bytecode):
'bytecode': buffer[i:i+length].hex()
})
i += length
return instructions
return instructions
18 changes: 12 additions & 6 deletions diffyscan/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .logger import logger
from .types import Config


def load_env(variable_name, required=True, masked=False):
value = os.getenv(variable_name, default=None)

Expand All @@ -29,6 +30,7 @@ def load_config(path: str) -> Config:
with open(path, mode="r") as config_file:
return json.load(config_file)


def handle_response(response, url):
if response.status_code == 404:
return None
Expand All @@ -40,18 +42,21 @@ def handle_response(response, url):

return response


def fetch(url, headers={}):
logger.log(f"Fetch: {url}")
response = requests.get(url, headers=headers)

return handle_response(response, url)


def pull(url, payload={}):
logger.log(f"Pull: {url}")
response = requests.post(url, data=payload)

return handle_response(response, url)



def mask_text(text, mask_start=3, mask_end=3):
text_length = len(text)
mask = "*" * (text_length - mask_start - mask_end)
Expand All @@ -64,11 +69,12 @@ def parse_repo_link(repo_link):
user_slash_repo = repo_location[0]
return user_slash_repo


def get_solc_native_platform_from_os():
platform_name = sys.platform
if platform_name == 'linux':
return 'linux-amd64'
elif platform_name == 'darwin':
return 'macosx-amd64'
if platform_name == "linux":
return "linux-amd64"
elif platform_name == "darwin":
return "macosx-amd64"
else:
raise ValueError(f'Unsupported platform {platform_name}')
raise ValueError(f"Unsupported platform {platform_name}")
Loading

0 comments on commit d445eb3

Please sign in to comment.