Skip to content

Commit

Permalink
Implement RetAbi and BoxRet (#1701)
Browse files Browse the repository at this point in the history
This completely reworks the way we handle returning types from
functions, so that we no longer have to rely on a macro expansion
behavior that has to already know the types at expansion time (and thus
has to parse them somehow, which it cannot realistically do because type
aliases exist). Instead, we simply expand into code that asks the types
themselves to modify the FunctionCallInfo appropriately and then return
a raw Datum to Postgres.

This breaks support for certain returns because it was difficult to do
this and also support arbitrary nesting, because Postgres does not
support arbitrary nesting. For instance, you can no longer return:
- `SetOfIterator<'a, Result<T, E>>`
- `TableIterator<'a, (Result<T, E>, Result<U, D>)>`
- `Option<TableIterator<'a, Tuple>>`

It's expected that this will improve in the near-ish future.

This also breaks returning values from `#[pg_extern]` functions that
were relying on `IntoDatum` implementations being enough. It is not
expected this will improve, for the reasons described on the
documentation of the new traits, which can be summarized as "`IntoDatum`
should never have been used for that bound". This change blocks off
several latent correctness problems from affecting pgrx going forward.

Fixes #1484
  • Loading branch information
workingjubilee authored May 20, 2024
1 parent 38646f9 commit 6bd4f0e
Show file tree
Hide file tree
Showing 14 changed files with 653 additions and 361 deletions.
7 changes: 7 additions & 0 deletions pgrx-examples/custom_types/src/hexint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//LICENSE All rights reserved.
//LICENSE
//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
use pgrx::callconv::BoxRet;
use pgrx::pg_sys::{Datum, Oid};
use pgrx::pgrx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
Expand Down Expand Up @@ -93,6 +94,12 @@ impl IntoDatum for HexInt {
}
}

unsafe impl BoxRet for HexInt {
unsafe fn box_in_fcinfo(self, _fcinfo: pg_sys::FunctionCallInfo) -> pg_sys::Datum {
Datum::from(self.value)
}
}

/// Input function for `HexInt`. Parses any valid "radix(16)" text string, with or without a leading
/// `0x` (or `0X`) into a `HexInt` type. Parse errors are returned and handled by pgrx
///
Expand Down
14 changes: 4 additions & 10 deletions pgrx-examples/spi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,7 @@ INSERT INTO spi_example (title) VALUES ('I like pudding');

#[pg_extern]
fn spi_return_query() -> Result<
TableIterator<
'static,
(
name!(oid, Result<Option<pg_sys::Oid>, pgrx::spi::Error>),
name!(name, Result<Option<String>, pgrx::spi::Error>),
),
>,
TableIterator<'static, (name!(oid, Option<pg_sys::Oid>), name!(name, Option<String>))>,
spi::Error,
> {
#[cfg(feature = "pg12")]
Expand All @@ -51,10 +45,10 @@ fn spi_return_query() -> Result<
let query = "SELECT oid, relname::text || '-pg16' FROM pg_class";

Spi::connect(|client| {
Ok(client
client
.select(query, None, None)?
.map(|row| (row["oid"].value(), row[2].value()))
.collect::<Vec<_>>())
.map(|row| Ok((row["oid"].value()?, row[2].value()?)))
.collect::<Result<Vec<_>, _>>()
})
.map(TableIterator::new)
}
Expand Down
40 changes: 20 additions & 20 deletions pgrx-examples/spi_srf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ fn calculate_human_years() -> Result<
TableIterator<
'static,
(
name!(dog_name, Result<Option<String>, pgrx::spi::Error>),
name!(dog_name, Option<String>),
name!(dog_age, i32),
name!(dog_breed, Result<Option<String>, pgrx::spi::Error>),
name!(dog_breed, Option<String>),
name!(human_age, i32),
),
>,
Expand All @@ -62,7 +62,7 @@ fn calculate_human_years() -> Result<
let dog_age = row["dog_age"].value::<i32>()?.expect("dog_age was null");
let dog_breed = row["dog_breed"].value::<String>();
let human_age = dog_age * 7;
results.push((dog_name, dog_age, dog_breed, human_age));
results.push((dog_name?, dog_age, dog_breed?, human_age));
}

Ok(TableIterator::new(results))
Expand All @@ -76,9 +76,9 @@ fn filter_by_breed(
TableIterator<
'static,
(
name!(dog_name, Result<Option<String>, pgrx::spi::Error>),
name!(dog_age, Result<Option<i32>, pgrx::spi::Error>),
name!(dog_breed, Result<Option<String>, pgrx::spi::Error>),
name!(dog_name, Option<String>),
name!(dog_age, Option<i32>),
name!(dog_breed, Option<String>),
),
>,
spi::Error,
Expand All @@ -95,9 +95,11 @@ fn filter_by_breed(
let tup_table = client.select(query, None, Some(args))?;

let filtered = tup_table
.map(|row| (row["dog_name"].value(), row["dog_age"].value(), row["dog_breed"].value()))
.collect::<Vec<_>>();
Ok(TableIterator::new(filtered))
.map(|row| {
Ok((row["dog_name"].value()?, row["dog_age"].value()?, row["dog_breed"].value()?))
})
.collect::<Result<Vec<_>, _>>();
filtered.map(|v| TableIterator::new(v))
})
}

Expand All @@ -107,19 +109,17 @@ mod tests {
use crate::calculate_human_years;
use pgrx::prelude::*;

#[rustfmt::skip]
#[pg_test]
fn test_calculate_human_years() -> Result<(), pgrx::spi::Error> {
let mut results: Vec<(Result<Option<String>, _>, i32, Result<Option<String>, _>, i32)> =
Vec::new();

results.push((Ok(Some("Fido".to_string())), 3, Ok(Some("Labrador".to_string())), 21));
results.push((Ok(Some("Spot".to_string())), 5, Ok(Some("Poodle".to_string())), 35));
results.push((Ok(Some("Rover".to_string())), 7, Ok(Some("Golden Retriever".to_string())), 49));
results.push((Ok(Some("Snoopy".to_string())), 9, Ok(Some("Beagle".to_string())), 63));
results.push((Ok(Some("Lassie".to_string())), 11, Ok(Some("Collie".to_string())), 77));
results.push((Ok(Some("Scooby".to_string())), 13, Ok(Some("Great Dane".to_string())), 91));
results.push((Ok(Some("Moomba".to_string())), 15, Ok(Some("Labrador".to_string())), 105));
let mut results = Vec::new();

results.push((Some("Fido".to_string()), 3, Some("Labrador".to_string()), 21));
results.push((Some("Spot".to_string()), 5, Some("Poodle".to_string()), 35));
results.push((Some("Rover".to_string()), 7, Some("Golden Retriever".to_string()), 49));
results.push((Some("Snoopy".to_string()), 9, Some("Beagle".to_string()), 63));
results.push((Some("Lassie".to_string()), 11, Some("Collie".to_string()), 77));
results.push((Some("Scooby".to_string()), 13, Some("Great Dane".to_string()), 91));
results.push((Some("Moomba".to_string()), 15, Some("Labrador".to_string()), 105));
let func_results = calculate_human_years()?;

for (expected, actual) in results.iter().zip(func_results) {
Expand Down
4 changes: 2 additions & 2 deletions pgrx-examples/srf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ fn random_values(num_rows: i32) -> TableIterator<'static, (name!(index, i32), na

#[pg_extern]
fn result_table() -> Result<
Option<::pgrx::iter::TableIterator<'static, (name!(a, Option<i32>), name!(b, Option<i32>))>>,
::pgrx::iter::TableIterator<'static, (name!(a, Option<i32>), name!(b, Option<i32>))>,
Box<dyn std::error::Error + Send + Sync + 'static>,
> {
Ok(Some(TableIterator::new(vec![(Some(1), Some(2))])))
Ok(TableIterator::new(vec![(Some(1), Some(2))]))
}

#[pg_extern]
Expand Down
15 changes: 15 additions & 0 deletions pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,12 @@ fn impl_postgres_enum(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
}

}

unsafe impl ::pgrx::callconv::BoxRet for #enum_ident {
unsafe fn box_in_fcinfo(self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pg_sys::Datum {
::pgrx::datum::IntoDatum::into_datum(self).unwrap()
}
}
});

let sql_graph_entity_item = PostgresEnum::from_derive_input(sql_graph_entity_ast)?;
Expand Down Expand Up @@ -808,6 +814,15 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream>
}
}

unsafe impl #generics ::pgrx::callconv::BoxRet for #name #generics {
unsafe fn box_in_fcinfo(self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pg_sys::Datum {
match ::pgrx::datum::IntoDatum::into_datum(self) {
None => ::pgrx::fcinfo::pg_return_null(fcinfo),
Some(datum) => datum,
}
}
}

impl #generics ::pgrx::datum::FromDatum for #name #generics {
unsafe fn from_polymorphic_datum(
datum: ::pgrx::pg_sys::Datum,
Expand Down
1 change: 0 additions & 1 deletion pgrx-pg-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#![allow(non_upper_case_globals)]
#![allow(improper_ctypes)]
#![allow(clippy::unneeded_field_pattern)]
#![cfg_attr(nightly, feature(strict_provenance))]

#[cfg(
// no features at all will cause problems
Expand Down
16 changes: 3 additions & 13 deletions pgrx-sql-entity-graph/src/pg_extern/entity/returning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,9 @@ use crate::UsedTypeEntity;
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub enum PgExternReturnEntity {
None,
Type {
ty: UsedTypeEntity,
},
SetOf {
ty: UsedTypeEntity,
is_option: bool, /* Eg `Option<SetOfIterator<T>>` */
is_result: bool, /* Eg `Result<SetOfIterator<T>, E>` */
},
Iterated {
tys: Vec<PgExternReturnEntityIteratedItem>,
is_option: bool, /* Eg `Option<TableIterator<T>>` */
is_result: bool, /* Eg `Result<TableIterator<T>, E>` */
},
Type { ty: UsedTypeEntity },
SetOf { ty: UsedTypeEntity },
Iterated { tys: Vec<PgExternReturnEntityIteratedItem> },
Trigger,
}

Expand Down
152 changes: 22 additions & 130 deletions pgrx-sql-entity-graph/src/pg_extern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::ToSqlConfig;
use operator::{PgrxOperatorAttributeWithIdent, PgrxOperatorOpName};
use search_path::SearchPathList;

use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{format_ident, quote, quote_spanned};
use syn::parse::{Parse, ParseStream, Parser};
use syn::punctuated::Punctuated;
Expand Down Expand Up @@ -406,141 +406,33 @@ impl PgExtern {
}
});

// Iterators require fancy handling for their retvals
let emit_result_handler = |span: Span, optional: bool, result: bool| {
let mut ret_expr = quote! { #func_name(#(#arg_pats),*) };
if result {
// If it's a result, we need to report it.
ret_expr = quote! { #ret_expr.unwrap_or_report() };
}
if !optional {
// If it's not already an option, we need to wrap it.
ret_expr = quote! { Some(#ret_expr) };
}
let import = result.then(|| quote! { use ::pgrx::pg_sys::panic::ErrorReportable; });
quote_spanned! { span =>
#import
#ret_expr
}
};

match &self.returns {
Returning::None => {
let fn_contents = quote! {
#(#arg_fetches)*
#[allow(unused_unsafe)]
unsafe { #func_name(#(#arg_pats),*) };
// -> () means always returning the zero Datum
::pgrx::pg_sys::Datum::from(0)
};
finfo_v1_extern_c(&self.func, fcinfo_ident, fn_contents)
}
Returning::Type(retval_ty) => {
let result_ident = syn::Ident::new("result", self.func.sig.span());
let retval_transform = if retval_ty.resolved_ty == syn::parse_quote!(()) {
quote_spanned! { self.func.sig.output.span() =>
unsafe { ::pgrx::fcinfo::pg_return_void() }
}
} else if retval_ty.result && retval_ty.optional.is_some() {
// returning `Result<Option<T>>`
quote_spanned! { self.func.sig.output.span() =>
match ::pgrx::datum::IntoDatum::into_datum(#result_ident) {
Some(datum) => datum,
None => unsafe { ::pgrx::fcinfo::pg_return_null(#fcinfo_ident) },
}
}
} else if retval_ty.result {
// returning Result<T>
quote_spanned! { self.func.sig.output.span() =>
::pgrx::datum::IntoDatum::into_datum(#result_ident).unwrap_or_else(|| panic!("returned Datum was NULL"))
}
} else if retval_ty.resolved_ty.last_ident_is("Datum") {
// As before, we can just throw this in because it must typecheck
quote_spanned! { self.func.sig.output.span() =>
#result_ident
}
} else if retval_ty.optional.is_some() {
quote_spanned! { self.func.sig.output.span() =>
match #result_ident {
Some(result) => {
::pgrx::datum::IntoDatum::into_datum(result).unwrap_or_else(|| panic!("returned Option<T> was NULL"))
},
None => unsafe { ::pgrx::fcinfo::pg_return_null(#fcinfo_ident) }
}
}
} else {
quote_spanned! { self.func.sig.output.span() =>
::pgrx::datum::IntoDatum::into_datum(#result_ident).unwrap_or_else(|| panic!("returned Datum was NULL"))
}
};

let fn_contents = quote! {
#(#arg_fetches)*

#[allow(unused_unsafe)] // unwrapped fn might be unsafe
let #result_ident = unsafe { #func_name(#(#arg_pats),*) };

#retval_transform
Returning::None
| Returning::Type(_)
| Returning::SetOf { .. }
| Returning::Iterated { .. } => {
let ret_ty = match &self.func.sig.output {
syn::ReturnType::Default => syn::parse_quote! { () },
syn::ReturnType::Type(_, ret_ty) => ret_ty.clone(),
};
finfo_v1_extern_c(&self.func, fcinfo_ident, fn_contents)
}
Returning::SetOf { ty: _retval_ty, is_option, is_result } => {
let result_handler =
emit_result_handler(self.func.sig.span(), *is_option, *is_result);
let setof_closure = quote! {
let wrapper_code = quote_spanned! { self.func.block.span() =>
#[allow(unused_unsafe)]
unsafe {
// SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
// with all its fields properly setup. Unless the user is calling this wrapper function directly, this
// will always be the case
::pgrx::iter::SetOfIterator::srf_next(#fcinfo_ident, || {
#( #arg_fetches )*
#result_handler
})
}
};
finfo_v1_extern_c(&self.func, fcinfo_ident, setof_closure)
}
Returning::Iterated { tys: retval_tys, is_option, is_result } => {
let result_handler =
emit_result_handler(self.func.sig.span(), *is_option, *is_result);

let iter_closure = if retval_tys.len() == 1 {
// Postgres considers functions returning a 1-field table (`RETURNS TABLE (T)`) to be
// a function that `RETURNS SETOF T`. So we write a different wrapper implementation
// that transparently transforms the `TableIterator` returned by the user into a `SetOfIterator`
quote! {
#[allow(unused_unsafe)]
unsafe {
// SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
// with all its fields properly setup. Unless the user is calling this wrapper function directly, this
// will always be the case
::pgrx::iter::SetOfIterator::srf_next(#fcinfo_ident, || {
#( #arg_fetches )*
let table_iterator = { #result_handler };

// we need to convert the 1-field `TableIterator` provided by the user
// into a SetOfIterator in order to properly handle the case of `RETURNS TABLE (T)`,
// which is a table that returns only 1 field.
table_iterator.map(|i| ::pgrx::iter::SetOfIterator::new(i.into_iter().map(|(v,)| v)))
})
}
}
} else {
quote! {
#[allow(unused_unsafe)]
unsafe {
// SAFETY: the caller has asserted that `fcinfo` is a valid FunctionCallInfo pointer, allocated by Postgres
// with all its fields properly setup. Unless the user is calling this wrapper function directly, this
// will always be the case
::pgrx::iter::TableIterator::srf_next(#fcinfo_ident, || {
#( #arg_fetches )*
#result_handler
})
}
let fcinfo = #fcinfo_ident;
let result = match <#ret_ty as ::pgrx::callconv::RetAbi>::check_fcinfo_and_prepare(fcinfo) {
::pgrx::callconv::CallCx::WrappedFn(mcx) => {
let mut mcx = ::pgrx::PgMemoryContexts::For(mcx);
::pgrx::callconv::RetAbi::to_ret(mcx.switch_to(|_| {
#(#arg_fetches)*
#func_name( #(#arg_pats),* )
}))
}
::pgrx::callconv::CallCx::RestoreCx => <#ret_ty as ::pgrx::callconv::RetAbi>::ret_from_fcinfo_fcx(fcinfo),
};
unsafe { <#ret_ty as ::pgrx::callconv::RetAbi>::box_ret_in_fcinfo(fcinfo, result) }
}
};
finfo_v1_extern_c(&self.func, fcinfo_ident, iter_closure)
finfo_v1_extern_c(&self.func, fcinfo_ident, wrapper_code)
}
}
}
Expand Down
Loading

0 comments on commit 6bd4f0e

Please sign in to comment.