Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cost model ORM implemenation #24

Merged
merged 28 commits into from
Nov 9, 2024
Merged
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
304a042
Add implementation for create_new_epoch
lanlou1554 Nov 7, 2024
99f36cf
implement get_stats methods
xx01cyx Nov 7, 2024
9f71c05
Initial draft of update_stat
lanlou1554 Nov 7, 2024
5dce3dd
Small optimization
lanlou1554 Nov 7, 2024
c05d664
Basic error handling
lanlou1554 Nov 7, 2024
3840673
Finish update_stats and store_expr_stats_mappings
lanlou1554 Nov 7, 2024
d05c493
Fix clippy
lanlou1554 Nov 7, 2024
6d1f82e
integrate w nullable columns
xx01cyx Nov 7, 2024
2d12016
Update test_create_new_epoch
lanlou1554 Nov 7, 2024
72a07d1
add: cost methods
unw9527 Nov 7, 2024
586c51f
add update_stats_from_catalog
lanlou1554 Nov 7, 2024
959610b
Add test_update_stats_from_catalog
lanlou1554 Nov 7, 2024
3cd2897
feat: enable unit test with every table intialized
unw9527 Nov 7, 2024
ea154e6
Fix clippy
lanlou1554 Nov 7, 2024
f56946c
fix description
xx01cyx Nov 7, 2024
8cc4f0e
Modify update_stat and add one related test
lanlou1554 Nov 8, 2024
6371d61
refine test infra and track init.db
xx01cyx Nov 8, 2024
fe891b0
add more initial data into stat-related tables
xx01cyx Nov 8, 2024
2956971
add new line at eof
xx01cyx Nov 8, 2024
6d70ad0
refine variant tag in init
xx01cyx Nov 8, 2024
a82d6c4
use json for stat type
xx01cyx Nov 8, 2024
4103607
add test_get_stats_for_table
xx01cyx Nov 8, 2024
7310fd2
minor fixes
xx01cyx Nov 8, 2024
3c22282
add test_get_stats_for_single_attr and test_get_stats_for_multiple_attrs
xx01cyx Nov 8, 2024
5ad0c7a
remove unused comments
xx01cyx Nov 8, 2024
1d12502
add: cost related tests
unw9527 Nov 8, 2024
6f97425
Fix update_stats and add all tests
lanlou1554 Nov 8, 2024
7a8d4df
Rebase on main
lanlou1554 Nov 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Modify update_stat and add one related test
lanlou1554 committed Nov 9, 2024
commit 8cc4f0e07c5e1ddfc4ae5a66716ba133a6e68a98
1 change: 1 addition & 0 deletions optd-persistent/src/cost_model/interface.rs
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ pub enum CatalogSource {

pub enum StatisticType {
Count,
Cardinality,
Min,
Max,
}
211 changes: 187 additions & 24 deletions optd-persistent/src/cost_model/orm.rs
Original file line number Diff line number Diff line change
@@ -8,8 +8,8 @@ use sea_orm::prelude::{Expr, Json};
use sea_orm::sea_query::Query;
use sea_orm::{sqlx::types::chrono::Utc, EntityTrait};
use sea_orm::{
ActiveModelTrait, ColumnTrait, DbErr, DeleteResult, EntityOrSelect, ModelTrait, QueryFilter,
QueryOrder, QuerySelect, RuntimeErr, TransactionTrait,
ActiveModelTrait, ColumnTrait, DbBackend, DbErr, DeleteResult, EntityOrSelect, ModelTrait,
QueryFilter, QueryOrder, QuerySelect, QueryTrait, RuntimeErr, TransactionTrait,
};

use super::catalog::mock_catalog::{self, MockCatalog};
@@ -154,27 +154,40 @@ impl CostModelStorageLayer for BackendManager {
}
}

// **IMPORTANT**: It is the caller's responsibility to ensure that the updated stat is not the same with the last stored stat if
// if it is already exists.
async fn update_stats(&self, stat: Stat, epoch_id: Self::EpochId) -> StorageResult<()> {
let transaction = self.db.begin().await?;
// let transaction = self.db.begin().await?;
// 0. Check if the stat already exists. If exists, get stat_id, else insert into statistic table.
let stat_id = match stat.table_id {
Some(table_id) => {
// TODO(lanlou): only select needed fields
let res = Statistic::find()
.filter(statistic::Column::TableId.eq(table_id))
.filter(statistic::Column::VariantTag.eq(stat.stat_type))
// FIX_ME: Do we need the following filter?
.inner_join(versioned_statistic::Entity)
.select_also(versioned_statistic::Entity)
.order_by_desc(versioned_statistic::Column::EpochId)
/*
TODO(FIX_ME, lanlou): Do we need the following filter?
I am really not sure although I add the top comment...
Since we already increase the epoch, so we should update the stat anyway.
(In theory, we can increase the epoch without updating the stat, but it is not
a straightforward design, and the epoch table will be very large.)
But it will increase the overhead, since the caller will need to make another
query to check if the stat is the same with the last one. We cannot put everything
in one query.
Let us assume we should update the stat anyway for now.
*/
// .inner_join(versioned_statistic::Entity)
// .select_also(versioned_statistic::Entity)
// .order_by_desc(versioned_statistic::Column::EpochId)
.one(&self.db)
.await?;
match res {
Some(stat_data) => {
if stat_data.1.unwrap().statistic_value == stat.stat_value {
return Ok(());
}
stat_data.0.id
// if stat_data.1.unwrap().statistic_value == stat.stat_value {
// return Ok(());
// }
// stat_data.0.id
stat_data.id
}
None => {
let new_stat = statistic::ActiveModel {
@@ -208,18 +221,18 @@ impl CostModelStorageLayer for BackendManager {
.filter(statistic::Column::NumberOfAttributes.eq(stat.attr_ids.len() as i32))
.filter(statistic::Column::Description.eq(description.clone()))
.filter(statistic::Column::VariantTag.eq(stat.stat_type))
// FIX_ME: Do we need the following filter?
.inner_join(versioned_statistic::Entity)
.select_also(versioned_statistic::Entity)
.order_by_desc(versioned_statistic::Column::EpochId)
// .inner_join(versioned_statistic::Entity)
// .select_also(versioned_statistic::Entity)
// .order_by_desc(versioned_statistic::Column::EpochId)
.one(&self.db)
.await?;
match res {
Some(stat_data) => {
if stat_data.1.unwrap().statistic_value == stat.stat_value {
return Ok(());
}
stat_data.0.id
// if stat_data.1.unwrap().statistic_value == stat.stat_value {
// return Ok(());
// }
// stat_data.0.id
stat_data.id
}
None => {
let new_stat = statistic::ActiveModel {
@@ -269,24 +282,25 @@ impl CostModelStorageLayer for BackendManager {
let _ = plan_cost::Entity::update_many()
.col_expr(plan_cost::Column::IsValid, Expr::value(false))
.filter(plan_cost::Column::IsValid.eq(true))
.filter(plan_cost::Column::EpochId.lt(epoch_id))
.filter(
plan_cost::Column::PhysicalExpressionId.in_subquery(
(*Query::select()
Query::select()
.column(
physical_expression_to_statistic_junction::Column::PhysicalExpressionId,
)
.from(physical_expression_to_statistic_junction::Entity)
.and_where(
.cond_where(
physical_expression_to_statistic_junction::Column::StatisticId
.eq(stat_id),
))
.to_owned(),
)
.to_owned(),
),
)
.exec(&self.db)
.await;

transaction.commit().await?;
// transaction.commit().await?;
Ok(())
}

@@ -441,6 +455,7 @@ impl CostModelStorageLayer for BackendManager {

#[cfg(test)]
mod tests {
use crate::cost_model::interface::StatisticType;
use crate::{cost_model::interface::Stat, migrate, CostModelStorageLayer};
use sea_orm::sqlx::database;
use sea_orm::Statement;
@@ -532,6 +547,154 @@ mod tests {
remove_db_file(DATABASE_FILE);
}

#[tokio::test]
async fn test_update_attr_stats() {
const DATABASE_FILE: &str = "test_update_attr_stats.db";
let database_url = copy_init_db(DATABASE_FILE).await;
let mut binding = super::BackendManager::new(Some(&database_url)).await;
let backend_manager = binding.as_mut().unwrap();
// 1. Update non-existed stat
let epoch_id1 = backend_manager
.create_new_epoch("test".to_string(), "InsertTest".to_string())
.await
.unwrap();
let stat = Stat {
stat_type: StatisticType::Count as i32,
stat_value: "100".to_string(),
attr_ids: vec![1],
table_id: None,
name: "CountAttr1".to_string(),
};
let res = backend_manager.update_stats(stat, epoch_id1).await;
assert!(res.is_ok());
let stat_res = Statistic::find()
.filter(statistic::Column::Name.eq("CountAttr1"))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(stat_res.len(), 1);
println!("{:?}", stat_res);
assert_eq!(stat_res[0].number_of_attributes, 1);
assert_eq!(stat_res[0].description, "1".to_string());
assert_eq!(stat_res[0].variant_tag, StatisticType::Count as i32);
let stat_attr_res = StatisticToAttributeJunction::find()
.filter(statistic_to_attribute_junction::Column::StatisticId.eq(stat_res[0].id))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(stat_attr_res.len(), 1);
assert_eq!(stat_attr_res[0].attribute_id, 1);
let versioned_stat_res = VersionedStatistic::find()
.filter(versioned_statistic::Column::StatisticId.eq(stat_res[0].id))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(versioned_stat_res.len(), 1);
assert_eq!(
versioned_stat_res[0].statistic_value,
serde_json::Value::String("100".to_string())
);
assert_eq!(versioned_stat_res[0].epoch_id, epoch_id1);

// 2. Normal update
// 2.1 Insert some costs
let res = PhysicalExpression::insert(physical_expression::ActiveModel {
group_id: sea_orm::ActiveValue::Set(1),
fingerprint: sea_orm::ActiveValue::Set(12346),
variant_tag: sea_orm::ActiveValue::Set(1),
data: sea_orm::ActiveValue::Set(serde_json::Value::String("data".to_string())),
..Default::default()
});
let expr_id = res.exec(&backend_manager.db).await.unwrap().last_insert_id;
let res = PhysicalExpressionToStatisticJunction::insert(
physical_expression_to_statistic_junction::ActiveModel {
physical_expression_id: sea_orm::ActiveValue::Set(expr_id),
statistic_id: sea_orm::ActiveValue::Set(stat_res[0].id),
},
)
.exec(&backend_manager.db)
.await
.unwrap();
backend_manager
.store_cost(expr_id, 42, versioned_stat_res[0].epoch_id)
.await
.unwrap();
let cost_res = PlanCost::find()
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(cost_res.len(), 1);
assert!(cost_res[0].is_valid);
// 2.2 Normal update
let epoch_id2 = backend_manager
.create_new_epoch("test".to_string(), "InsertTest".to_string())
.await
.unwrap();
let stat2 = Stat {
stat_type: StatisticType::Count as i32,
stat_value: "200".to_string(),
attr_ids: vec![1],
table_id: None,
name: "CountAttr1".to_string(),
};
let res = backend_manager.update_stats(stat2, epoch_id2).await;
assert!(res.is_ok());
// 2.3 Check statistic table
let stat_res = Statistic::find()
.filter(statistic::Column::Name.eq("CountAttr1"))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(stat_res.len(), 1);
assert_eq!(stat_res[0].number_of_attributes, 1);
assert_eq!(stat_res[0].description, "1".to_string());
assert_eq!(stat_res[0].variant_tag, StatisticType::Count as i32);
// 2.4 Check statistic_to_attribute_junction table
let stat_attr_res = StatisticToAttributeJunction::find()
.filter(statistic_to_attribute_junction::Column::StatisticId.eq(stat_res[0].id))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(stat_attr_res.len(), 1);
assert_eq!(stat_attr_res[0].attribute_id, 1);
// 2.5 Check versioned_statistic table
let versioned_stat_res = VersionedStatistic::find()
.filter(versioned_statistic::Column::StatisticId.eq(stat_res[0].id))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(versioned_stat_res.len(), 2);
assert_eq!(
versioned_stat_res[0].statistic_value,
serde_json::Value::String("100".to_string())
);
assert_eq!(versioned_stat_res[0].epoch_id, epoch_id1);
assert_eq!(versioned_stat_res[0].statistic_id, stat_res[0].id);
assert_eq!(
versioned_stat_res[1].statistic_value,
serde_json::Value::String("200".to_string())
);
assert_eq!(versioned_stat_res[1].epoch_id, epoch_id2);
assert_eq!(versioned_stat_res[1].statistic_id, stat_res[0].id);
assert!(epoch_id1 < epoch_id2);
// 2.6 Check plan_cost table (cost invalidation)
let cost_res = PlanCost::find()
.filter(plan_cost::Column::PhysicalExpressionId.eq(expr_id))
.all(&backend_manager.db)
.await
.unwrap();
assert_eq!(cost_res.len(), 1);
assert_eq!(cost_res[0].cost, 42);
assert_eq!(cost_res[0].epoch_id, epoch_id1);
assert!(!cost_res[0].is_valid);

remove_db_file(DATABASE_FILE);
}

#[tokio::test]
async fn test_update_table_stats() {}

#[tokio::test]
#[ignore] // Need to update all tables
async fn test_store_cost() {