Skip to content

Commit

Permalink
add unsafe marker for async functions without drop safe
Browse files Browse the repository at this point in the history
  • Loading branch information
ihciah committed Dec 17, 2023
1 parent 5d6a8c3 commit a4e920b
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn async_call() {
}],
};
let call = Instant::now();
let pass = DemoCallImpl::demo_check_async(&req).await.pass;
let pass = unsafe { DemoCallImpl::demo_check_async(&req).await }.pass;
println!(
"[async] User pass: {pass}, time cost: {}sec",
call.elapsed().as_secs()
Expand Down
19 changes: 15 additions & 4 deletions rust2go-common/src/raw_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,17 +342,22 @@ impl TryFrom<&ItemTrait> for TraitRepr {
drop_safe_ret_params = true;
}

if (drop_safe || drop_safe_ret_params)
&& params.iter().any(|param| param.ty.is_reference)
{
let mut safe = true;
let has_reference = params.iter().any(|param| param.ty.is_reference);

if (drop_safe || drop_safe_ret_params) && has_reference {
sbail!("drop_safe function cannot have reference parameters")
}
if is_async && !drop_safe && !drop_safe_ret_params {
safe = false;
}

fns.push(FnRepr {
name: fn_name,
is_async,
params,
ret,
safe,
drop_safe_ret_params,
});
}
Expand All @@ -368,6 +373,7 @@ pub struct FnRepr {
is_async: bool,
params: Vec<Param>,
ret: Option<ParamType>,
safe: bool,
drop_safe_ret_params: bool,
}

Expand Down Expand Up @@ -696,6 +702,10 @@ impl FnRepr {
self.drop_safe_ret_params
}

pub fn safe(&self) -> bool {
self.safe
}

pub fn params(&self) -> &[Param] {
&self.params
}
Expand Down Expand Up @@ -853,8 +863,9 @@ inline void {fn_name}_cb(const void *f_ptr, {c_resp_type} resp, const void *slot
let func_name = &self.name;
let func_param_names: Vec<_> = self.params.iter().map(|p| &p.name).collect();
let func_param_types: Vec<_> = self.params.iter().map(|p| &p.ty).collect();
let unsafe_marker = (!self.safe).then(syn::token::Unsafe::default);
out.extend(quote! {
fn #func_name(#(#func_param_names: #func_param_types)*)
#unsafe_marker fn #func_name(#(#func_param_names: #func_param_types)*)
});

let ref_marks = self.params.iter().map(|p| {
Expand Down
15 changes: 15 additions & 0 deletions rust2go-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,21 @@ fn r2g_trait(
}
}

// for all functions with safe=false, add unsafe
for (_, trat_fn) in trat_repr
.fns()
.iter()
.zip(trat.items.iter_mut())
.filter(|(fn_repr, _)| !fn_repr.safe())
{
match trat_fn {
syn::TraitItem::Fn(f) => {
f.sig.unsafety = Some(syn::token::Unsafe::default());
}
_ => sbail!("only fn is supported"),
}
}

let mut out = quote! {#trat};
out.extend(trat_repr.generate_rs(binding_path.as_ref())?);
Ok(out.into())
Expand Down

0 comments on commit a4e920b

Please sign in to comment.