Skip to content

Commit

Permalink
Create temp file at start of execution of first op for op_by_op exect…
Browse files Browse the repository at this point in the history
…ion.

Remove the temp file after executing all ops.
  • Loading branch information
mmanzoorTT committed Feb 4, 2025
1 parent 4caeccf commit a432301
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def execute_process(receiver, sender, exec_event):
binary = obj["binary"]
inputs = obj["inputs"]
file_name = obj["dump_file"]
f_stderr = open(file_name, "w")
file_stderr = open(file_name, "w")
old_stderr = sys.stderr
sys.stderr = f_stderr
sys.stderr = file_stderr
old_stdout = sys.stdout
sys.stdout = f_stderr
sys.stdout = file_stderr
outputs = tt_mlir.run(inputs, binary)
sys.stderr = old_stderr
sys.stdout = old_stdout
f_stderr.close()
file_stderr.close()
sender.put({"outputs": outputs})
exec_event.wait()
sys.exit(0)
Expand Down Expand Up @@ -137,14 +137,11 @@ def __init__(
self.execute_sender = None
self.execute_receiver = None

# Create a temp file which will be used by sub process to dump stderr
# output during execution of the mlir graph on TT device.
self.f_stderr = tempfile.NamedTemporaryFile(mode="w+t", delete=False)

# Class destructor
def __del__(self):
# Remove the temp file created.
os.unlink(self.f_stderr.name)
# Create temp file at start of execution of first op and pass the name
# of temp file to subprocess which will be used to redirect the stderr
# to capture runtime stack dump.
self.stderror_redirected = False
self.file_stderr = None

def register_intermediate_callback(self, callback):
if not is_runtime_debug_enabled():
Expand Down Expand Up @@ -339,7 +336,11 @@ def pre_process_inputs(self, *inputs):

def run_op(self, binary, *inputs):
inputs = self.pre_process_inputs(*inputs)
obj = {"binary": binary, "inputs": inputs, "dump_file": self.f_stderr.name}
if not self.stderror_redirected:
self.file_stderr = tempfile.NamedTemporaryFile(mode="w+t", delete=False)
self.stderror_redirected = True

obj = {"binary": binary, "inputs": inputs, "dump_file": self.file_stderr.name}

exec_event = mp.Event()
if self.execute_process is None:
Expand Down Expand Up @@ -375,11 +376,11 @@ def run_op(self, binary, *inputs):

stderr_data = ""
if outputs is None:
f_stderr = open(self.f_stderr.name, "r")
stderr_data = f_stderr.read()
file_stderr = open(self.file_stderr.name, "r")
stderr_data = file_stderr.read()
stderr_data = stderr_data.replace("\n", "\\n")
stderr_data = re.sub(r"[^\x20-\x7E]", "", stderr_data)
f_stderr.close()
file_stderr.close()

return outputs, stderr_data

Expand Down Expand Up @@ -473,6 +474,10 @@ def run_gm_op_by_op(self, *inputs):
if self.execute_process is not None:
self.execute_process.terminate()
self.execute_process = None
if self.stderror_redirected:
os.unlink(self.file_stderr.name)
self.stderror_redirected = False

return outputs

def __call__(self, *inputs):
Expand Down

0 comments on commit a432301

Please sign in to comment.