Skip to content

Commit

Permalink
Add --unfused option (designed for fma)
Browse files Browse the repository at this point in the history
  • Loading branch information
lathuili-home committed Dec 16, 2024
1 parent db54265 commit bece845
Show file tree
Hide file tree
Showing 27 changed files with 10,436 additions and 1,514 deletions.
203 changes: 131 additions & 72 deletions generateBackendInterOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,74 @@
import sys
import re

FctNameRegExp=re.compile("(.*)FCTNAME\(([^,]*),([^)]*)\)(.*)")
FctConvNameRegExp=re.compile("(.*)FCTCONVNAME\(([^,]*),([^)]*)\)(.*)")
FctNameUnfusedRegExp=re.compile("(.*)FCTNAMEUNFUSED\(([^,]*),([^)]*)\)(.*)")
FctConvNameUnfusedRegExp=re.compile("(.*)FCTCONVNAMEUNFUSED\(([^,]*),([^)]*)\)(.*)")
BckNameRegExp=re.compile("(.*)BACKENDFUNC\(([^)]*)\)(.*)")
BckNameFirstRegExp=re.compile("(.*)BACKEND_FIRST_FUNC\(([^)]*)\)(.*)")
BckNameSecondRegExp=re.compile("(.*)BACKEND_SECOND_FUNC\(([^)]*)\)(.*)")
BckNameNearestRegExp=re.compile("(.*)BACKEND_NEAREST_FUNC\(([^)]*)\)(.*)")


def mergeFused(tab,tmpVar="res_temp"):
fusedTab=[]
for i in range(len(tab)):
line=tab[i]
if tmpVar+";" in line:
tab[i]= "//"+line

if "_FIRST_" in line:
prefix=tab[0:max(0,i-1)]
first=tab[i]
second=tab[i+1]
postfix=tab[i+2:]
break
if not "_SECOND_" in second:
print("Generation failure")
return [line for line in prefix] + [mergeTwoLines(first,second, tmpVar)] +[line for line in postfix]

def mergeTwoLines(first,second, varInter):
first=first.replace("_FIRST_","")
second=second.replace("_SECOND_","")
res=(first.partition("&"+varInter))[0] + (second.partition(varInter+","))[2]
return res


def treatBackend(tab, soft):
if soft==False:
return tab
else:
res=["if(vr.instrument_soft){\n"]
res+=tab
res+=["}else{\n"]
res+=[line.replace("BACKENDFUNC", "BACKEND_NEAREST_FUNC").replace("CONTEXT","BACKEND_NEAREST_CONTEXT") for line in tab]
if any(["BACKEND_FIRST" in line for line in tab]):
res+=mergeFused(tab)
else:
res+=[line.replace("BACKENDFUNC", "BACKEND_NEAREST_FUNC").replace("CONTEXT","BACKEND_NEAREST_CONTEXT") for line in tab]
res+=["}\n"]
return res


def transformTemplateForSoftStopStart(nameRegExp,convNameRegExp,
lineTab, soft=False, only64=False):
def transformTemplateForSoftStopStart(lineTab, soft=False, only64=False):
func=[]
dicOfFunc={}

splitActive=False
tab=[]
for line in lineTab:
result=nameRegExp.match(line) #"(.*)FCTNAME\(([^,]*),([^)]*)\)(.*)")
if result!=None:
typeVal=result.group(2)
opt=result.group(3)
newName="FCTNAME("+typeVal+","+opt+")"
resultRegExp=None
for (FCTNAME, regExp) in [("FCTNAME", FctNameRegExp),
("FCTCONVNAME", FctConvNameRegExp),
("FCTNAMEUNFUSED", FctNameUnfusedRegExp),
("FCTCONVNAMEUNFUSED", FctConvNameUnfusedRegExp)]:
resultRegExp=regExp.match(line) #"(.*)FCTNAME\(([^,]*),([^)]*)\)(.*)")
if resultRegExp!=None:
break
if resultRegExp!=None:
typeVal=resultRegExp.group(2)
opt=resultRegExp.group(3)
newName=FCTNAME+"("+typeVal+","+opt+")"
if splitActive:
dicOfFunc[currentName]=tab
currentName=newName
Expand All @@ -61,19 +104,6 @@ def transformTemplateForSoftStopStart(nameRegExp,convNameRegExp,
splitActive=True
currentName=newName

result=convNameRegExp.match(line) #"(.*)FCTCONVNAME\(([^,]*),([^)]*)\)(.*)")
if result!=None:
typeVal=result.group(2)
opt=result.group(3)
newName="FCTCONVNAME("+typeVal+","+opt+")"
if splitActive:
dicOfFunc[currentName]=tab
currentName=newName
tab=[line]
continue
else:
splitActive=True
currentName=newName
if splitActive:
tab+=[line]
dicOfFunc[currentName]=tab
Expand All @@ -92,17 +122,17 @@ def transformTemplateForSoftStopStart(nameRegExp,convNameRegExp,
def transformTemplateForSoftStopStartOneFunction(lineTab, soft):
res=[]
back=[]
status="pre"
status="around"
for line in lineTab:
if "PREBACKEND" in line:
status="back"
continue
if "POSTBACKEND" in line:
res+=treatBackend(back,soft)
back=[]
status="pre"
status="around"
continue
if status=="pre":
if status=="around":
res+=[line]
if status=="back":
back+=[line]
Expand All @@ -111,7 +141,6 @@ def transformTemplateForSoftStopStartOneFunction(lineTab, soft):
def transformTemplateForSoftStopStartOneConvFunction(lineTabConv, lineTab, soft):
if soft==False:
return transformTemplateForSoftStopStartOneFunction(lineTabConv,soft)

res=[lineTabConv[0]]
remain=lineTabConv[1:]
while remain[-1].strip()!="}":
Expand All @@ -128,12 +157,14 @@ def transformTemplateForSoftStopStartOneConvFunction(lineTabConv, lineTab, soft)
remain=lineTab[1:]
while remain[-1].strip()!="}":
remain=remain[0:-1]

elseInstrumentBock=[]
for x in remain[0:-1]:
if ("PREBACKEND" in x) or ("POSTBACKEND" in x):
continue
res+=[x.replace("BACKENDFUNC", "BACKEND_NEAREST_FUNC").replace("CONTEXT","BACKEND_NEAREST_CONTEXT")]

elseInstrumentBock+=[x]
if any(["BACKEND_FIRST" in line for line in elseInstrumentBock]):
elseInstrumentBock=mergeFused(elseInstrumentBock)
res+=[x.replace("BACKENDFUNC", "BACKEND_NEAREST_FUNC").replace("CONTEXT","BACKEND_NEAREST_CONTEXT") for x in elseInstrumentBock]
res+=["}\n}\n\n"]
return res

Expand All @@ -144,13 +175,7 @@ def generateNargs(fileOut, fileNameTemplate, listOfBackend, listOfOp, nargs, pos
commentConv=False
templateStr=open(fileNameTemplate, "r").readlines()

FctNameRegExp=re.compile("(.*)FCTNAME\(([^,]*),([^)]*)\)(.*)")
FctConvNameRegExp=re.compile("(.*)FCTCONVNAME\(([^,]*),([^)]*)\)(.*)")
templateStr=transformTemplateForSoftStopStart(FctNameRegExp, FctConvNameRegExp, templateStr, soft, only64)

BckNameRegExp=re.compile("(.*)BACKENDFUNC\(([^)]*)\)(.*)")
BckNameNearestRegExp=re.compile("(.*)BACKEND_NEAREST_FUNC\(([^)]*)\)(.*)")

templateStr=transformTemplateForSoftStopStart(templateStr, soft, only64)
if post in ["check_float_max"]:
commentConv=True

Expand All @@ -161,40 +186,52 @@ def generateNargs(fileOut, fileNameTemplate, listOfBackend, listOfOp, nargs, pos
for op in listOfOp:
for rounding in roundingTab:
if nargs in [1,2]:
applyTemplate(fileOut, templateStr, FctNameRegExp, FctConvNameRegExp, BckNameRegExp, BckNameNearestRegExp, backend,op, post, sign=None, rounding=rounding, soft=soft, commentConv=commentConv)
applyTemplate(fileOut, templateStr, backend,op, post, sign=None, rounding=rounding, soft=soft, commentConv=commentConv)
if nargs==3:
sign=""
if "msub" in op:
sign="-"
applyTemplate(fileOut, templateStr, FctNameRegExp, FctConvNameRegExp, BckNameRegExp,BckNameNearestRegExp, backend, op, post, sign, rounding=rounding, soft=soft, commentConv=commentConv)
applyTemplate(fileOut, templateStr, backend, op, post, sign, rounding=rounding, soft=soft, commentConv=commentConv)
if backend=="mcaquad":
fileOut.write("#endif //USE_VERROU_QUADMATH\n")


def applyTemplate(fileOut, templateStr, FctRegExp, FctConvRegExp, BckRegExp, BckNearestRegExp, backend, op, post, sign=None, rounding=None, soft=False, commentConv=False):
def applyTemplate(fileOut, templateStr, backend, op, post, sign=None, rounding=None, soft=False, commentConv=False):
fileOut.write("// generation of operation %s backend %s\n"%(op,backend))
backendFunc=backend
if rounding!=None:
backendFunc=backend+"_"+rounding

def fctName(conv,typeVal,opt):
def fctName(unfused,conv,typeVal,opt):
vrPrefix="vr_"
if unfused:
vrPrefix+="unfused_"
if conv:
vrPrefix="vr_conv_"
vrPrefix+="conv_"
if soft:
return vrPrefix+backendFunc+post+"_soft"+op+typeVal+opt
else:
return vrPrefix+backendFunc+post+op+typeVal+opt

def bckName(typeVal):
if sign!="-":
if rounding!=None:
return "interflop_"+backend+"_"+op+"_"+typeVal+"_"+rounding
return "interflop_"+backend+"_"+op+"_"+typeVal
else:
if rounding!=None:
return "interflop_"+backend+"_"+op.replace("sub","add")+"_"+typeVal+"_"+rounding
return "interflop_"+backend+"_"+op.replace("sub","add")+"_"+typeVal
def localOp(first):
localOp=op
if sign=="-":
localOp=op.replace("sub","add")
if op in ["madd","msub"] and first!=None:
if first==True:
localOp="mul"
if first==False:
if op=="madd":
localOp="add"
else:
localOp="sub"
return localOp

def bckName(typeVal, first=None):
if rounding!=None:
return "interflop_"+backend+"_"+localOp(first)+"_"+typeVal+"_"+rounding
return "interflop_"+backend+"_"+localOp(first)+"_"+typeVal

def bckNearestName(typeVal):
if rounding!=None:
return (bckName(typeVal)).replace(rounding,"NEAREST")
Expand All @@ -203,11 +240,11 @@ def bckNearestName(typeVal):
else:
return "interflop_verrou_"+op.replace("sub","add")+"_"+typeVal+"_NEAREST"

def bckNamePost(typeVal):
if sign!="-":
return "interflop_"+post+"_"+op+"_"+typeVal
else:
return "interflop_"+post+"_"+op.replace("sub","add")+"_"+typeVal
def bckNamePost(typeVal,first=None):
bop=localOp(first)
if bop in ["mul"] and post=="checkcancellation":
return None
return "interflop_"+post+"_"+bop+"_"+typeVal


contextName="backend_"+backend+"_context"
Expand Down Expand Up @@ -241,39 +278,61 @@ def outputRes(res,localComment):
else:
print("Generation failed")
sys.exit()
result=FctRegExp.match(line)
if result!=None:
comment=False

typeVal=result.group(2)
opt=result.group(3)
if rounding=="NEAREST" and soft:
comment=True
res=result.group(1) + fctName(False,typeVal, opt) + result.group(4)
outputRes(res,comment)
continue
result=FctConvRegExp.match(line)

result=FctConvNameRegExp.match(line)
fused=False
if result==None:
result=FctConvNameUnfusedRegExp.match(line)
fused=True
if result!=None:
if commentConv:
comment=True
else:
comment=False
typeVal=result.group(2)
opt=result.group(3)
res=result.group(1) + fctName(True, typeVal, opt) + result.group(4)
res=result.group(1) + fctName(fused,True, typeVal, opt) + result.group(4)
outputRes(res,comment)
continue

result=FctNameRegExp.match(line)
unfused=False
if result==None:
result=FctNameUnfusedRegExp.match(line)
unfused=True
if result!=None:
comment=False
typeVal=result.group(2)
opt=result.group(3)
if rounding=="NEAREST" and soft and (not unfused):
comment=True
res=result.group(1) + fctName(unfused,False,typeVal, opt) + result.group(4)
outputRes(res,comment)
continue
result=BckRegExp.match(line)

result=None
first=None
for (rgExp,firstVar) in [(BckNameFirstRegExp, True), (BckNameSecondRegExp,False) , (BckNameRegExp,None)]:
result=rgExp.match(line)
if result!=None:
first=firstVar
break
if result!=None:
res=result.group(1) + bckName(result.group(2)) + result.group(3)
group3=result.group(3)
if first==False:
group3=group3.replace(" - "," ")
res=result.group(1) + bckName(result.group(2),first) + group3
outputRes(res,comment)
if post!="":
res=result.group(1) + bckNamePost(result.group(2)) + result.group(3)
res=res.replace(contextName, contextNamePost)
outputRes(res,comment)
bnPost=bckNamePost(result.group(2),first)
if bnPost!=None:
res=result.group(1) + bnPost+ group3
res=res.replace(contextName, contextNamePost)
outputRes(res,comment)
continue

result=BckNearestRegExp.match(line)
result=BckNameNearestRegExp.match(line)
if result!=None:
res=result.group(1) + bckNearestName(result.group(2)) + result.group(3)
outputRes(res,comment)
Expand Down
Loading

0 comments on commit bece845

Please sign in to comment.