diff --git a/bandit/core/context.py b/bandit/core/context.py index 8a2d4fbb..84fb7b97 100644 --- a/bandit/core/context.py +++ b/bandit/core/context.py @@ -3,6 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 import ast +import linecache +import sys from bandit.core import utils @@ -311,6 +313,48 @@ def is_module_imported_like(self, module): return True return False + def get_outer_text(self): + """Get the text to the left and right of the node in context. + + Gets the text to the left and text to the right of the node in + context. This function depends on knowing the line range, col_offset, + and end_col_offset. + + :return: outer text as tuple + """ + lineno = self._context.get("linerange")[0] + end_lineno = self._context.get("linerange")[-1] + col_offset = self._context.get("col_offset") + end_col_offset = self._context.get("end_col_offset") + + if self._context.get("filename") == "": + self._context.get("file_data").seek(0) + for line_num in range(1, lineno): + self._context.get("file_data").readline() + line = self._context.get("file_data").readline() + end_line = line + if end_lineno > lineno: + for line_num in range(1, end_lineno): + self._context.get("file_data").readline() + end_line = self._context.get("file_data").readline() + else: + line = linecache.getline(self._context.get("filename"), lineno) + end_line = linecache.getline( + self._context.get("filename"), end_lineno + ) + + return (line[:col_offset], end_line[end_col_offset:]) + + def unparse(self, transformer): + """Unparse an ast node using given transformer + + :param transformer: NodeTransformer that fixes the ast + :return: node as statement string + """ + fixed_node = ast.fix_missing_locations(transformer) + outer_text = self.get_outer_text() + return outer_text[0] + ast.unparse(fixed_node) + outer_text[1] + @property def filename(self): return self._context.get("filename") diff --git a/bandit/core/issue.py b/bandit/core/issue.py index bfa58335..767ff783 100644 --- a/bandit/core/issue.py +++ b/bandit/core/issue.py @@ -87,6 +87,7 @@ def __init__( test_id="", col_offset=-1, end_col_offset=0, + fix=None, ): self.severity = severity self.cwe = Cwe(cwe) @@ -103,6 +104,7 @@ def __init__( self.col_offset = col_offset self.end_col_offset = end_col_offset self.linerange = [] + self.fix = fix def __str__(self): return ( @@ -195,7 +197,7 @@ def get_code(self, max_lines=3, tabbed=False): if not len(text): break lines.append(tmplt % (line, text)) - return "".join(lines) + return "".join(lines).rstrip() def as_dict(self, with_code=True, max_lines=3): """Convert the issue to a dict of values for outputting.""" @@ -215,6 +217,8 @@ def as_dict(self, with_code=True, max_lines=3): if with_code: out["code"] = self.get_code(max_lines=max_lines) + if self.fix: + out["fix"] = self.fix return out def from_dict(self, data, with_code=True): @@ -230,6 +234,7 @@ def from_dict(self, data, with_code=True): self.linerange = data["line_range"] self.col_offset = data.get("col_offset", 0) self.end_col_offset = data.get("end_col_offset", 0) + self.fix = data.get("fix") def cwe_from_dict(data): diff --git a/bandit/core/node_visitor.py b/bandit/core/node_visitor.py index 938e8733..38e1825f 100644 --- a/bandit/core/node_visitor.py +++ b/bandit/core/node_visitor.py @@ -13,7 +13,7 @@ LOG = logging.getLogger(__name__) -class BanditNodeVisitor: +class BanditNodeVisitor(ast.NodeTransformer): def __init__( self, fname, fdata, metaast, testset, debug, nosec_lines, metrics ): diff --git a/bandit/formatters/html.py b/bandit/formatters/html.py index fb09f835..91f0bfab 100644 --- a/bandit/formatters/html.py +++ b/bandit/formatters/html.py @@ -270,9 +270,19 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): Line number: {line_number}
More info: {url}
{code} + Suggested Fix:
+{fix} {candidates} +""" + + fix_block = """ +
+
+{fix}
+
+
""" code_block = """ @@ -357,6 +367,9 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates = candidate_block.format(candidate_list=candidates_str) url = docs_utils.get_url(issue.test_id) + fix = ( + fix_block.format(fix=html_escape(issue.fix)) if issue.fix else None + ) results_str += issue_block.format( issue_no=index, issue_class=f"issue-sev-{issue.severity.lower()}", @@ -372,6 +385,7 @@ def report(manager, fileobj, sev_level, conf_level, lines=-1): candidates=candidates, url=url, line_number=issue.lineno, + fix=fix, ) # build the metrics string to insert in the report diff --git a/bandit/formatters/screen.py b/bandit/formatters/screen.py index 906e86b7..b4b95670 100644 --- a/bandit/formatters/screen.py +++ b/bandit/formatters/screen.py @@ -146,6 +146,13 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append( + f"{indent} {COLOR[issue.severity]}" + f"Suggested Fix:{COLOR['DEFAULT']}" + ) + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/formatters/text.py b/bandit/formatters/text.py index e6918e37..ddcf77b0 100644 --- a/bandit/formatters/text.py +++ b/bandit/formatters/text.py @@ -112,6 +112,10 @@ def _output_issue_str( [indent + line for line in issue.get_code(lines, True).split("\n")] ) + if issue.fix: + bits.append(f"{indent} Suggested Fix:") + bits.append(f"\t{issue.fix}") + return "\n".join([bit for bit in bits]) diff --git a/bandit/plugins/app_debug.py b/bandit/plugins/app_debug.py index 3b18996f..c64f206e 100644 --- a/bandit/plugins/app_debug.py +++ b/bandit/plugins/app_debug.py @@ -52,6 +52,8 @@ def flask_debug_true(context): if context.is_module_imported_like("flask"): if context.call_function_name_qual.endswith(".run"): if context.check_call_arg_value("debug", "True"): + context.node.keywords[0].value.value = False + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -60,4 +62,5 @@ def flask_debug_true(context): "which exposes the Werkzeug debugger and allows " "the execution of arbitrary code.", lineno=context.get_lineno_for_call_arg("debug"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/crypto_request_no_cert_validation.py b/bandit/plugins/crypto_request_no_cert_validation.py index 11791ed1..f42a4082 100644 --- a/bandit/plugins/crypto_request_no_cert_validation.py +++ b/bandit/plugins/crypto_request_no_cert_validation.py @@ -65,6 +65,8 @@ def request_with_no_cert_validation(context): and context.call_function_name in HTTPX_ATTRS ): if context.check_call_arg_value("verify", "False"): + context.node.keywords[0].value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -72,4 +74,5 @@ def request_with_no_cert_validation(context): text=f"Call to {qualname} with verify=False disabling SSL " "certificate checks, security issue.", lineno=context.get_lineno_for_call_arg("verify"), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/hashlib_insecure_functions.py b/bandit/plugins/hashlib_insecure_functions.py index 710800a7..49bc4eba 100644 --- a/bandit/plugins/hashlib_insecure_functions.py +++ b/bandit/plugins/hashlib_insecure_functions.py @@ -48,6 +48,7 @@ Added check for the crypt module weak hashes """ # noqa: E501 +import ast import sys import bandit @@ -58,6 +59,18 @@ WEAK_CRYPT_HASHES = ("METHOD_CRYPT", "METHOD_MD5", "METHOD_BLOWFISH") +def transform(node): + found = False + for keyword in node.keywords: + if keyword.arg == "usedforsecurity": + keyword.value.value = False + found = True + if not found: + keyword = ast.keyword("usedforsecurity", ast.Constant(False)) + node.keywords.append(keyword) + return node + + def _hashlib_func(context, func): keywords = context.call_keywords @@ -70,6 +83,7 @@ def _hashlib_func(context, func): text=f"Use of weak {func.upper()} hash for security. " "Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(context.node), ) elif func == "new": args = context.call_args @@ -83,6 +97,7 @@ def _hashlib_func(context, func): text=f"Use of weak {name.upper()} hash for " "security. Consider usedforsecurity=False", lineno=context.node.lineno, + fix=context.unparse(context.node), ) @@ -91,6 +106,8 @@ def _hashlib_new(context, func): args = context.call_args keywords = context.call_keywords name = args[0] if args else keywords.get("name", None) + if len(context.node.args): + context.node.args[0].value = "sha224" if isinstance(name, str) and name.lower() in WEAK_HASHES: return bandit.Issue( severity=bandit.MEDIUM, @@ -98,6 +115,7 @@ def _hashlib_new(context, func): cwe=issue.Cwe.BROKEN_CRYPTO, text=f"Use of insecure {name.upper()} hash function.", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/jinja2_templates.py b/bandit/plugins/jinja2_templates.py index 667b79d6..403cac1d 100644 --- a/bandit/plugins/jinja2_templates.py +++ b/bandit/plugins/jinja2_templates.py @@ -85,6 +85,8 @@ def jinja2_autoescape_false(context): getattr(node.value, "id", None) == "False" or getattr(node.value, "value", None) is False ): + node.value.value = True + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -94,6 +96,7 @@ def jinja2_autoescape_false(context): "Use autoescape=True or use the " "select_autoescape function to mitigate XSS " "vulnerabilities.", + fix=context.unparse(context.node), ) # found autoescape if getattr(node, "arg", None) == "autoescape": @@ -112,6 +115,8 @@ def jinja2_autoescape_false(context): ): return else: + node.value = ast.Constant(value=True, kind=None) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -121,9 +126,15 @@ def jinja2_autoescape_false(context): "Ensure autoescape=True or use the " "select_autoescape function to mitigate " "XSS vulnerabilities.", + fix=context.unparse(context.node), ) # We haven't found a keyword named autoescape, indicating default # behavior + keyword = ast.keyword( + "autoescape", ast.Constant(value=True, kind=None) + ) + context.node.keywords.append(keyword) + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.HIGH, @@ -131,4 +142,5 @@ def jinja2_autoescape_false(context): text="By default, jinja2 sets autoescape to False. Consider " "using autoescape=True or use the select_autoescape " "function to mitigate XSS vulnerabilities.", + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/ssh_no_host_key_verification.py b/bandit/plugins/ssh_no_host_key_verification.py index 51be2eb4..d7f5e980 100644 --- a/bandit/plugins/ssh_no_host_key_verification.py +++ b/bandit/plugins/ssh_no_host_key_verification.py @@ -64,6 +64,8 @@ def ssh_no_host_key_verification(context): policy_argument_value = policy_argument.func.id if policy_argument_value in ["AutoAddPolicy", "WarningPolicy"]: + context.node.args[0].attr = "RejectPolicy" + return bandit.Issue( severity=bandit.HIGH, confidence=bandit.MEDIUM, @@ -73,4 +75,5 @@ def ssh_no_host_key_verification(context): lineno=context.get_lineno_for_call_arg( "set_missing_host_key_policy" ), + fix=context.unparse(context.node), ) diff --git a/bandit/plugins/yaml_load.py b/bandit/plugins/yaml_load.py index 2304c1d7..4f155dec 100644 --- a/bandit/plugins/yaml_load.py +++ b/bandit/plugins/yaml_load.py @@ -66,6 +66,21 @@ def yaml_load(context): not context.get_call_arg_at_position(1) == "CSafeLoader", ] ): + if getattr(context.node.func, "attr", None) == "load": + context.node.func.attr = "safe_load" + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + elif getattr(context.node.func, "id", None) == "load": + # Suggesting a switch to safe_load won't work without the import. + # Therefore switch to a SafeLoader. + # TODO: fix this + for keyword in context.node.keywords: + if keyword.arg == "Loader": + context.node.keywords.remove(keyword) + break + return bandit.Issue( severity=bandit.MEDIUM, confidence=bandit.HIGH, @@ -73,4 +88,5 @@ def yaml_load(context): text="Use of unsafe yaml load. Allows instantiation of" " arbitrary objects. Consider yaml.safe_load().", lineno=context.node.lineno, + fix=context.unparse(context.node), ) diff --git a/examples/suggest_fix.py b/examples/suggest_fix.py new file mode 100644 index 00000000..eb71e147 --- /dev/null +++ b/examples/suggest_fix.py @@ -0,0 +1,52 @@ +import hashlib + +import flask +import jinja2 +from paramiko import client +import requests +import yaml +from yaml import load +from yaml import Loader + +app = flask.Flask(__name__) + + +@app.route('/') +def main(): + # Test call within if statement + if requests.get('https://google.com', verify=False): + + # Test complex call within dict of multiple lines + yaml_dict = { + "first": yaml.load(""" +a: 1 +b: 2 +c: 3""" +), + } + + load("{}") # Test trailing comment + + # Newer PyYAML load() requires a Loader + load("{}", Loader=Loader) + + # Test multiple calls on same line + data = b"abcd" + print(hashlib.md4(data), + hashlib.md5(data), hashlib.sha(data), + hashlib.sha1(data)) + + # Test a call over multiple lines + ssh_client = client.SSHClient() + ssh_client.set_missing_host_key_policy( + client.AutoAddPolicy # This comment will get lost + ) + + jinja2.Environment(loader=templateLoader, + load=templateLoader) + +if debug: + app.run() +else: + app.run(debug=True) +main()