From f4d553562ddaca04976d7778d1c5197d9efd08bc Mon Sep 17 00:00:00 2001 From: Eric Brown Date: Tue, 22 Feb 2022 17:54:52 -0800 Subject: [PATCH] Suggest fixes for issues This change introduces a new feature that will suggest a fix in the form of a line of code as replacement for the line range of the issue. This is the first step to have the ability to auto-correct problems detected. Later more changes can be merged to modify the file with the suggested fix. The Issue class has a new fix string attribute that denotes how the lines of affected code can be replaced. This suggested fix will not preserve code comments and possibly other optimizations the AST does not capture. Closes #439 Signed-off-by: Eric Brown Signed-off-by: Eric Brown --- bandit/core/context.py | 44 ++++++++++++++++ bandit/core/issue.py | 7 ++- bandit/core/node_visitor.py | 2 +- bandit/formatters/html.py | 14 +++++ bandit/formatters/screen.py | 7 +++ bandit/formatters/text.py | 4 ++ bandit/plugins/app_debug.py | 3 ++ .../crypto_request_no_cert_validation.py | 3 ++ bandit/plugins/hashlib_insecure_functions.py | 18 +++++++ bandit/plugins/jinja2_templates.py | 12 +++++ .../plugins/ssh_no_host_key_verification.py | 3 ++ bandit/plugins/yaml_load.py | 16 ++++++ examples/suggest_fix.py | 52 +++++++++++++++++++ 13 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 examples/suggest_fix.py diff --git a/bandit/core/context.py b/bandit/core/context.py index 8a2d4fbb..84fb7b97 100644 --- a/bandit/core/context.py +++ b/bandit/core/context.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 import ast +import linecache +import sys from bandit.core import utils @@ -311,6 +313,48 @@ def is_module_imported_like(self, module): return True return False + def get_outer_text(self): + """Get the text to the left and right of the node in context. + + Gets the text to the left and text to the right of the node in + context. This function depends on knowing the line range, col_offset, + and end_col_offset. + + :return: outer text as tuple + """ + lineno = self._context.get("linerange")[0] + end_lineno = self._context.get("linerange")[-1] + col_offset = self._context.get("col_offset") + end_col_offset = self._context.get("end_col_offset") + + if self._context.get("filename") == "": + self._context.get("file_data").seek(0) + for line_num in range(1, lineno): + self._context.get("file_data").readline() + line = self._context.get("file_data").readline() + end_line = line + if end_lineno > lineno: + for line_num in range(1, end_lineno): + self._context.get("file_data").readline() + end_line = self._context.get("file_data").readline() + else: + line = linecache.getline(self._context.get("filename"), lineno) + end_line = linecache.getline( + self._context.get("filename"), end_lineno + ) + + return (line[:col_offset], end_line[end_col_offset:]) + + def unparse(self, transformer): + """Unparse an ast node using given transformer + + :param transformer: NodeTransformer that fixes the ast + :return: node as statement string + """ + fixed_node = ast.fix_missing_locations(transformer) + outer_text = self.get_outer_text() + return outer_text[0] + ast.unparse(fixed_node) + outer_text[1] + @property def filename(self): return self._context.get("filename") diff --git a/bandit/core/issue.py b/bandit/core/issue.py index bfa58335..767ff783 100644 --- a/bandit/core/issue.py +++ b/bandit/core/issue.py @@ -87,6 +87,7 @@ def __init__( test_id="", col_offset=-1, end_col_offset=0, + fix=None, ): self.severity = severity self.cwe = Cwe(cwe) @@ -103,6 +104,7 @@ def __init__( self.col_offset = col_offset self.end_col_offset = end_col_offset self.linerange = [] + self.fix = fix def __str__(self): return ( @@ -195,7 +197,7 @@ def get_code(self, max_lines=3, tabbed=False): if not len(text): break lines.append(tmplt % (line, text)) - return "".join(lines) + return "".join(lines).rstrip() def as_dict(self, with_code=True, max_lines=3): """Convert the issue to a dict of values for outputting.""" @@ -215,6 +217,8 @@ def as_dict(self, with_code=True, max_lines=3): if with_code: out["code"] = self.get_code(max_lines=max_lines) + if self.fix: + out["fix"] = self.fix return out def from_dict(self, data, with_code=True): @@ -230,6 +234,7 @@ def from_dict(self, data, with_code=True): self.linerange = data["line_range"] self.col_offset = data.get("col_offset", 0) self.end_col_offset = data.get("end_col_offset", 0) + self.fix = data.get("fix") def cwe_from_dict(data): diff --git a/bandit/core/node_visitor.py b/bandit/core/node_visitor.py index 938e8733..38e1825f 100644 --- a/bandit/core/node_visitor.py +++ b/bandit/core/node_visitor.py @@ -13,7 +13,7 @@ LOG = logging.getLogger(__name__) -class BanditNodeVisitor: +class BanditNodeVisitor(ast.NodeTransformer): def __init__( self, fname, fdata, metaast, testset, debug, nosec_lines, metrics ): diff --git a/bandit/formatters/html.py b/bandit/formatters/html.py index fb09f835..91f0bfab 100644 --- a/bandit/formatters/html.py +++ b/bandit/formatters/html.py @@ -270,9 +270,19 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): Line number: {line_number}
More info: {url}
{code} + Suggested Fix:
+{fix} {candidates} +""" + + fix_block = """ +
+
+{fix}
+
+
""" code_block = """ @@ -357,6 +367,9 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates = candidate_block.format(candidate_list=candidates_str) url = docs_utils.get_url(issue.test_id) + fix = ( + fix_block.format(fix=html_escape(issue.fix)) if issue.fix else None + ) results_str += issue_block.format( issue_no=index, issue_class=f"issue-sev-{issue.severity.lower()}", @@ -372,6 +385,7 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates=candidates, url=url, line_number=issue.lineno, + fix=fix, ) # build the metrics string to insert in the report diff --git a/bandit/formatters/screen.py b/bandit/formatters/screen.py index 906e86b7..b4b95670 100644 --- a/bandit/formatters/screen.py +++ b/bandit/formatters/screen.py @@ -146,6 +146,13 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append( + f"{indent} {COLOR[issue.severity]}" + f"Suggested Fix:{COLOR['DEFAULT']}" + ) + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/formatters/text.py b/bandit/formatters/text.py index e6918e37..ddcf77b0 100644 --- a/bandit/formatters/text.py +++ b/bandit/formatters/text.py @@ -112,6 +112,10 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append(f"{indent} Suggested Fix:") + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/plugins/app_debug.py b/bandit/plugins/app_debug.py index 3b18996f..c64f206e 100644 --- a/bandit/plugins/app_debug.py +++ b/bandit/plugins/app_debug.py @@ -52,6 +52,8 @@ def flask_debug_true(context): if context.is_module_imported_like("flask"): if context.call_function_name_qual.endswith(".run"): if context.check_call_arg_value("debug", "True"): + context.node.keywords[0].value.value = False + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -60,4 +62,5 @@ def flask_debug_true(context): "which exposes the Werkzeug debugger and allows " "the execution of arbitrary code.", lineno=context.get_lineno_for_call_arg("debug"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/crypto_request_no_cert_validation.py b/bandit/plugins/crypto_request_no_cert_validation.py index 11791ed1..f42a4082 100644 --- a/bandit/plugins/crypto_request_no_cert_validation.py +++ b/bandit/plugins/crypto_request_no_cert_validation.py @@ -65,6 +65,8 @@ def request_with_no_cert_validation(context): and context.call_function_name in HTTPX_ATTRS ): if context.check_call_arg_value("verify", "False"): + context.node.keywords[0].value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -72,4 +74,5 @@ def request_with_no_cert_validation(context): text=f"Call to {qualname} with verify=False disabling SSL " "certificate checks, security issue.", lineno=context.get_lineno_for_call_arg("verify"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/hashlib_insecure_functions.py b/bandit/plugins/hashlib_insecure_functions.py index 710800a7..49bc4eba 100644 --- a/bandit/plugins/hashlib_insecure_functions.py +++ b/bandit/plugins/hashlib_insecure_functions.py @@ -48,6 +48,7 @@ Added check for the crypt module weak hashes """ # noqa: E501 +import ast import sys import bandit @@ -58,6 +59,18 @@ WEAK_CRYPT_HASHES = ("METHOD_CRYPT", "METHOD_MD5", "METHOD_BLOWFISH") +def transform(node): + found = False + for keyword in node.keywords: + if keyword.arg == "usedforsecurity": + keyword.value.value = False + found = True + if not found: + keyword = ast.keyword("usedforsecurity", ast.Constant(False)) + node.keywords.append(keyword) + return node + + def _hashlib_func(context, func): keywords = context.call_keywords @@ -70,6 +83,7 @@ def _hashlib_func(context, func): text=f"Use of weak {func.upper()} hash for security. " "Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(context.node), ) elif func == "new": args = context.call_args @@ -83,6 +97,7 @@ def _hashlib_func(context, func): text=f"Use of weak {name.upper()} hash for " "security. Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(context.node), ) @@ -91,6 +106,8 @@ def _hashlib_new(context, func): args = context.call_args keywords = context.call_keywords name = args[0] if args else keywords.get("name", None) + if len(context.node.args): + context.node.args[0].value = "sha224" if isinstance(name, str) and name.lower() in WEAK_HASHES: return bandit.Issue( severity=bandit.MEDIUM, @@ -98,6 +115,7 @@ def _hashlib_new(context, func): cwe=issue.Cwe.BROKEN_CRYPTO, text=f"Use of insecure {name.upper()} hash function.", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/jinja2_templates.py b/bandit/plugins/jinja2_templates.py index 667b79d6..403cac1d 100644 --- a/bandit/plugins/jinja2_templates.py +++ b/bandit/plugins/jinja2_templates.py @@ -85,6 +85,8 @@ def jinja2_autoescape_false(context): getattr(node.value, "id", None) == "False" or getattr(node.value, "value", None) is False ): + node.value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -94,6 +96,7 @@ def jinja2_autoescape_false(context): "Use autoescape=True or use the " "select_autoescape function to mitigate XSS " "vulnerabilities.", + fix=context.unparse(context.node), ) # found autoescape if getattr(node, "arg", None) == "autoescape": @@ -112,6 +115,8 @@ def jinja2_autoescape_false(context): ): return else: + node.value = ast.Constant(value=True, kind=None) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -121,9 +126,15 @@ def jinja2_autoescape_false(context): "Ensure autoescape=True or use the " "select_autoescape function to mitigate " "XSS vulnerabilities.", + fix=context.unparse(context.node), ) # We haven't found a keyword named autoescape, indicating default # behavior + keyword = ast.keyword( + "autoescape", ast.Constant(value=True, kind=None) + ) + context.node.keywords.append(keyword) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -131,4 +142,5 @@ def jinja2_autoescape_false(context): text="By default, jinja2 sets autoescape to False. Consider " "using autoescape=True or use the select_autoescape " "function to mitigate XSS vulnerabilities.", + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/ssh_no_host_key_verification.py b/bandit/plugins/ssh_no_host_key_verification.py index 51be2eb4..d7f5e980 100644 --- a/bandit/plugins/ssh_no_host_key_verification.py +++ b/bandit/plugins/ssh_no_host_key_verification.py @@ -64,6 +64,8 @@ def ssh_no_host_key_verification(context): policy_argument_value = policy_argument.func.id if policy_argument_value in ["AutoAddPolicy", "WarningPolicy"]: + context.node.args[0].attr = "RejectPolicy" + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -73,4 +75,5 @@ def ssh_no_host_key_verification(context): lineno=context.get_lineno_for_call_arg( "set_missing_host_key_policy" ), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/yaml_load.py b/bandit/plugins/yaml_load.py index 2304c1d7..4f155dec 100644 --- a/bandit/plugins/yaml_load.py +++ b/bandit/plugins/yaml_load.py @@ -66,6 +66,21 @@ def yaml_load(context): not context.get_call_arg_at_position(1) == "CSafeLoader", ] ): + if getattr(context.node.func, "attr", None) == "load": + context.node.func.attr = "safe_load" + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + elif getattr(context.node.func, "id", None) == "load": + # Suggesting a switch to safe_load won't work without the import. + # Therefore switch to a SafeLoader. + # TODO: fix this + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, @@ -73,4 +88,5 @@ def yaml_load(context): text="Use of unsafe yaml load. Allows instantiation of" " arbitrary objects. Consider yaml.safe_load().", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/examples/suggest_fix.py b/examples/suggest_fix.py new file mode 100644 index 00000000..eb71e147 --- /dev/null +++ b/examples/suggest_fix.py @@ -0,0 +1,52 @@ +import hashlib + +import flask +import jinja2 +from paramiko import client +import requests +import yaml +from yaml import load +from yaml import Loader + +app = flask.Flask(__name__) + + +@app.route('/') +def main(): + # Test call within if statement + if requests.get('https://google.com', verify=False): + + # Test complex call within dict of multiple lines + yaml_dict = { + "first": yaml.load(""" +a: 1 +b: 2 +c: 3""" +), + } + + load("{}") # Test trailing comment + + # Newer PyYAML load() requires a Loader + load("{}", Loader=Loader) + + # Test multiple calls on same line + data = b"abcd" + print(hashlib.md4(data), + hashlib.md5(data), hashlib.sha(data), + hashlib.sha1(data)) + + # Test a call over multiple lines + ssh_client = client.SSHClient() + ssh_client.set_missing_host_key_policy( + client.AutoAddPolicy # This comment will get lost + ) + + jinja2.Environment(loader=templateLoader, + load=templateLoader) + +if debug: + app.run() +else: + app.run(debug=True) +main()