Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save restore function #717

Merged
merged 6 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions compiler/orchestrator_runtime/pash_declare_vars.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ vars_file="${1?File not given}"

# pash_redir_output echo "Writing vars to: $vars_file"

declare -p > "$vars_file"
## KK 2021-11-23 We don't actually need to export functions in the vars file.
## We never expand them in the compiler
## declare -f >> "$vars_file"
echo "cd ${PWD}" > "$vars_file"
declare -p >> "$vars_file"
declare -f >> "$vars_file"
50 changes: 42 additions & 8 deletions compiler/shell_ast/ast_to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from env_var_names import *
from shell_ast.ast_util import *
from shasta.ast_node import ast_match, is_empty_cmd, string_of_arg
from shasta.ast_node import ast_match, is_empty_cmd, string_of_arg, BArgChar
from shasta.json_to_ast import to_ast_node
from parse import from_ast_objects_to_shell
from speculative import util_spec
Expand Down Expand Up @@ -41,6 +41,9 @@ def __init__(self, num: int):

def add_command(self, command):
self._is_emtpy = False

def add_assignment(self, assignment):
self._is_emtpy = False

def make_non_empty(self):
self._is_emtpy = False
Expand Down Expand Up @@ -148,6 +151,9 @@ def add_break(self):

def add_command(self, command):
self.bbs[self.current_bb].add_command(command)

def add_var_assignment(self, assignment):
self.bbs[self.current_bb].add_assignment(assignment)

## Use this object to pass options inside the preprocessing
## trasnformation.
Expand All @@ -157,6 +163,8 @@ def __init__(self, mode: TransformationType):
self.node_counter = 0
self.loop_counter = 0
self.loop_contexts = []
self.var_counter = 0
self.var_contexts = []
self.prog = ShellProg()

def get_mode(self):
Expand Down Expand Up @@ -190,6 +198,12 @@ def get_current_loop_id(self):
else:
return self.loop_contexts[0]

def get_number_of_var_assignments(self):
return self.var_counter

def get_var_nodes(self):
return self.var_contexts[:]

def current_bb(self):
return self.prog.current_bb

Expand All @@ -215,6 +229,11 @@ def exit_if(self):
def visit_command(self, command):
if len(command.arguments) > 0 and string_of_arg(command.arguments[0]) == 'break':
self.prog.add_break()
elif len(command.arguments) == 0 and len(command.assignments) > 0 and not contains_command_substitution(command):
self.prog.add_var_assignment(command)
## GL: HACK to get ids right
self.var_contexts.append(self.get_current_id() + 1)
self.var_counter += 1
else:
self.prog.add_command(command)

Expand Down Expand Up @@ -289,7 +308,17 @@ def get_all_node_bb(self):
lambda ast_node: preprocess_node_case(ast_node, trans_options, last_object=last_object))
}


# Checks is var assignment value is BArgChar
def contains_command_substitution(ast_node):
if len(ast_node.assignments) == 0:
return False
for assignment in ast_node.assignments:
if len(assignment.val) == 0:
return False
for val in assignment.val:
if type(val) == BArgChar:
return True
return False

## Replace candidate dataflow AST regions with calls to PaSh's runtime.
def replace_ast_regions(ast_objects, trans_options):
Expand Down Expand Up @@ -435,16 +464,21 @@ def preprocess_node_command(ast_node, trans_options, last_object=False):
## If there are no arguments, the command is just an
## assignment (Q: or just redirections?)
trans_options : TransformationState
if(len(ast_node.arguments) == 0):

if trans_options.get_mode() is TransformationType.PASH \
and (len(ast_node.arguments) == 0):
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=False,
non_maximal=False,
something_replaced=False,
last_ast=last_object)
replace_whole=False,
non_maximal=False,
something_replaced=False,
last_ast=last_object)
return preprocessed_ast_object

## This means we have a command. Commands are always candidate dataflow
## regions.

## GL: In spec mode, we treat assignment nodes as commands
# breakpoint()
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=True,
non_maximal=False,
Expand Down Expand Up @@ -704,7 +738,7 @@ def preprocess_node_case(ast_node, trans_options, last_object=False):
##
## If we are need to disable parallel pipelines, e.g., if we are in the context of an if,
## or if we are in the end of a script, then we set a variable.
def replace_df_region(asts, trans_options, disable_parallel_pipelines=False, ast_text=None) -> AstNode:
def replace_df_region(asts, trans_options: TransformationState, disable_parallel_pipelines=False, ast_text=None, var_assignment=False) -> AstNode:

transformation_mode = trans_options.get_mode()
if transformation_mode is TransformationType.PASH:
Expand Down
23 changes: 23 additions & 0 deletions compiler/speculative/util_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,18 @@ def serialize_edge(from_id: int, to_id: int) -> str:
def serialize_number_of_nodes(number_of_ids: int) -> str:
return f'{number_of_ids}\n'

def serialize_number_of_var_assignments(number_of_var_assignments: int) -> str:
return f'{number_of_var_assignments}\n'

def serialize_loop_context(node_id: int, bb_id) -> str:
## Galaxy brain serialization
# loop_contexts_str = ",".join([str(loop_ctx) for loop_ctx in loop_contexts])
bb_id_str = str(bb_id)
return f'{node_id}-loop_ctx-{bb_id_str}\n'

def serialize_var_assignments(node_id: int) -> str:
return f'{node_id}-var\n'

def save_current_env_to_file(trans_options):
initial_env_file = ptempfile()
subprocess.check_output([f"{os.getenv('PASH_TOP')}/compiler/orchestrator_runtime/pash_declare_vars.sh", initial_env_file])
Expand All @@ -89,6 +95,19 @@ def save_loop_contexts(trans_options):
bb_id = node_bb_dict[node_id]
po_file.write(serialize_loop_context(node_id, bb_id))

def save_var_assignment_contexts(trans_options):
var_nodes = trans_options.get_var_nodes()
partial_order_file_path = trans_options.get_partial_order_file()
with open(partial_order_file_path, "a") as po_file:
for node_id in var_nodes:
po_file.write(serialize_var_assignments(node_id))

def save_number_of_var_assignments(trans_options):
number_of_var_assignments = trans_options.get_number_of_var_assignments()
partial_order_file_path = trans_options.get_partial_order_file()
with open(partial_order_file_path, "a") as po_file:
po_file.write(serialize_number_of_var_assignments(number_of_var_assignments))

def serialize_partial_order(trans_options):
## Initialize the po file
dir_path = partial_order_directory()
Expand All @@ -110,6 +129,10 @@ def serialize_partial_order(trans_options):

## Save loop contexts
save_loop_contexts(trans_options)

save_number_of_var_assignments(trans_options)

save_var_assignment_contexts(trans_options)

# Save the edges in the partial order file
edges = trans_options.get_all_edges()
Expand Down
Loading