From 245820d83af311f6a134481244d0b55bb07ec87b Mon Sep 17 00:00:00 2001 From: Eric Brown Date: Tue, 22 Feb 2022 17:54:52 -0800 Subject: [PATCH] Suggest fixes for issues This change introduces a new feature that will suggest a fix in the form of a line of code as replacement for the line range of the issue. This is the first step to have the ability to auto-correct problems detected. Later more changes can be merged to modify the file with the suggested fix. The Issue class has a new fix string attribute that denotes how the lines of affected code can be replaced. This suggested fix will not preserve code comments and possibly other optimizations the AST does not capture. Closes #439 Signed-off-by: Eric Brown Signed-off-by: Eric Brown --- bandit/core/context.py | 44 ++++++++++++++++ bandit/core/issue.py | 7 ++- bandit/core/node_visitor.py | 2 +- bandit/formatters/html.py | 14 +++++ bandit/formatters/screen.py | 7 +++ bandit/formatters/text.py | 4 ++ bandit/plugins/app_debug.py | 3 ++ .../crypto_request_no_cert_validation.py | 3 ++ bandit/plugins/hashlib_insecure_functions.py | 18 +++++++ bandit/plugins/jinja2_templates.py | 12 +++++ .../plugins/ssh_no_host_key_verification.py | 3 ++ bandit/plugins/yaml_load.py | 16 ++++++ examples/suggest_fix.py | 52 +++++++++++++++++++ 13 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 examples/suggest_fix.py diff --git a/bandit/core/context.py b/bandit/core/context.py index 8a2d4fbbc..84fb7b979 100644 --- a/bandit/core/context.py +++ b/bandit/core/context.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 import ast +import linecache +import sys from bandit.core import utils @@ -311,6 +313,48 @@ def is_module_imported_like(self, module): return True return False + def get_outer_text(self): + """Get the text to the left and right of the node in context. + + Gets the text to the left and text to the right of the node in + context. This function depends on knowing the line range, col_offset, + and end_col_offset. + + :return: outer text as tuple + """ + lineno = self._context.get("linerange")[0] + end_lineno = self._context.get("linerange")[-1] + col_offset = self._context.get("col_offset") + end_col_offset = self._context.get("end_col_offset") + + if self._context.get("filename") == "": + self._context.get("file_data").seek(0) + for line_num in range(1, lineno): + self._context.get("file_data").readline() + line = self._context.get("file_data").readline() + end_line = line + if end_lineno > lineno: + for line_num in range(1, end_lineno): + self._context.get("file_data").readline() + end_line = self._context.get("file_data").readline() + else: + line = linecache.getline(self._context.get("filename"), lineno) + end_line = linecache.getline( + self._context.get("filename"), end_lineno + ) + + return (line[:col_offset], end_line[end_col_offset:]) + + def unparse(self, transformer): + """Unparse an ast node using given transformer + + :param transformer: NodeTransformer that fixes the ast + :return: node as statement string + """ + fixed_node = ast.fix_missing_locations(transformer) + outer_text = self.get_outer_text() + return outer_text[0] + ast.unparse(fixed_node) + outer_text[1] + @property def filename(self): return self._context.get("filename") diff --git a/bandit/core/issue.py b/bandit/core/issue.py index bfa583356..767ff7832 100644 --- a/bandit/core/issue.py +++ b/bandit/core/issue.py @@ -87,6 +87,7 @@ def __init__( test_id="", col_offset=-1, end_col_offset=0, + fix=None, ): self.severity = severity self.cwe = Cwe(cwe) @@ -103,6 +104,7 @@ def __init__( self.col_offset = col_offset self.end_col_offset = end_col_offset self.linerange = [] + self.fix = fix def __str__(self): return ( @@ -195,7 +197,7 @@ def get_code(self, max_lines=3, tabbed=False): if not len(text): break lines.append(tmplt % (line, text)) - return "".join(lines) + return "".join(lines).rstrip() def as_dict(self, with_code=True, max_lines=3): """Convert the issue to a dict of values for outputting.""" @@ -215,6 +217,8 @@ def as_dict(self, with_code=True, max_lines=3): if with_code: out["code"] = self.get_code(max_lines=max_lines) + if self.fix: + out["fix"] = self.fix return out def from_dict(self, data, with_code=True): @@ -230,6 +234,7 @@ def from_dict(self, data, with_code=True): self.linerange = data["line_range"] self.col_offset = data.get("col_offset", 0) self.end_col_offset = data.get("end_col_offset", 0) + self.fix = data.get("fix") def cwe_from_dict(data): diff --git a/bandit/core/node_visitor.py b/bandit/core/node_visitor.py index 938e8733b..38e1825fe 100644 --- a/bandit/core/node_visitor.py +++ b/bandit/core/node_visitor.py @@ -13,7 +13,7 @@ LOG = logging.getLogger(__name__) -class BanditNodeVisitor: +class BanditNodeVisitor(ast.NodeTransformer): def __init__( self, fname, fdata, metaast, testset, debug, nosec_lines, metrics ): diff --git a/bandit/formatters/html.py b/bandit/formatters/html.py index fb09f835f..91f0bfabe 100644 --- a/bandit/formatters/html.py +++ b/bandit/formatters/html.py @@ -270,9 +270,19 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): Line number: {line_number}
More info: {url}
{code} + Suggested Fix:
+{fix} {candidates} +""" + + fix_block = """ +
+
+{fix}
+
+
""" code_block = """ @@ -357,6 +367,9 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates = candidate_block.format(candidate_list=candidates_str) url = docs_utils.get_url(issue.test_id) + fix = ( + fix_block.format(fix=html_escape(issue.fix)) if issue.fix else None + ) results_str += issue_block.format( issue_no=index, issue_class=f"issue-sev-{issue.severity.lower()}", @@ -372,6 +385,7 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates=candidates, url=url, line_number=issue.lineno, + fix=fix, ) # build the metrics string to insert in the report diff --git a/bandit/formatters/screen.py b/bandit/formatters/screen.py index 906e86b73..b4b956706 100644 --- a/bandit/formatters/screen.py +++ b/bandit/formatters/screen.py @@ -146,6 +146,13 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append( + f"{indent} {COLOR[issue.severity]}" + f"Suggested Fix:{COLOR['DEFAULT']}" + ) + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/formatters/text.py b/bandit/formatters/text.py index e6918e37b..ddcf77b0a 100644 --- a/bandit/formatters/text.py +++ b/bandit/formatters/text.py @@ -112,6 +112,10 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append(f"{indent} Suggested Fix:") + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/plugins/app_debug.py b/bandit/plugins/app_debug.py index 3b18996fe..c64f206ee 100644 --- a/bandit/plugins/app_debug.py +++ b/bandit/plugins/app_debug.py @@ -52,6 +52,8 @@ def flask_debug_true(context): if context.is_module_imported_like("flask"): if context.call_function_name_qual.endswith(".run"): if context.check_call_arg_value("debug", "True"): + context.node.keywords[0].value.value = False + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -60,4 +62,5 @@ def flask_debug_true(context): "which exposes the Werkzeug debugger and allows " "the execution of arbitrary code.", lineno=context.get_lineno_for_call_arg("debug"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/crypto_request_no_cert_validation.py b/bandit/plugins/crypto_request_no_cert_validation.py index 11791ed1e..f42a40827 100644 --- a/bandit/plugins/crypto_request_no_cert_validation.py +++ b/bandit/plugins/crypto_request_no_cert_validation.py @@ -65,6 +65,8 @@ def request_with_no_cert_validation(context): and context.call_function_name in HTTPX_ATTRS ): if context.check_call_arg_value("verify", "False"): + context.node.keywords[0].value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -72,4 +74,5 @@ def request_with_no_cert_validation(context): text=f"Call to {qualname} with verify=False disabling SSL " "certificate checks, security issue.", lineno=context.get_lineno_for_call_arg("verify"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/hashlib_insecure_functions.py b/bandit/plugins/hashlib_insecure_functions.py index 710800a72..49bc4eba8 100644 --- a/bandit/plugins/hashlib_insecure_functions.py +++ b/bandit/plugins/hashlib_insecure_functions.py @@ -48,6 +48,7 @@ Added check for the crypt module weak hashes """ # noqa: E501 +import ast import sys import bandit @@ -58,6 +59,18 @@ WEAK_CRYPT_HASHES = ("METHOD_CRYPT", "METHOD_MD5", "METHOD_BLOWFISH") +def transform(node): + found = False + for keyword in node.keywords: + if keyword.arg == "usedforsecurity": + keyword.value.value = False + found = True + if not found: + keyword = ast.keyword("usedforsecurity", ast.Constant(False)) + node.keywords.append(keyword) + return node + + def _hashlib_func(context, func): keywords = context.call_keywords @@ -70,6 +83,7 @@ def _hashlib_func(context, func): text=f"Use of weak {func.upper()} hash for security. " "Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(context.node), ) elif func == "new": args = context.call_args @@ -83,6 +97,7 @@ def _hashlib_func(context, func): text=f"Use of weak {name.upper()} hash for " "security. Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(context.node), ) @@ -91,6 +106,8 @@ def _hashlib_new(context, func): args = context.call_args keywords = context.call_keywords name = args[0] if args else keywords.get("name", None) + if len(context.node.args): + context.node.args[0].value = "sha224" if isinstance(name, str) and name.lower() in WEAK_HASHES: return bandit.Issue( severity=bandit.MEDIUM, @@ -98,6 +115,7 @@ def _hashlib_new(context, func): cwe=issue.Cwe.BROKEN_CRYPTO, text=f"Use of insecure {name.upper()} hash function.", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/jinja2_templates.py b/bandit/plugins/jinja2_templates.py index 667b79d62..403cac1d9 100644 --- a/bandit/plugins/jinja2_templates.py +++ b/bandit/plugins/jinja2_templates.py @@ -85,6 +85,8 @@ def jinja2_autoescape_false(context): getattr(node.value, "id", None) == "False" or getattr(node.value, "value", None) is False ): + node.value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -94,6 +96,7 @@ def jinja2_autoescape_false(context): "Use autoescape=True or use the " "select_autoescape function to mitigate XSS " "vulnerabilities.", + fix=context.unparse(context.node), ) # found autoescape if getattr(node, "arg", None) == "autoescape": @@ -112,6 +115,8 @@ def jinja2_autoescape_false(context): ): return else: + node.value = ast.Constant(value=True, kind=None) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -121,9 +126,15 @@ def jinja2_autoescape_false(context): "Ensure autoescape=True or use the " "select_autoescape function to mitigate " "XSS vulnerabilities.", + fix=context.unparse(context.node), ) # We haven't found a keyword named autoescape, indicating default # behavior + keyword = ast.keyword( + "autoescape", ast.Constant(value=True, kind=None) + ) + context.node.keywords.append(keyword) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -131,4 +142,5 @@ def jinja2_autoescape_false(context): text="By default, jinja2 sets autoescape to False. Consider " "using autoescape=True or use the select_autoescape " "function to mitigate XSS vulnerabilities.", + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/ssh_no_host_key_verification.py b/bandit/plugins/ssh_no_host_key_verification.py index 51be2eb4a..d7f5e9806 100644 --- a/bandit/plugins/ssh_no_host_key_verification.py +++ b/bandit/plugins/ssh_no_host_key_verification.py @@ -64,6 +64,8 @@ def ssh_no_host_key_verification(context): policy_argument_value = policy_argument.func.id if policy_argument_value in ["AutoAddPolicy", "WarningPolicy"]: + context.node.args[0].attr = "RejectPolicy" + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -73,4 +75,5 @@ def ssh_no_host_key_verification(context): lineno=context.get_lineno_for_call_arg( "set_missing_host_key_policy" ), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/yaml_load.py b/bandit/plugins/yaml_load.py index 2304c1d7d..4f155dec1 100644 --- a/bandit/plugins/yaml_load.py +++ b/bandit/plugins/yaml_load.py @@ -66,6 +66,21 @@ def yaml_load(context): not context.get_call_arg_at_position(1) == "CSafeLoader", ] ): + if getattr(context.node.func, "attr", None) == "load": + context.node.func.attr = "safe_load" + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + elif getattr(context.node.func, "id", None) == "load": + # Suggesting a switch to safe_load won't work without the import. + # Therefore switch to a SafeLoader. + # TODO: fix this + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, @@ -73,4 +88,5 @@ def yaml_load(context): text="Use of unsafe yaml load. Allows instantiation of" " arbitrary objects. Consider yaml.safe_load().", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/examples/suggest_fix.py b/examples/suggest_fix.py new file mode 100644 index 000000000..eb71e1477 --- /dev/null +++ b/examples/suggest_fix.py @@ -0,0 +1,52 @@ +import hashlib + +import flask +import jinja2 +from paramiko import client +import requests +import yaml +from yaml import load +from yaml import Loader + +app = flask.Flask(__name__) + + +@app.route('/') +def main(): + # Test call within if statement + if requests.get('https://google.com', verify=False): + + # Test complex call within dict of multiple lines + yaml_dict = { + "first": yaml.load(""" +a: 1 +b: 2 +c: 3""" +), + } + + load("{}") # Test trailing comment + + # Newer PyYAML load() requires a Loader + load("{}", Loader=Loader) + + # Test multiple calls on same line + data = b"abcd" + print(hashlib.md4(data), + hashlib.md5(data), hashlib.sha(data), + hashlib.sha1(data)) + + # Test a call over multiple lines + ssh_client = client.SSHClient() + ssh_client.set_missing_host_key_policy( + client.AutoAddPolicy # This comment will get lost + ) + + jinja2.Environment(loader=templateLoader, + load=templateLoader) + +if debug: + app.run() +else: + app.run(debug=True) +main()