diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index 9ea73c09a..ef5cc73db 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -184,6 +184,34 @@ impl Store { Ok(()) } + /// Update existing iris with given shares. + pub async fn update_iris( + &self, + id: i64, + left_iris_share: &GaloisRingIrisCodeShare, + left_mask_share: &GaloisRingTrimmedMaskCodeShare, + right_iris_share: &GaloisRingIrisCodeShare, + right_mask_share: &GaloisRingTrimmedMaskCodeShare, + ) -> Result<()> { + let mut tx = self.pool.begin().await?; + + let query = sqlx::query( + r#" +UPDATE irises SET (left_code, left_mask, right_code, right_mask) = ($2, $3, $4, $5) +WHERE id = $1; +"#, + ) + .bind(id) + .bind(cast_slice::(&left_iris_share.coefs[..])) + .bind(cast_slice::(&left_mask_share.coefs[..])) + .bind(cast_slice::(&right_iris_share.coefs[..])) + .bind(cast_slice::(&right_mask_share.coefs[..])); + + query.execute(&mut *tx).await?; + tx.commit().await?; + Ok(()) + } + pub async fn insert_or_update_left_iris( &self, id: i64, @@ -659,6 +687,68 @@ mod tests { Ok(()) } + #[tokio::test] + #[cfg(feature = "db_dependent")] + async fn test_update_iris() -> Result<()> { + let schema_name = temporary_name(); + let store = Store::new(&test_db_url()?, &schema_name).await?; + + // insert two irises into db + let iris = StoredIrisRef { + left_code: &[123_u16; 12800], + left_mask: &[456_u16; 6400], + right_code: &[789_u16; 12800], + right_mask: &[101_u16; 6400], + }; + let mut tx = store.tx().await?; + store.insert_irises(&mut tx, &vec![iris.clone(); 2]).await?; + tx.commit().await?; + + // update iris with id 1 in db + let updated_left_code = GaloisRingIrisCodeShare { + id: 0, + coefs: [666_u16; 12800], + }; + let updated_left_mask = GaloisRingTrimmedMaskCodeShare { + id: 0, + coefs: [777_u16; 6400], + }; + let updated_right_code = GaloisRingIrisCodeShare { + id: 0, + coefs: [888_u16; 12800], + }; + let updated_right_mask = GaloisRingTrimmedMaskCodeShare { + id: 0, + coefs: [999_u16; 6400], + }; + store + .update_iris( + 1, + &updated_left_code, + &updated_left_mask, + &updated_right_code, + &updated_right_mask, + ) + .await?; + + // assert iris updated in db with new values + let got: Vec = store.stream_irises().await.try_collect().await?; + assert_eq!(got.len(), 2); + assert_eq!(cast_u8_to_u16(&got[0].left_code), updated_left_code.coefs); + assert_eq!(cast_u8_to_u16(&got[0].left_mask), updated_left_mask.coefs); + assert_eq!(cast_u8_to_u16(&got[0].right_code), updated_right_code.coefs); + assert_eq!(cast_u8_to_u16(&got[0].right_mask), updated_right_mask.coefs); + + // assert the other iris in db is not updated + assert_eq!(cast_u8_to_u16(&got[1].left_code), iris.left_code); + assert_eq!(cast_u8_to_u16(&got[1].left_mask), iris.left_mask); + assert_eq!(cast_u8_to_u16(&got[1].right_code), iris.right_code); + assert_eq!(cast_u8_to_u16(&got[1].right_mask), iris.right_mask); + + cleanup(&store, &schema_name).await?; + Ok(()) + } + fn test_db_url() -> Result { dotenvy::from_filename(DOTENV_TEST)?; Ok(Config::load_config(APP_NAME)? diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index f6d17a2e4..4003d7ff9 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -21,6 +21,7 @@ use iris_mpc_common::{ sync::SyncState, task_monitor::TaskMonitor, }, + iris_db::iris::IrisCode, IrisCodeDb, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, }; use iris_mpc_gpu::{ @@ -31,6 +32,7 @@ use iris_mpc_gpu::{ }, }; use iris_mpc_store::{Store, StoredIrisRef}; +use rand::{rngs::StdRng, SeedableRng}; use static_assertions::const_assert; use std::{ mem, @@ -828,12 +830,20 @@ async fn server_main(config: Config) -> eyre::Result<()> { shares_encryption_key_pair.clone(), ); + let dummy_shares_for_deletions = get_dummy_shares_for_deletion(party_id); + loop { let now = Instant::now(); let batch = next_batch.await?; - process_identity_deletions(&batch); + process_identity_deletions( + &batch, + &store, + &dummy_shares_for_deletions.0, + &dummy_shares_for_deletions.1, + ) + .await?; // Iterate over a list of tracing payloads, and create logs with mappings to // payloads Log at least a "start" event using a log with trace.id and @@ -893,9 +903,14 @@ async fn server_main(config: Config) -> eyre::Result<()> { Ok(()) } -fn process_identity_deletions(batch: &BatchQuery) { +async fn process_identity_deletions( + batch: &BatchQuery, + store: &Store, + dummy_iris_share: &GaloisRingIrisCodeShare, + dummy_mask_share: &GaloisRingTrimmedMaskCodeShare, +) -> eyre::Result<()> { if batch.deletion_requests.is_empty() { - return; + return Ok(()); } for (serial_id, tracing_payload) in batch @@ -910,7 +925,17 @@ fn process_identity_deletions(batch: &BatchQuery) { "Started processing deletion request", ); - // TODO: implement deletion logic here + // overwrite postgres db with dummy values. + // note that both serial_id and postgres db are 1-indexed. + store + .update_iris( + *serial_id as i64, + dummy_iris_share, + dummy_mask_share, + dummy_iris_share, + dummy_mask_share, + ) + .await?; tracing::info!( node_id = tracing_payload.node_id, @@ -920,4 +945,21 @@ fn process_identity_deletions(batch: &BatchQuery) { serial_id, ); } + + Ok(()) +} + +fn get_dummy_shares_for_deletion( + party_id: usize, +) -> (GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare) { + let mut rng: StdRng = StdRng::seed_from_u64(0); + let dummy: IrisCode = IrisCode::default(); + let iris_share: GaloisRingIrisCodeShare = + GaloisRingIrisCodeShare::encode_iris_code(&dummy.code, &dummy.mask, &mut rng)[party_id] + .clone(); + let mask_share: GaloisRingTrimmedMaskCodeShare = + GaloisRingIrisCodeShare::encode_mask_code(&dummy.mask, &mut rng)[party_id] + .clone() + .into(); + (iris_share, mask_share) }