Skip to content

Commit

Permalink
py/emitnative: Fix native async with.
Browse files Browse the repository at this point in the history
The code generating the entry to the finally handler of an async-with
statement was simply wrong for the case of the native emitter.  Among other
things the layout of the stack was incorrect.

This is fixed by this commit.  The setup of the async-with finally handler
is now put in a dedicated emit function, for both the bytecode and native
emitters to implement in their own way (the bytecode emitter is unchanged,
just factored to a function).

With this fix all of the async-with tests now work when using the native
emitter.

Signed-off-by: Damien George <[email protected]>
  • Loading branch information
dpgeorge committed Jun 21, 2024
1 parent a19214d commit 038125b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 20 deletions.
18 changes: 5 additions & 13 deletions py/compile.c
Original file line number Diff line number Diff line change
Expand Up @@ -1899,19 +1899,7 @@ static void compile_async_with_stmt_helper(compiler_t *comp, size_t n, mp_parse_

// Handle case 1: call __aexit__
// Stack: (..., ctx_mgr)
EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // to tell end_finally there's no exception
EMIT(rot_two);
EMIT_ARG(jump, l_aexit_no_exc); // jump to code below to call __aexit__

// Start of "finally" block
// At this point we have case 2 or 3, we detect which one by the TOS being an exception or not
EMIT_ARG(label_assign, l_finally_block);

// Detect if TOS an exception or not
EMIT(dup_top);
EMIT_LOAD_GLOBAL(MP_QSTR_BaseException);
EMIT_ARG(binary_op, MP_BINARY_OP_EXCEPTION_MATCH);
EMIT_ARG(pop_jump_if, false, l_ret_unwind_jump); // if not an exception then we have case 3
EMIT_ARG(async_with_setup_finally, l_aexit_no_exc, l_finally_block, l_ret_unwind_jump);

// Handle case 2: call __aexit__ and either swallow or re-raise the exception
// Stack: (..., ctx_mgr, exc)
Expand All @@ -1937,6 +1925,7 @@ static void compile_async_with_stmt_helper(compiler_t *comp, size_t n, mp_parse_
EMIT_ARG(pop_jump_if, false, l_end);
EMIT(pop_top); // pop exception
EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE); // replace with None to swallow exception
// Stack: (..., None)
EMIT_ARG(jump, l_end);
EMIT_ARG(adjust_stack_size, 2);

Expand All @@ -1946,13 +1935,16 @@ static void compile_async_with_stmt_helper(compiler_t *comp, size_t n, mp_parse_
EMIT(rot_three);
EMIT(rot_three);
EMIT_ARG(label_assign, l_aexit_no_exc);
// We arrive here from either case 1 (a jump) or case 3 (fall through)
// Stack: case 1: (..., None, ctx_mgr) or case 3: (..., X, INT, ctx_mgr)
EMIT_ARG(load_method, MP_QSTR___aexit__, false);
EMIT_ARG(load_const_tok, MP_TOKEN_KW_NONE);
EMIT(dup_top);
EMIT(dup_top);
EMIT_ARG(call_method, 3, 0, 0);
compile_yield_from(comp);
EMIT(pop_top);
// Stack: case 1: (..., None) or case 3: (..., X, INT)
EMIT_ARG(adjust_stack_size, -1);

// End of "finally" block
Expand Down
6 changes: 6 additions & 0 deletions py/emit.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ typedef struct _emit_method_table_t {
void (*unwind_jump)(emit_t *emit, mp_uint_t label, mp_uint_t except_depth);
void (*setup_block)(emit_t *emit, mp_uint_t label, int kind);
void (*with_cleanup)(emit_t *emit, mp_uint_t label);
#if MICROPY_PY_ASYNC_AWAIT
void (*async_with_setup_finally)(emit_t *emit, mp_uint_t label_aexit_no_exc, mp_uint_t label_finally_block, mp_uint_t label_ret_unwind_jump);
#endif
void (*end_finally)(emit_t *emit);
void (*get_iter)(emit_t *emit, bool use_stack);
void (*for_iter)(emit_t *emit, mp_uint_t label);
Expand Down Expand Up @@ -264,6 +267,9 @@ void mp_emit_bc_jump_if_or_pop(emit_t *emit, bool cond, mp_uint_t label);
void mp_emit_bc_unwind_jump(emit_t *emit, mp_uint_t label, mp_uint_t except_depth);
void mp_emit_bc_setup_block(emit_t *emit, mp_uint_t label, int kind);
void mp_emit_bc_with_cleanup(emit_t *emit, mp_uint_t label);
#if MICROPY_PY_ASYNC_AWAIT
void mp_emit_bc_async_with_setup_finally(emit_t *emit, mp_uint_t label_aexit_no_exc, mp_uint_t label_finally_block, mp_uint_t label_ret_unwind_jump);
#endif
void mp_emit_bc_end_finally(emit_t *emit);
void mp_emit_bc_get_iter(emit_t *emit, bool use_stack);
void mp_emit_bc_for_iter(emit_t *emit, mp_uint_t label);
Expand Down
24 changes: 24 additions & 0 deletions py/emitbc.c
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,27 @@ void mp_emit_bc_with_cleanup(emit_t *emit, mp_uint_t label) {
mp_emit_bc_adjust_stack_size(emit, -4);
}

#if MICROPY_PY_ASYNC_AWAIT
void mp_emit_bc_async_with_setup_finally(emit_t *emit, mp_uint_t label_aexit_no_exc, mp_uint_t label_finally_block, mp_uint_t label_ret_unwind_jump) {
// The async-with body has executed and no exception was raised, the execution fell through to this point.
// Stack: (..., ctx_mgr)

// Finish async-with body and prepare to enter "finally" block.
mp_emit_bc_load_const_tok(emit, MP_TOKEN_KW_NONE); // to tell end_finally there's no exception
mp_emit_bc_rot_two(emit);
mp_emit_bc_jump(emit, label_aexit_no_exc); // jump to code to call __aexit__

// Start of "finally" block which is entered via one of: an exception propagating out, a return, an unwind jump.
mp_emit_bc_label_assign(emit, label_finally_block);

// Detect which case we have by the TOS being an exception or not.
mp_emit_bc_dup_top(emit);
mp_emit_bc_load_global(emit, MP_QSTR_BaseException, MP_EMIT_IDOP_GLOBAL_GLOBAL);
mp_emit_bc_binary_op(emit, MP_BINARY_OP_EXCEPTION_MATCH);
mp_emit_bc_pop_jump_if(emit, false, label_ret_unwind_jump); // if not an exception then we have return or unwind jump.
}
#endif

void mp_emit_bc_end_finally(emit_t *emit) {
emit_write_bytecode_byte(emit, -1, MP_BC_END_FINALLY);
}
Expand Down Expand Up @@ -862,6 +883,9 @@ const emit_method_table_t emit_bc_method_table = {
mp_emit_bc_unwind_jump,
mp_emit_bc_setup_block,
mp_emit_bc_with_cleanup,
#if MICROPY_PY_ASYNC_AWAIT
mp_emit_bc_async_with_setup_finally,
#endif
mp_emit_bc_end_finally,
mp_emit_bc_get_iter,
mp_emit_bc_for_iter,
Expand Down
38 changes: 35 additions & 3 deletions py/emitnative.c
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ static void emit_native_label_assign(emit_t *emit, mp_uint_t l) {
if (is_finally) {
// Label is at start of finally handler: store TOS into exception slot
vtype_kind_t vtype;
emit_pre_pop_reg(emit, &vtype, REG_TEMP0);
emit_access_stack(emit, 1, &vtype, REG_TEMP0);
ASM_MOV_LOCAL_REG(emit->as, LOCAL_IDX_EXC_VAL(emit), REG_TEMP0);
}

Expand Down Expand Up @@ -1201,6 +1201,10 @@ static void emit_native_global_exc_entry(emit_t *emit) {
ASM_XOR_REG_REG(emit->as, REG_TEMP0, REG_TEMP0);
ASM_MOV_LOCAL_REG(emit->as, LOCAL_IDX_EXC_HANDLER_UNWIND(emit), REG_TEMP0);

// clear nlr.ret_val, because it's passed to mp_native_raise regardless
// of whether there was an exception or not
ASM_MOV_LOCAL_REG(emit->as, LOCAL_IDX_EXC_VAL(emit), REG_TEMP0);

// Put PC of start code block into REG_LOCAL_1
ASM_MOV_REG_PCREL(emit->as, REG_LOCAL_1, start_label);

Expand Down Expand Up @@ -2235,8 +2239,34 @@ static void emit_native_with_cleanup(emit_t *emit, mp_uint_t label) {
emit_native_label_assign(emit, *emit->label_slot + 1);

// Exception is in nlr_buf.ret_val slot
adjust_stack(emit, 1);
}

#if MICROPY_PY_ASYNC_AWAIT
static void emit_native_async_with_setup_finally(emit_t *emit, mp_uint_t label_aexit_no_exc, mp_uint_t label_finally_block, mp_uint_t label_ret_unwind_jump) {
// The async-with body has executed and no exception was raised, the execution fell through to this point.
// Stack: (..., ctx_mgr)

// Insert a dummy value into the stack so the stack has the same layout to execute the code starting at label_aexit_no_exc
emit_native_adjust_stack_size(emit, 1); // push dummy value, it won't ever be used
emit_native_rot_two(emit);
emit_native_load_const_tok(emit, MP_TOKEN_KW_NONE); // to tell end_finally there's no exception
emit_native_rot_two(emit);
// Stack: (..., <dummy>, None, ctx_mgr)
emit_native_jump(emit, label_aexit_no_exc); // jump to code to call __aexit__
emit_native_adjust_stack_size(emit, -1);

// Start of "finally" block which is entered via one of: an exception propagating out, a return, an unwind jump.
emit_native_label_assign(emit, label_finally_block);

// Detect which case we have by the local exception slot holding an exception or not.
emit_pre_pop_discard(emit);
ASM_MOV_REG_LOCAL(emit->as, REG_ARG_1, LOCAL_IDX_EXC_VAL(emit)); // get exception
emit_post_push_reg(emit, VTYPE_PYOBJ, REG_ARG_1);
ASM_JUMP_IF_REG_ZERO(emit->as, REG_ARG_1, label_ret_unwind_jump, false); // if not an exception then we have return or unwind jump.
}
#endif

static void emit_native_end_finally(emit_t *emit) {
// logic:
// exc = pop_stack
Expand All @@ -2245,7 +2275,7 @@ static void emit_native_end_finally(emit_t *emit) {
// the check if exc is None is done in the MP_F_NATIVE_RAISE stub
DEBUG_printf("end_finally\n");

emit_native_pre(emit);
emit_pre_pop_discard(emit);
ASM_MOV_REG_LOCAL(emit->as, REG_ARG_1, LOCAL_IDX_EXC_VAL(emit));
emit_call(emit, MP_F_NATIVE_RAISE);

Expand Down Expand Up @@ -3033,7 +3063,6 @@ static void emit_native_start_except_handler(emit_t *emit) {
}

static void emit_native_end_except_handler(emit_t *emit) {
adjust_stack(emit, -1); // pop the exception (end_finally didn't use it)
}

const emit_method_table_t EXPORT_FUN(method_table) = {
Expand Down Expand Up @@ -3082,6 +3111,9 @@ const emit_method_table_t EXPORT_FUN(method_table) = {
emit_native_unwind_jump,
emit_native_setup_block,
emit_native_with_cleanup,
#if MICROPY_PY_ASYNC_AWAIT
emit_native_async_with_setup_finally,
#endif
emit_native_end_finally,
emit_native_get_iter,
emit_native_for_iter,
Expand Down
4 changes: 0 additions & 4 deletions tests/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,6 @@ def run_tests(pyb, tests, args, result_dir, num_threads=1):
# Remove them from the below when they work
if args.emit == "native":
skip_tests.add("basics/gen_yield_from_close.py") # require raise_varargs
skip_tests.update(
{"basics/async_%s.py" % t for t in "with with2 with_break with_return".split()}
) # require async_with
skip_tests.update(
{"basics/%s.py" % t for t in "try_reraise try_reraise2".split()}
) # require raise_varargs
Expand All @@ -731,7 +728,6 @@ def run_tests(pyb, tests, args, result_dir, num_threads=1):
skip_tests.add("basics/sys_tracebacklimit.py") # requires traceback info
skip_tests.add("basics/try_finally_return2.py") # requires raise_varargs
skip_tests.add("basics/unboundlocal.py") # requires checking for unbound local
skip_tests.add("extmod/asyncio_lock.py") # requires async with
skip_tests.add("misc/features.py") # requires raise_varargs
skip_tests.add(
"misc/print_exception.py"
Expand Down

0 comments on commit 038125b

Please sign in to comment.