Skip to content

Commit

Permalink
Abstract over id_matches in sql-entity-graph (#1705)
Browse files Browse the repository at this point in the history
Simplify the graph code by allowing similar types to be described
by a behavior they both implement. Abstractions can sometimes
clarify things! Who'd have thought?
  • Loading branch information
workingjubilee authored May 16, 2024
1 parent 99c2d9d commit 38646f9
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 180 deletions.
38 changes: 11 additions & 27 deletions pgrx-sql-entity-graph/src/aggregate/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::metadata::SqlMapping;
use crate::pgrx_sql::PgrxSql;
use crate::to_sql::entity::ToSqlConfigEntity;
use crate::to_sql::ToSql;
use crate::type_keyed;
use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
use core::any::TypeId;
use eyre::{eyre, WrapErr};
Expand Down Expand Up @@ -277,21 +278,14 @@ impl ToSql for PgAggregateEntity {
};

let stype_sql = map_ty(&self.stype.used_ty).wrap_err("Mapping state type")?;
let mut stype_schema = String::from("");
for (ty_item, ty_index) in context.types.iter() {
if ty_item.id_matches(&self.stype.used_ty.ty_id) {
stype_schema = context.schema_prefix_for(ty_index);
break;
}
}
if String::is_empty(&stype_schema) {
for (ty_item, ty_index) in context.enums.iter() {
if ty_item.id_matches(&self.stype.used_ty.ty_id) {
stype_schema = context.schema_prefix_for(ty_index);
break;
}
}
}
let stype_schema = context
.types
.iter()
.map(type_keyed)
.chain(context.enums.iter().map(type_keyed))
.find(|(ty, _)| ty.id_matches(&self.stype.used_ty.ty_id))
.map(|(_, ty_index)| context.schema_prefix_for(ty_index))
.unwrap_or_default();

if let Some(value) = &self.mstype {
let mstype_sql = map_ty(value).wrap_err("Mapping moving state type")?;
Expand Down Expand Up @@ -319,12 +313,7 @@ impl ToSql for PgAggregateEntity {
let graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == arg.used_ty.full_path,
_ => false,
})
.find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
.ok_or_else(|| {
eyre!("Could not find arg type in graph. Got: {:?}", arg.used_ty)
})?;
Expand Down Expand Up @@ -372,12 +361,7 @@ impl ToSql for PgAggregateEntity {
let graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == arg.used_ty.full_path,
_ => false,
})
.find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
.ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
let needs_comma = idx < (direct_args.len() - 1);
let buf = format!(
Expand Down
26 changes: 26 additions & 0 deletions pgrx-sql-entity-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,32 @@ impl SqlGraphEntity {
rust_identifier = self.rust_identifier(),
)
}

pub fn id_or_name_matches(&self, ty_id: &::core::any::TypeId, name: &str) -> bool {
match self {
SqlGraphEntity::Enum(entity) => entity.id_matches(ty_id),
SqlGraphEntity::Type(entity) => entity.id_matches(ty_id),
SqlGraphEntity::BuiltinType(string) => string == name,
_ => false,
}
}

pub fn type_matches(&self, arg: &dyn TypeIdentifiable) -> bool {
self.id_or_name_matches(arg.ty_id(), arg.ty_name())
}
}

pub trait TypeMatch {
fn id_matches(&self, arg: &core::any::TypeId) -> bool;
}

pub fn type_keyed<'a, 'b, A: TypeMatch, B>((a, b): (&'a A, &'b B)) -> (&'a dyn TypeMatch, &'b B) {
(a, b)
}

pub trait TypeIdentifiable {
fn ty_id(&self) -> &core::any::TypeId;
fn ty_name(&self) -> &str;
}

impl SqlGraphIdentifier for SqlGraphEntity {
Expand Down
59 changes: 14 additions & 45 deletions pgrx-sql-entity-graph/src/pg_extern/entity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ use crate::metadata::{Returns, SqlMapping};
use crate::pgrx_sql::PgrxSql;
use crate::to_sql::entity::ToSqlConfigEntity;
use crate::to_sql::ToSql;
use crate::ExternArgs;
use crate::{SqlGraphEntity, SqlGraphIdentifier};
use crate::{ExternArgs, SqlGraphEntity, SqlGraphIdentifier, TypeMatch};

use eyre::{eyre, WrapErr};

Expand Down Expand Up @@ -118,12 +117,7 @@ impl ToSql for PgExternEntity {
let graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == arg.used_ty.full_path,
_ => false,
})
.find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
.ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
let needs_comma = idx < (metadata_without_arg_skips.len().saturating_sub(1));
let metadata_argument = &self.metadata.arguments[idx];
Expand Down Expand Up @@ -182,12 +176,7 @@ impl ToSql for PgExternEntity {
let graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(neighbor_ty) => neighbor_ty.id_matches(&ty.ty_id),
SqlGraphEntity::Enum(neighbor_en) => neighbor_en.id_matches(&ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == ty.full_path,
_ => false,
})
.find(|neighbor| context.graph[*neighbor].type_matches(ty))
.ok_or_else(|| eyre!("Could not find return type in graph."))?;
let metadata_retval = self.metadata.retval.clone();
let sql_type = match metadata_retval.return_sql {
Expand All @@ -206,12 +195,7 @@ impl ToSql for PgExternEntity {
let graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(neighbor_ty) => neighbor_ty.id_matches(&ty.ty_id),
SqlGraphEntity::Enum(neighbor_en) => neighbor_en.id_matches(&ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == ty.full_path,
_ => false,
})
.find(|neighbor| context.graph[*neighbor].type_matches(ty))
.ok_or_else(|| eyre!("Could not find return type in graph."))?;
let metadata_retval = self.metadata.retval.clone();
let sql_type = match metadata_retval.return_sql {
Expand Down Expand Up @@ -251,16 +235,7 @@ impl ToSql for PgExternEntity {
{
let graph_index =
context.graph.neighbors_undirected(self_index).find(|neighbor| {
match &context.graph[*neighbor] {
SqlGraphEntity::Type(neighbor_ty) => {
neighbor_ty.id_matches(&ty.ty_id)
}
SqlGraphEntity::Enum(neighbor_en) => {
neighbor_en.id_matches(&ty.ty_id)
}
SqlGraphEntity::BuiltinType(defined) => defined == ty.ty_source,
_ => false,
}
context.graph[*neighbor].id_or_name_matches(&ty.ty_id, ty.ty_source)
});

let needs_comma = idx < (table_items.len() - 1);
Expand Down Expand Up @@ -381,11 +356,9 @@ impl ToSql for PgExternEntity {
let left_arg_graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&left_fn_arg.used_ty.ty_id),
SqlGraphEntity::Enum(en) => en.id_matches(&left_fn_arg.used_ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == left_arg.type_name,
_ => false,
.find(|neighbor| {
context.graph[*neighbor]
.id_or_name_matches(&left_fn_arg.used_ty.ty_id, left_arg.type_name)
})
.ok_or_else(|| {
eyre!("Could not find left arg type in graph. Got: {:?}", left_arg)
Expand Down Expand Up @@ -421,11 +394,9 @@ impl ToSql for PgExternEntity {
let right_arg_graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&right_fn_arg.used_ty.ty_id),
SqlGraphEntity::Enum(en) => en.id_matches(&right_fn_arg.used_ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == right_arg.type_name,
_ => false,
.find(|neighbor| {
context.graph[*neighbor]
.id_or_name_matches(&right_fn_arg.used_ty.ty_id, right_arg.type_name)
})
.ok_or_else(|| {
eyre!("Could not find right arg type in graph. Got: {:?}", right_arg)
Expand Down Expand Up @@ -537,11 +508,9 @@ impl ToSql for PgExternEntity {
let source_arg_graph_index = context
.graph
.neighbors_undirected(self_index)
.find(|neighbor| match &context.graph[*neighbor] {
SqlGraphEntity::Type(ty) => ty.id_matches(&source_fn_arg.used_ty.ty_id),
SqlGraphEntity::Enum(en) => en.id_matches(&source_fn_arg.used_ty.ty_id),
SqlGraphEntity::BuiltinType(defined) => defined == source_arg.type_name,
_ => false,
.find(|neighbor| {
context.graph[*neighbor]
.id_or_name_matches(&source_fn_arg.used_ty.ty_id, source_arg.type_name)
})
.ok_or_else(|| {
eyre!("Could not find source type in graph. Got: {:?}", source_arg)
Expand Down
Loading

0 comments on commit 38646f9

Please sign in to comment.