diff --git a/optd-mvp/DESIGN.md b/optd-mvp/DESIGN.md new file mode 100644 index 0000000..e50a56c --- /dev/null +++ b/optd-mvp/DESIGN.md @@ -0,0 +1,67 @@ +# Duplicate Elimination Memo Table + +Note that most of the details are in `src/memo/persistent/implementation.rs`. + +For this document, we are assuming that the memo table is backed by a database / ORM. A lot of these +problems would likely not be an issue if everything was in memory. + +## Group Merging + +During logical exploration, there will be rules that create cycles between groups. The easy solution +for this is to immediately merge two groups together when the engine determines that adding an +expression would result in a duplicate expression from another group. + +However, if we want to support parallel exploration, this could be prone to high contention. By +definition, merging group G1 into group G2 would mean that _every expression_ that has a child of +group G1 with would need to be rewritten to point to group G2 instead. + +This is unacceptable in a parallel setting, as that would mean every single task that gets affected +would need to either wait for the rewrites to happen before resuming work, or need to abort their +work because data has changed underneath them. + +So immediate / eager group merging is not a great idea for parallel exploration. However, if we do +not ever merge two groups that are identical, we are subject to doing duplicate work for every +duplicate expression in the memo table during physical optimization. + +Instead of merging groups together immediately, we can instead maintain an auxiliary data structure +that records the groups that _eventually_ need to get merged, and "lazily" merge those groups +together once every group has finished exploration. + +## Union-Find Group Sets + +We use the well-known Union-Find algorithm and corresponding data structure as the auxiliary data +structure that tracks the to-be-merged groups. + +Union-Find supports `Union` and `Find` operations, where `Union` merges sets and `Find` searches for +a "canonical" or "root" element that is shared between all elements in a given set. + +For more information about Union-Find, see these +[15-451 lecture notes](https://www.cs.cmu.edu/~15451-f24/lectures/lecture08-union-find.pdf). + +Here, we make the elements the groups themselves (really the Group IDs), which allows us to merge +group sets together and also determine a "root group" that all groups in a set can agree on. + +When every group in a group set has finished exploration, we can safely begin to merge them +together by moving all expressions from every group in the group set into a single large group. +Other than making sure that any reference to an old group in the group set points to this new large +group, exploration of all groups are done and physical optimization can start. + +RFC: Do we need to support incremental search? + +Note that since we are now waiting for exploration of all groups to finish, this algorithm is much +closer to the Volcano framework than the Cascades' incremental search. However, since we eventually +will want to store trails / breadcrumbs of decisions made to skip work in the future, and since we +essentially have unlimited space due to the memo table being backed by a DBMS, this is not as much +of a problem. + +## Duplicate Detection + +TODO explain the fingerprinting algorithm and how it relates to group merging + +When taking the fingerprint of an expression, the child groups of an expression need to be root groups. If they are not, we need to try again. +Assuming that all children are root groups, the fingerprint we make for any expression that fulfills that is valid and can be looked up for duplicates. +In order to maintain that correctness, on a merge of two sets, the smaller one requires that a new fingerprint be generated for every expression that has a group in that smaller set. +For example, let's say we need to merge { 1, 2 } (root group 1) with { 3, 4, 5, 6, 7, 8 } (root group 3). We need to find every single expression that has a child group of 1 or 2 and we need to generate a new fingerprint for each where the child groups have been "rewritten" to 3 + +TODO this is incredibly expensive, but is potentially easily parallelizable? + diff --git a/optd-mvp/README.md b/optd-mvp/entities.md similarity index 100% rename from optd-mvp/README.md rename to optd-mvp/entities.md diff --git a/optd-mvp/src/entities/cascades_group.rs b/optd-mvp/src/entities/cascades_group.rs index 9c2ba83..62e1835 100644 --- a/optd-mvp/src/entities/cascades_group.rs +++ b/optd-mvp/src/entities/cascades_group.rs @@ -7,9 +7,9 @@ use sea_orm::entity::prelude::*; pub struct Model { #[sea_orm(primary_key)] pub id: i32, + pub status: i8, pub winner: Option, pub cost: Option, - pub is_optimized: bool, pub parent_id: Option, } diff --git a/optd-mvp/src/memo/mod.rs b/optd-mvp/src/memo/mod.rs index fbf23a2..83a821f 100644 --- a/optd-mvp/src/memo/mod.rs +++ b/optd-mvp/src/memo/mod.rs @@ -19,6 +19,14 @@ pub struct LogicalExpressionId(pub i32); #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct PhysicalExpressionId(pub i32); +/// A status enum representing the different states a group can be during query optimization. +#[repr(u8)] +pub enum GroupStatus { + InProgress = 0, + Explored = 1, + Optimized = 2, +} + /// The different kinds of errors that might occur while running operations on a memo table. #[derive(Error, Debug)] pub enum MemoError { diff --git a/optd-mvp/src/memo/persistent/implementation.rs b/optd-mvp/src/memo/persistent/implementation.rs index 4fc7048..002893a 100644 --- a/optd-mvp/src/memo/persistent/implementation.rs +++ b/optd-mvp/src/memo/persistent/implementation.rs @@ -10,7 +10,7 @@ use super::PersistentMemo; use crate::{ entities::*, expression::{LogicalExpression, PhysicalExpression}, - memo::{GroupId, LogicalExpressionId, MemoError, PhysicalExpressionId}, + memo::{GroupId, GroupStatus, LogicalExpressionId, MemoError, PhysicalExpressionId}, OptimizerResult, DATABASE_URL, }; use sea_orm::{ @@ -66,6 +66,40 @@ impl PersistentMemo { .ok_or(MemoError::UnknownGroup(group_id))?) } + /// Retrieves the root / canonical group ID of the given group ID. + /// + /// The groups form a union find / disjoint set parent pointer forest, where group merging + /// causes two trees to merge. + /// + /// This function uses the path compression optimization, which amortizes the cost to a single + /// lookup (theoretically in constant time, but we must be wary of the I/O roundtrip). + pub async fn get_root_group(&self, group_id: GroupId) -> OptimizerResult { + let mut curr_group = self.get_group(group_id).await?; + + // Traverse up the path and find the root group, keeping track of groups we have visited. + let mut path = vec![]; + loop { + let Some(parent_id) = curr_group.parent_id else { + break; + }; + + let next_group = self.get_group(GroupId(parent_id)).await?; + path.push(curr_group); + curr_group = next_group; + } + + let root_id = GroupId(curr_group.id); + + // Path Compression Optimization: + // For every group along the path that we walked, set their parent id pointer to the root. + // This allows for an amortized O(1) cost for `get_root_group`. + for group in path { + self.update_group_parent(GroupId(group.id), root_id).await?; + } + + Ok(root_id) + } + /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`]. /// /// If the physical expression does not exist, returns a @@ -146,6 +180,32 @@ impl PersistentMemo { Ok(children) } + /// Updates / replaces a group's status. Returns the previous group status. + /// + /// If the group does not exist, returns a [`MemoError::UnknownGroup`] error. + pub async fn update_group_status( + &self, + group_id: GroupId, + status: GroupStatus, + ) -> OptimizerResult { + // First retrieve the group record. + let mut group = self.get_group(group_id).await?.into_active_model(); + + // Update the group's status. + let old_status = group.status; + group.status = Set(status as u8 as i8); + group.update(&self.db).await?; + + let old_status = match old_status.unwrap() { + 0 => GroupStatus::InProgress, + 1 => GroupStatus::Explored, + 2 => GroupStatus::Optimized, + _ => panic!("encountered an invalid group status"), + }; + + Ok(old_status) + } + /// Updates / replaces a group's best physical plan (winner). Optionally returns the previous /// winner's physical expression ID. /// @@ -167,8 +227,32 @@ impl PersistentMemo { group.update(&self.db).await?; // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. - let old = old_id.unwrap().map(PhysicalExpressionId); - Ok(old) + let old_id = old_id.unwrap().map(PhysicalExpressionId); + Ok(old_id) + } + + /// Updates / replaces a group's parent group. Optionally returns the previous parent. + /// + /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error. + pub async fn update_group_parent( + &self, + group_id: GroupId, + parent_id: GroupId, + ) -> OptimizerResult> { + // First retrieve the group record. + let mut group = self.get_group(group_id).await?.into_active_model(); + + // Check that the parent group exists. + let _ = self.get_group(parent_id).await?; + + // Update the group to point to the new parent. + let old_parent = group.parent_id; + group.parent_id = Set(Some(parent_id.0)); + group.update(&self.db).await?; + + // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. + let old_parent = old_parent.unwrap().map(GroupId); + Ok(old_parent) } /// Adds a logical expression to an existing group via its ID. @@ -192,10 +276,10 @@ impl PersistentMemo { group_id: GroupId, logical_expression: LogicalExpression, children: &[GroupId], - ) -> OptimizerResult> { + ) -> OptimizerResult> { // Check if the expression already exists anywhere in the memo table. if let Some(existing_id) = self - .is_duplicate_logical_expression(&logical_expression) + .is_duplicate_logical_expression(&logical_expression, children) .await? { return Ok(Err(existing_id)); @@ -227,7 +311,15 @@ impl PersistentMemo { // Finally, insert the fingerprint of the logical expression as well. let new_expr: LogicalExpression = new_model.into(); let kind = new_expr.kind(); - let hash = new_expr.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let hash = new_expr.fingerprint_with_rewrite(&rewrites); let fingerprint = fingerprint::ActiveModel { id: NotSet, @@ -285,8 +377,8 @@ impl PersistentMemo { /// In order to prevent a large amount of duplicate work, the memo table must support duplicate /// expression detection. /// - /// Returns `Some(expression_id)` if the memo table detects that the expression already exists, - /// and `None` otherwise. + /// Returns `Some((group_id, expression_id))` if the memo table detects that the expression + /// already exists, and `None` otherwise. /// /// This function assumes that the child groups of the expression are currently roots of their /// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input @@ -296,13 +388,22 @@ impl PersistentMemo { pub async fn is_duplicate_logical_expression( &self, logical_expression: &LogicalExpression, - ) -> OptimizerResult> { + children: &[GroupId], + ) -> OptimizerResult> { let model: logical_expression::Model = logical_expression.clone().into(); // Lookup all expressions that have the same fingerprint and kind. There may be false // positives, but we will check for those next. let kind = model.kind; - let fingerprint = logical_expression.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let fingerprint = logical_expression.fingerprint_with_rewrite(&rewrites); // Filter first by the fingerprint, and then the kind. // FIXME: The kind is already embedded into the fingerprint, so we may not actually need the @@ -323,11 +424,11 @@ impl PersistentMemo { let mut match_id = None; for potential_match in potential_matches { let expr_id = LogicalExpressionId(potential_match.logical_expression_id); - let (_, expr) = self.get_logical_expression(expr_id).await?; + let (group_id, expr) = self.get_logical_expression(expr_id).await?; // Check for an exact match. if &expr == logical_expression { - match_id = Some(expr_id); + match_id = Some((group_id, expr_id)); // There should be at most one duplicate expression, so we can break here. break; @@ -359,18 +460,17 @@ impl PersistentMemo { ) -> OptimizerResult> { // Check if the expression already exists in the memo table. - if let Some(existing_id) = self - .is_duplicate_logical_expression(&logical_expression) + if let Some((group_id, existing_id)) = self + .is_duplicate_logical_expression(&logical_expression, children) .await? { - let (group_id, _expr) = self.get_logical_expression(existing_id).await?; return Ok(Err((group_id, existing_id))); } // The expression does not exist yet, so we need to create a new group and new expression. let group = cascades_group::ActiveModel { winner: Set(None), - is_optimized: Set(false), + status: Set(0), // `GroupStatus::InProgress` status. ..Default::default() }; @@ -401,7 +501,15 @@ impl PersistentMemo { // Finally, insert the fingerprint of the logical expression as well. let new_expr: LogicalExpression = new_model.into(); let kind = new_expr.kind(); - let hash = new_expr.fingerprint(); + + // In order to calculate a correct fingerprint, we will want to use the IDs of the root + // groups of the children instead of the child ID themselves. + let mut rewrites = vec![]; + for &child_id in children { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + let hash = new_expr.fingerprint_with_rewrite(&rewrites); let fingerprint = fingerprint::ActiveModel { id: NotSet, diff --git a/optd-mvp/src/memo/persistent/tests.rs b/optd-mvp/src/memo/persistent/tests.rs index f3afea6..3dcddd6 100644 --- a/optd-mvp/src/memo/persistent/tests.rs +++ b/optd-mvp/src/memo/persistent/tests.rs @@ -34,20 +34,22 @@ async fn test_simple_logical_duplicates() { // Test `add_logical_expression_to_group`. { // Attempting to add a duplicate expression into the same group should also fail every time. - let logical_expression_id_2a = memo + let (group_id_2a, logical_expression_id_2a) = memo .add_logical_expression_to_group(group_id, scan2a, &[]) .await .unwrap() .err() .unwrap(); + assert_eq!(group_id, group_id_2a); assert_eq!(logical_expression_id, logical_expression_id_2a); - let logical_expression_id_2b = memo + let (group_id_2b, logical_expression_id_2b) = memo .add_logical_expression_to_group(group_id, scan2b, &[]) .await .unwrap() .err() .unwrap(); + assert_eq!(group_id, group_id_2b); assert_eq!(logical_expression_id, logical_expression_id_2b); } @@ -140,3 +142,71 @@ async fn test_simple_tree() { memo.cleanup().await; } + +/// Tests basic group merging. See comments in the test itself for more information. +#[ignore] +#[tokio::test] +async fn test_simple_group_link() { + let memo = PersistentMemo::new().await; + memo.cleanup().await; + + // Create two scan groups. + let scan1 = scan("t1".to_string()); + let scan2 = scan("t2".to_string()); + let (scan_id_1, _) = memo.add_group(scan1, &[]).await.unwrap().ok().unwrap(); + let (scan_id_2, _) = memo.add_group(scan2, &[]).await.unwrap().ok().unwrap(); + + // Create two join expression that should be in the same group. + // Even though these are obviously the same expression (to humans), the fingerprints will be + // different, and so they will be put into different groups. + let join1 = join(scan_id_1, scan_id_2, "t1.a = t2.b".to_string()); + let join2 = join(scan_id_2, scan_id_1, "t2.b = t1.a".to_string()); + let join_unknown = join2.clone(); + + let (join_group_1, _) = memo + .add_group(join1, &[scan_id_1, scan_id_2]) + .await + .unwrap() + .ok() + .unwrap(); + let (join_group_2, join_expr_2) = memo + .add_group(join2, &[scan_id_2, scan_id_1]) + .await + .unwrap() + .ok() + .unwrap(); + assert_ne!(join_group_1, join_group_2); + + // Assume that some rule was applied to `join1`, and it outputs something like `join_unknown`. + // The memo table will tell us that `join_unknown == join2`. + // Take note here that `join_unknown` is a clone of `join2`, not `join1`. + let (existing_group, not_actually_new_expr_id) = memo + .add_logical_expression_to_group(join_group_1, join_unknown, &[scan_id_2, scan_id_1]) + .await + .unwrap() + .err() + .unwrap(); + assert_eq!(existing_group, join_group_2); + assert_eq!(not_actually_new_expr_id, join_expr_2); + + // The above tells the application that the expression already exists in the memo, specifically + // under `existing_group`. Thus, we should link these two groups together. + // Here, we arbitrarily choose to link group 1 into group 2. + memo.update_group_parent(join_group_1, join_group_2) + .await + .unwrap(); + + let test_root_1 = memo.get_root_group(join_group_1).await.unwrap(); + let test_root_2 = memo.get_root_group(join_group_2).await.unwrap(); + assert_eq!(test_root_1, test_root_2); + + // TODO(Connor) + // + // We now need to find all logical expressions that had group 1 (or whatever the root group of + // the set that group 1 belongs to is, in this case it is just group 1) as a child, and add a + // new fingerprint for each one that uses group 2 as a child instead. + // + // In order to do this, we need to iterate through every group in group 1's set. + + memo.cleanup().await; +} diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs b/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs index 3a0e7d0..abaa829 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_cascades_group.rs @@ -74,9 +74,9 @@ use sea_orm_migration::{prelude::*, schema::*}; pub enum CascadesGroup { Table, Id, + Status, Winner, Cost, - IsOptimized, ParentId, } @@ -92,8 +92,9 @@ impl MigrationTrait for Migration { .table(CascadesGroup::Table) .if_not_exists() .col(pk_auto(CascadesGroup::Id)) + .col(tiny_integer(CascadesGroup::Status)) .col(integer_null(CascadesGroup::Winner)) - .col(big_unsigned_null(CascadesGroup::Cost)) + .col(big_integer_null(CascadesGroup::Cost)) .foreign_key( ForeignKey::create() .from(CascadesGroup::Table, CascadesGroup::Winner) @@ -101,7 +102,6 @@ impl MigrationTrait for Migration { .on_delete(ForeignKeyAction::SetNull) .on_update(ForeignKeyAction::Cascade), ) - .col(boolean(CascadesGroup::IsOptimized)) .col(integer_null(CascadesGroup::ParentId)) .foreign_key( ForeignKey::create() diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs b/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs index 4a828b8..e153b9e 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_fingerprint.rs @@ -26,7 +26,7 @@ impl MigrationTrait for Migration { .table(Fingerprint::Table) .if_not_exists() .col(pk_auto(Fingerprint::Id)) - .col(unsigned(Fingerprint::LogicalExpressionId)) + .col(integer(Fingerprint::LogicalExpressionId)) .foreign_key( ForeignKey::create() .from(Fingerprint::Table, Fingerprint::LogicalExpressionId) @@ -34,8 +34,8 @@ impl MigrationTrait for Migration { .on_delete(ForeignKeyAction::Cascade) .on_update(ForeignKeyAction::Cascade), ) - .col(small_unsigned(Fingerprint::Kind)) - .col(big_unsigned(Fingerprint::Hash)) + .col(small_integer(Fingerprint::Kind)) + .col(big_integer(Fingerprint::Hash)) .to_owned(), ) .await