Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add riscv Architecture and cfg show address & source code #37

Open
wants to merge 3 commits into
base: development-update-readme
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
x213212
Add tarjan algorithm Detection domator tree find loops

Add experimental color cfg
x213212 committed Dec 2, 2022
commit edf2e2692461221c7dc3b0f05e87d1b442ef749f
350 changes: 346 additions & 4 deletions src/asm2cfg/asm2cfg.py
Original file line number Diff line number Diff line change
@@ -5,15 +5,106 @@
import re
import sys
import tempfile


from . import utils
from collections import defaultdict
from graphviz import Digraph

from colour import Color

# TODO: make this a command-line flag
VERBOSE = 0


#This class represents an directed graph
# using adjacency list representation
class Graph:

def __init__(self,vertices,):
#No. of vertices
self.V= vertices

self.ans = []
# default dictionary to store graph
self.graph = defaultdict(list)

self.Time = 0

# function to add an edge to graph
def addEdge(self,u,v):
self.graph[u].append(v)


'''A recursive function that find finds and prints strongly connected
components using DFS traversal
u --> The vertex to be visited next
disc[] --> Stores discovery times of visited vertices
low[] -- >> earliest visited vertex (the vertex with minimum
discovery time) that can be reached from subtree
rooted with current vertex
st -- >> To store all the connected ancestors (could be part
of SCC)
stackMember[] --> bit/index array for faster check whether
a node is in stack
'''
def SCCUtil(self,u, low, disc, stackMember, st):

# Initialize discovery time and low value
disc[u] = self.Time
low[u] = self.Time
self.Time += 1
stackMember[u] = True
st.append(u)

# Go through all vertices adjacent to this
for v in self.graph[u]:

# If v is not visited yet, then recur for it
if disc[v] == -1 :

self.SCCUtil(v, low, disc, stackMember, st)

# Check if the subtree rooted with v has a connection to
# one of the ancestors of u
# Case 1 (per above discussion on Disc and Low value)
low[u] = min(low[u], low[v])

elif stackMember[v] == True:

'''Update low value of 'u' only if 'v' is still in stack
(i.e. it's a back edge, not cross edge).
Case 2 (per above discussion on Disc and Low value) '''
low[u] = min(low[u], disc[v])

# head node found, pop the stack and print an SCC
w = -1 #To store stack extracted vertices
anscell = []
if low[u] == disc[u]:
while w != u:
w = st.pop()
print (w, end=" ")
anscell.append(w)
stackMember[w] = False
self.ans.append(anscell)
print()

#The function to do DFS traversal.
# It uses recursive SCCUtil()
def SCC(self):

# Mark all the vertices as not visited
# and Initialize parent and visited,
# and ap(articulation point) arrays
disc = [-1] * (self.V)
low = [-1] * (self.V)
stackMember = [False] * (self.V)
st=[]
# Call the recursive helper function
# to find articulation points
# in DFS tree rooted with vertex 'i'
# time_disc=[]
for i in range(self.V):
if disc[i] == -1:
self.SCCUtil(i, low, disc, stackMember, st)

def escape(instruction):
"""
Escape used dot graph characters in given instruction so they will be
@@ -739,4 +830,255 @@ def draw_cfg(function_name, basic_blocks, view):
else:
dot.format = 'pdf'
dot.render(filename=function_name, cleanup=True)
print(f'Saved CFG to a file {function_name}.{dot.format}')
print(f'Saved CFG to a file {function_name}.{dot.format}')


# bb_index_mapping
bb_index_mapping={

}
# bb_block_mapping
bb_block_mapping={

}
graph = {
# key:[value,value2],...
}
def draw_cfgdark(function_name,basic_blocks, view):
dot=None
dot = Digraph(name=function_name, comment=function_name, engine='dot')
# dot.graph_attr['rankdir'] = 'LR'
dot.attr('graph', label=function_name, fontsize="60", color="white", labelloc="t",bgcolor="#1e1e1e",splines= "polyline",nodesep="0.5",ranksep="2")

total_node =0
index =0
cur_bb_index=0
possible_unroll_block={}
for get_child in basic_blocks:
for address, basic_block in basic_blocks.items():
key = str(address)
graph[key] =[]
total_node+=1
for basic_block in basic_blocks.values():
bb_index_mapping[str(cur_bb_index)] =str(basic_block.key)
bb_block_mapping[str(cur_bb_index)]=basic_block
##############################
total_inst_count =0
unrollloop_inst={}
possible_byte = 0
max_inst_count = 0
print("==================")
print("next")
print("cur_bb:"+str(cur_bb_index ))
for i in basic_block.instructions:
total_inst_count+=1

if(str(i.text).find("debug")) <0:
if(i.opcode not in unrollloop_inst):
unrollloop_inst[i.opcode]=1
else:
unrollloop_inst[i.opcode]+=1
if(str(i.opcode) == str("ld")):
possible_byte=8

elif(str(i.opcode) == str("lw") ):
possible_byte=4

elif(str(i.opcode) == str("lh") or str(i.opcode) == str("lhu") ):
possible_byte=2

elif(str(i.opcode)== str("lb")):
possible_byte=1

possible_unroll_time=0
max_unroll_inst=0
inst_list=[]

if(possible_byte >0):
for inst in unrollloop_inst :
if(int(unrollloop_inst[inst]%possible_byte)==0 and int(unrollloop_inst[inst]/possible_byte)>= max_unroll_inst):
max_unroll_inst=unrollloop_inst[inst]
possible_unroll_time=unrollloop_inst[inst]/possible_byte
print(inst)
if(int(unrollloop_inst[inst]%possible_byte)==0 and unrollloop_inst[inst] > max_inst_count):
max_inst_count=unrollloop_inst[inst]
for inst in unrollloop_inst :
if(unrollloop_inst[inst]>=4 or unrollloop_inst[inst] == max_inst_count):
inst_list.append([inst,unrollloop_inst[inst]])

if possible_unroll_time>1 :
global info_str
info_str+=f"max_unroll_inst: {max_unroll_inst}\npossible_unroll_time:{possible_unroll_time}\nunrollloop_inst:{unrollloop_inst}\nunroll loop check"
print(info_str)

dot.node(str(basic_block.key)+str(cur_bb_index), shape='record', label=f'in bb {cur_bb_index} possible unrolling {int(possible_unroll_time)} time. \ntotal inst {inst_list}',style="filled",fillcolor="#1e1e1e",color="white",fontcolor='white')
dot.edge(f'{basic_block.key}:s0',str(basic_block.key)+str(cur_bb_index),color="#f5d166", penwidth="5")
info_str=""
possible_unroll_block[str(basic_block.key)]=inst_list
cur_bb_index+=1
cur_bb_index=0
break

cur_bb_index=0
tarjansans=[]

# Create a graph given in the above diagram
tarjans = Graph(total_node)
for get_child in basic_blocks:
for basic_block in basic_blocks.values():
if basic_block.jump_edge:
key=0
jkey=0
nojkey=0
if basic_block.no_jump_edge is not None:
for x in bb_index_mapping:
if(bb_index_mapping[x] == str(basic_block.key)):
key=int(x)
elif(bb_index_mapping[x] == str(basic_block.no_jump_edge)):
nojkey=int(x)
tarjans.addEdge(key, nojkey)

for x in bb_index_mapping:
if(bb_index_mapping[x] == str(basic_block.key)):
key=int(x)
elif(bb_index_mapping[x] == str(basic_block.jump_edge)):
jkey=int(x)

if(jkey>cur_bb_index ):
tarjans.addEdge(key, jkey)

elif basic_block.no_jump_edge:
key=0
nojkey=0
for x in bb_index_mapping:
if(bb_index_mapping[x] == str(basic_block.key)):
key=int(x)
elif(bb_index_mapping[x] == str(basic_block.no_jump_edge)):
nojkey=int(x)

tarjans.addEdge(key, nojkey)
cur_bb_index=0
break

print ("SSC in first graph ")
tarjans.SCC()
print(total_node)
print ("tarjans")
print(tarjans.ans)
count_tarjans_ans_index=0
count_tarjans_ans_index_max=0

for x in tarjans.ans:
if(len (x)>=2):
count_tarjans_ans_index+=1
count_tarjans_ans_index_max+=1

green = Color("#a0671d")
green2 = Color("#4d7825")
colors = list(green.range_to(Color("#21286e"),2))
colors2 = list(green2.range_to(Color("#782525"),2))
colorro = 0
for get_child in basic_blocks:
for address, basic_block in basic_blocks.items():
key = str(address)
tarjankey=0
find =0

for x in bb_index_mapping:
if(bb_index_mapping[x] == str(basic_block.key)):
tarjankey=int(x)
break
count_tarjans_ans_index=0
tarjans_same_block_list=[]
for x in tarjans.ans:
if(len (x)>=2):
if(tarjankey in x):
find=1
# print("tarjanskey:"+str(tarjankey))
tarjans_same_block_list=x
break
count_tarjans_ans_index+=1

if(key in possible_unroll_block):
if tarjankey in tarjans_same_block_list:
possible_unloop_block=0
for x in tarjans_same_block_list:
check_loop=-1
for y in bb_index_mapping:
if(str(x)==str(y)):
check_loop=y
print("same block:"+str(x)+":"+str(y))
unrollloop_inst={}

for inst in bb_block_mapping[str(check_loop)].instructions:
print(inst)
if(inst.opcode not in unrollloop_inst):
unrollloop_inst[inst.opcode]=1
else:
unrollloop_inst[inst.opcode]+=1
same_inst_check=0
same_inst_list=[]
for x in possible_unroll_block[str(basic_block.key)]:
print(x[0])
if(str(x[0]) in unrollloop_inst):
same_inst_check+=1
print(unrollloop_inst[str(x[0])])
same_inst_list.append([str(x[0]),unrollloop_inst[str(x[0])]])
# if (unrollloop_inst[str(x[0])]==x[1]):
# same_inst_check+=1
# loop same rate
if(bb_block_mapping[str(check_loop)].key !=basic_block.key ):
if(len(possible_unroll_block[str(basic_block.key)])>=2 and len(same_inst_list)>=2):
print("in loop possiable unrolling")
# for inst in bb_block_mapping[str(check_loop)].instructions:
# print(inst)
# if(inst.opcode in same_inst_list):
# inst.text="loop same inst"+inst.text
# print("same"+ inst.text)

dot.node(str(check_loop)+str(basic_block.key), shape='record', label=f'in bb {check_loop} find same dom bb {tarjankey} inst{ same_inst_list} .',style="filled",fillcolor="#1e1e1e",color="white",fontcolor='white')
dot.edge(f'{bb_block_mapping[str(check_loop)].key}:s0',str(check_loop)+str(basic_block.key),color="#f5d166", penwidth="5")
possible_unloop_block+=1

print(same_inst_check)
print(len(possible_unroll_block[str(basic_block.key)]))
print("==========================")
break
dot.node(str(tarjankey)+str(basic_block.key), shape='record', label=f'in bb {tarjankey} possible unrolling {int(possible_unloop_block)} time(same block). ',style="filled",fillcolor="#1e1e1e",color="white",fontcolor='white')
dot.edge(f'{basic_block.key}:s0',str(tarjankey)+str(basic_block.key),color="#f5d166", penwidth="5")

if(find==0):
dot.node(key, shape='record', label=basic_block.get_label(),style="filled",fillcolor="#1e1e1e", color="white",fontcolor='white')
else:
if(count_tarjans_ans_index%2==0):

dot.node(key, shape='record', label=basic_block.get_label(),style="filled",fillcolor=str(colors[1]),fontcolor='white')

else:

dot.node(key, shape='record', label=basic_block.get_label(),style="filled",fillcolor=str(colors[0]),fontcolor='white')


for basic_block in basic_blocks.values():
if basic_block.jump_edge:
if basic_block.no_jump_edge is not None:
dot.edge(f'{basic_block.key}:s0', str(basic_block.no_jump_edge),color="white", penwidth="5")
dot.edge(f'{basic_block.key}:s1', str(basic_block.jump_edge),color="#379cb0", penwidth="5",style="dashed")
elif basic_block.no_jump_edge:
dot.edge(f'{basic_block.key}:s0', str(basic_block.no_jump_edge),color="white", penwidth="5")
break

if view:
dot.format = 'gv'
with tempfile.NamedTemporaryFile(mode='w+b', prefix=function_name) as filename:
# dot.view(filename.name)
print(f'Opening a file {filename.name}.{dot.format} with default viewer. Don\'t forget to delete it later.')
else:
dot.format = 'svg'
dot.render(filename=function_name, cleanup=True)
print(f'Saved CFG to a file {function_name}.{dot.format}')
# redundant(f'back{function_name}.{dot.format}', f'new{function_name}.{dot.format}')
utils.replaceline_0(f'{function_name}.{dot.format}', f'back{function_name}.{dot.format}')
# utils.replaceline(f'back{function_name}.{dot.format}', f'back2{function_name}.{dot.format}')
# utils.replaceline2(f'back2{function_name}.{dot.format}', f'new{function_name}.{dot.format}')
print(f'Saved new CFG to a file back{function_name}.{dot.format}')
18 changes: 14 additions & 4 deletions src/asm2cfg/command_line.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

import argparse
from . import asm2cfg
from . import utils


def main():
@@ -16,12 +17,21 @@ def main():
help='File to contain one function assembly dump')
parser.add_argument('-c', '--skip-calls', action='store_true',
help='Skip function calls from dividing code to blocks')
parser.add_argument('--target', choices=['x86', 'arm','riscv'], default='riscv',
parser.add_argument('--target', choices=['x86', 'arm', 'riscv'], default='riscv',
help='Specify target platform for assembly')
parser.add_argument('-v', '--view', action='store_true',
help='View as a dot graph instead of saving to a file')
parser.add_argument('-er', '--eriscv', action='store_true',
help='If you use the RiscV architecture, you can try the new output format, you can figure out the cfg loop from the diagram.')
args = parser.parse_args()
print('If function CFG rendering takes too long, try to skip function calls with -c flag')
lines = asm2cfg.read_lines(args.assembly_file)
function_name, basic_blocks = asm2cfg.parse_lines(lines, args.skip_calls, args.target)
asm2cfg.draw_cfg(function_name, basic_blocks, args.view)
if(args.eriscv):
utils.delblankline(args.assembly_file, "./preproceessa.asm")
lines = asm2cfg.read_lines("./preproceessa.asm")
function_name, basic_blocks = asm2cfg.parse_lines(lines, args.skip_calls, args.target)
asm2cfg.draw_cfgdark(function_name, basic_blocks, args.view)
else:
lines = asm2cfg.read_lines(args.assembly_file)
function_name, basic_blocks = asm2cfg.parse_lines(lines, args.skip_calls, args.target)
asm2cfg.draw_cfg(function_name, basic_blocks, args.view)

77 changes: 77 additions & 0 deletions src/asm2cfg/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
asm format adjustment
"""

import re
import sys
import tempfile
from collections import defaultdict
from graphviz import Digraph
from colour import Color

def replaceline(infile, outfile):
infopen = open(infile, 'r', encoding="utf-8")
outfopen = open(outfile, 'w', encoding="utf-8")

lines = infopen.readlines()
for line in lines:
if line.split():
# print(str(line).replace(">debug", " fill='"'red'"'>"))
outfopen.writelines(str(line).replace(">debug", " fill='"'green'"'>"))
else:
outfopen.writelines("")

infopen.close()
outfopen.close()
def replaceline_0(infile, outfile):
infopen = open(infile, 'r', encoding="utf-8")
outfopen = open(outfile, 'w', encoding="utf-8")

lines = infopen.readlines()
for line in lines:
if line.split():
tmp = str(line)
tmp = tmp.replace("fill="+'"'+"#000000"+'"', "")
tmp = tmp.replace("fill="+'"'+"#ffffff"+'"', "")
# fill="#ffffff"
outfopen.writelines(tmp)

else:
outfopen.writelines("")

infopen.close()
outfopen.close()

def replaceline2(infile, outfile):
infopen = open(infile, 'r', encoding="utf-8")
outfopen = open(outfile, 'w', encoding="utf-8")

lines = infopen.readlines()
for line in lines:
if line.split():
# print(str(line).replace(">debug", " fill='"'red'"'>"))

# if(str(line).find("white") >=0):
if(str(line).find("green") <=0):
outfopen.writelines(str(line).replace(">0x","fill=\"white\">0x"))
else:
outfopen.writelines(str(line))

else:
outfopen.writelines("")

infopen.close()
outfopen.close()

def delblankline(infile, outfile):
infopen = open(infile, 'r', encoding="utf-8")
outfopen = open(outfile, 'w', encoding="utf-8")

lines = infopen.readlines()
for line in lines:
if line.split():
outfopen.writelines(line)
else:
outfopen.writelines("")
infopen.close()
outfopen.close()