Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[explore] stop submitting runs if sensor is stopped mid-iteration #26924

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions python_modules/dagster/dagster/_daemon/asset_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,8 @@ def _process_auto_materialize_tick_generator(
# or crashed partway through execution and needs to be resumed
# Don't resume very old ticks though in case the daemon crashed for a long time and
# then restarted

# TODO - add check here if the interrupted bool was set? don't retry if the tick was manually iterrupted
if can_resume and previous_cursor_written:
if latest_tick.status == TickStatus.STARTED:
self._logger.warn(
Expand Down Expand Up @@ -1015,6 +1017,7 @@ def _evaluate_auto_materialize_tick(
reserved_run_ids=reserved_run_ids,
debug_crash_flags=debug_crash_flags,
submit_threadpool_executor=submit_threadpool_executor,
remote_sensor=sensor,
)

if schedule_storage.supports_auto_materialize_asset_evaluations:
Expand Down Expand Up @@ -1098,12 +1101,48 @@ def _submit_run_requests_and_update_evaluations(
reserved_run_ids: Sequence[str],
debug_crash_flags: SingleInstigatorDebugCrashFlags,
submit_threadpool_executor: Optional[ThreadPoolExecutor],
remote_sensor: Optional[RemoteSensor],
):
updated_evaluation_keys = set()
run_request_execution_data_cache = {}

check.invariant(len(run_requests) == len(reserved_run_ids))
to_submit = zip(range(len(run_requests)), reserved_run_ids, run_requests)
to_submit = list(
zip(range(len(run_requests)), reserved_run_ids, run_requests)
) # TODO see if i can do the batched iteration on an iterable....

def submit_run_request_batch(
run_id_with_run_request_batch: Sequence[Tuple[int, str, RunRequest]],
) -> Optional[Sequence[Tuple[str, AbstractSet[EntityKey]]]]:
# check if the sensor is still enabled:
if remote_sensor:
all_sensor_states = {
sensor_state.selector_id: sensor_state
for sensor_state in instance.all_instigator_state(
instigator_type=InstigatorType.SENSOR
)
}
if not remote_sensor.get_current_instigator_state(
all_sensor_states.get(remote_sensor.selector_id)
).is_running:
return

# if so then submit the run requests
results = []
for i, run_id, run_request in run_id_with_run_request_batch:
results.append(
self._submit_run_request(
i=i,
instance=instance,
run_request=run_request,
reserved_run_id=run_id,
evaluation_id=evaluation_id,
run_request_execution_data_cache=run_request_execution_data_cache,
workspace_process_context=workspace_process_context,
debug_crash_flags=debug_crash_flags,
)
)
return results

def submit_run_request(
run_id_with_run_request: Tuple[int, str, RunRequest],
Expand All @@ -1120,15 +1159,31 @@ def submit_run_request(
debug_crash_flags=debug_crash_flags,
)

batch_size = 25

batches = []
for i in range(0, len(to_submit), batch_size):
# TODO - double check the math here
batches.append(to_submit[i : i + batch_size])

if submit_threadpool_executor:
gen_run_request_results = submit_threadpool_executor.map(submit_run_request, to_submit)
gen_run_request_results = submit_threadpool_executor.map(
submit_run_request_batch, batches
)
else:
gen_run_request_results = map(submit_run_request, to_submit)
gen_run_request_results = map(submit_run_request_batch, batches)

for run_request_result in gen_run_request_results:
if run_request_result is None:
# sensor is no longer running
# TODO - cleanup work
break # maybe return

for submitted_run_id, entity_keys in gen_run_request_results:
# heartbeat after each submitted run
yield

submitted_run_id, entity_keys = run_request_result

tick_context.add_run_info(run_id=submitted_run_id)

# write the submitted run ID to any evaluations
Expand All @@ -1152,6 +1207,7 @@ def submit_run_request(

check_for_debug_crash(debug_crash_flags, "RUN_IDS_ADDED_TO_EVALUATIONS")

# TODO - need to make sure not to update the tick state if the iteration is interrupted
tick_context.update_state(
TickStatus.SUCCESS if len(run_requests) > 0 else TickStatus.SKIPPED,
)
105 changes: 65 additions & 40 deletions python_modules/dagster/dagster/_daemon/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,55 +1093,80 @@ def _submit_run_requests(
instance, remote_sensor, [request for _, request in resolved_run_ids_with_requests]
)

def submit_run_request(
run_id_with_run_request: Tuple[str, RunRequest],
) -> SubmitRunRequestResult:
run_id, run_request = run_id_with_run_request
if run_request.requires_backfill_daemon():
return _submit_backfill_request(run_id, run_request, instance)
else:
return _submit_run_request(
run_id,
run_request,
workspace_process_context,
remote_sensor,
existing_runs_by_key,
context.logger,
sensor_debug_crash_flags,
)
def submit_run_request_batch(
run_id_with_run_request_batch: Sequence[Tuple[str, RunRequest]],
) -> Optional[Sequence[SubmitRunRequestResult]]:
# check if the sensor is still enabled:
all_sensor_states = {
sensor_state.selector_id: sensor_state
for sensor_state in instance.all_instigator_state(instigator_type=InstigatorType.SENSOR)
}
if not remote_sensor.get_current_instigator_state(
all_sensor_states.get(remote_sensor.selector_id)
).is_running:
return

# if so then submit the run requests
results = []
for run_id, run_request in run_id_with_run_request_batch:
if run_request.requires_backfill_daemon():
results.append(_submit_backfill_request(run_id, run_request, instance))
else:
results.append(
_submit_run_request(
run_id,
run_request,
workspace_process_context,
remote_sensor,
existing_runs_by_key,
context.logger,
sensor_debug_crash_flags,
)
)
return results

batch_size = 25

batches = []
for i in range(0, len(resolved_run_ids_with_requests), batch_size):
# TODO - double check the math here
batches.append(resolved_run_ids_with_requests[i : i + batch_size])

if submit_threadpool_executor:
gen_run_request_results = submit_threadpool_executor.map(
submit_run_request, resolved_run_ids_with_requests
)
gen_run_request_results = submit_threadpool_executor.map(submit_run_request_batch, batches)
else:
gen_run_request_results = map(submit_run_request, resolved_run_ids_with_requests)
gen_run_request_results = map(submit_run_request_batch, batches)

skipped_runs: List[SkippedSensorRun] = []
evaluations_by_key = {
evaluation.key: evaluation for evaluation in automation_condition_evaluations
}
updated_evaluation_keys = set()
for run_request_result in gen_run_request_results:
yield run_request_result.error_info

run = run_request_result.run

if isinstance(run, SkippedSensorRun):
skipped_runs.append(run)
context.add_run_info(run_id=None, run_key=run_request_result.run_key)
elif isinstance(run, BackfillSubmission):
context.add_run_info(run_id=run.backfill_id)
else:
context.add_run_info(run_id=run.run_id, run_key=run_request_result.run_key)
entity_keys = [*(run.asset_selection or []), *(run.asset_check_selection or [])]
for key in entity_keys:
if key in evaluations_by_key:
evaluation = evaluations_by_key[key]
evaluations_by_key[key] = dataclasses.replace(
evaluation, run_ids=evaluation.run_ids | {run.run_id}
)
updated_evaluation_keys.add(key)
for run_request_results in gen_run_request_results:
if run_request_results is None:
# sensor is no longer running
# TODO - cleanup work
break # maybe return
for run_request_result in run_request_results:
yield run_request_result.error_info

run = run_request_result.run

if isinstance(run, SkippedSensorRun):
skipped_runs.append(run)
context.add_run_info(run_id=None, run_key=run_request_result.run_key)
elif isinstance(run, BackfillSubmission):
context.add_run_info(run_id=run.backfill_id)
else:
context.add_run_info(run_id=run.run_id, run_key=run_request_result.run_key)
entity_keys = [*(run.asset_selection or []), *(run.asset_check_selection or [])]
for key in entity_keys:
if key in evaluations_by_key:
evaluation = evaluations_by_key[key]
evaluations_by_key[key] = dataclasses.replace(
evaluation, run_ids=evaluation.run_ids | {run.run_id}
)
updated_evaluation_keys.add(key)

if (
updated_evaluation_keys
Expand Down