From 6f1c4fd548e380b5105f90316d2dd30510a45f2c Mon Sep 17 00:00:00 2001 From: Jim Carciofini Date: Mon, 20 Jan 2025 11:50:19 -0600 Subject: [PATCH] GUI: Fix graph node focus logic to properly resolve when multiple function call invocations. --- pate_binja/pate.py | 44 +++++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/pate_binja/pate.py b/pate_binja/pate.py index 9348adda..f3b3ec1f 100644 --- a/pate_binja/pate.py +++ b/pate_binja/pate.py @@ -234,15 +234,16 @@ def markFocusNodes(self, cfar_graph: CFARGraph) -> None: rec = self.next_json() while True: - this = rec.get('this') - if this == '': + if rec.get('this') == '': break + desc = get_node_desc(rec) + # TODO: desc? Or use data? - nodes = (n for n in cfar_graph.nodes.values() if n.desc == this) + nodes = (n for n in cfar_graph.nodes.values() if n.desc == desc) node = next(nodes, None) if next(nodes, None): - print(f'WARNING: Multiple nodes match: {this}') + print(f'WARNING: Multiple nodes match: {desc}') if rec.get('trace_node_kind') == 'blocktarget': if node: @@ -426,7 +427,7 @@ def extract_graph_rec(self, if rec['trace_node_kind'] == 'node': id = get_graph_node_id(trace_node) existing_cfar_node = cfar_graph.get(id) - cfar_node = cfar_graph.add_node(id, this, rec) + cfar_node = cfar_graph.add_node(id, rec) # Look for observable difference trace for this node for n in rec['trace_node_contents']: @@ -438,7 +439,7 @@ def extract_graph_rec(self, elif rec['trace_node_kind'] == 'blocktarget': id = get_blocktarget_id(rec, context, cfar_parent) existing_cfar_node = cfar_graph.get(id) - cfar_node = cfar_graph.add_node(id, this, rec) + cfar_node = cfar_graph.add_node(id, rec) # connect block target (ie exit) to parent cfar_exit = cfar_node @@ -484,7 +485,7 @@ def extract_graph_rec(self, exit_id = get_exit_id(trace_node, context) # TODO: Better way to detect this? if not(exit_id.startswith('None') or exit_id.startswith('return_target')): - exit_node = cfar_graph.add_node(exit_id, 'junk', {}) + exit_node = cfar_graph.add_node(exit_id, {}) cfar_node.addExit(exit_node) if self.debug_cfar: print('CFAR ID (exit):', exit_id) @@ -815,12 +816,12 @@ def prettyLoc(self, loc: dict) -> str: class CFARNode: exits: list[CFARNode] - def __init__(self, id: str, desc: str, data: dict): + def __init__(self, id: str, data: dict): self.id = id (self.original_addr, self.patched_addr) = get_cfar_addr(id) self.exits = [] self.exit_meta_data = {} - self.desc = desc + self.desc = None self.data = data self.predomain = None self.postdomain = None @@ -835,10 +836,10 @@ def __init__(self, id: str, desc: str, data: dict): self.assumedConditionTrace = None # After default initializations above, update the node - self.update_node(desc, data) + self.update_node(data) - def update_node(self, desc: str, data: dict): - self.desc = desc + def update_node(self, data: dict): + self.desc = get_node_desc(data) self.data = data self.predomain = get_domain(data, 'Predomain') self.postdomain = get_domain(data, 'Postdomain') @@ -945,16 +946,16 @@ def getEqCondNodes(self): nodes.append(n) return nodes - def add_node(self, id: str, desc: str, data) -> CFARNode: + def add_node(self, id: str, data) -> CFARNode: """Add node, creating if node with ID does not exist.""" node = self.nodes.get(id) if not node: - node = CFARNode(id, desc, data) + node = CFARNode(id, data) self.nodes[node.id] = node else: # update with most recent data - node.update_node(desc, data) + node.update_node(data) return node def pprint(self): @@ -1040,6 +1041,19 @@ def extractTraceVars(condition: ConditionTrace) -> list[TraceVar]: traceVars.reverse() return traceVars + +def get_node_desc(rec: dict) -> str: + desc = rec.get('this', 'no description') + flag = False + for tnc in rec.get('trace_node_contents', []): + if flag: + desc = desc + ' via ' + tnc.get('pretty', 'no ancestor path') + break + if tnc.get('pretty') == 'Instruction Paths to Exit': + flag = True + return desc + + def get_cfar_addr(cfar_id: str) -> tuple[Optional[int], Optional[int]]: """Get CFAR original and patched address""" parts = cfar_id.split(' vs ')