diff --git a/crdt-enc-gpgme/src/lib.rs b/crdt-enc-gpgme/src/lib.rs index f8d07f3..c9cceae 100644 --- a/crdt-enc-gpgme/src/lib.rs +++ b/crdt-enc-gpgme/src/lib.rs @@ -1,14 +1,15 @@ -use ::anyhow::{Context, Error, Result}; +use ::anyhow::{Context, Result}; use ::async_trait::async_trait; use ::crdt_enc::{ key_cryptor::Keys, - utils::VersionBytes, - utils::{decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom}, + utils::{ + decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom, LockBox, VersionBytes, + }, CoreSubHandle, Info, }; use ::crdts::{ctx::ReadCtx, CvRDT, MVReg, Orswot}; use ::serde::{Deserialize, Serialize}; -use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex}; +use ::std::{convert::Infallible, fmt::Debug}; use ::uuid::Uuid; const CURRENT_VERSION: Uuid = Uuid::from_u128(0xe69cb68e_7fbb_41aa_8d22_87eace7a04c9); @@ -31,19 +32,17 @@ struct MutData { #[derive(Debug)] pub struct KeyHandler { - data: SyncMutex, + data: LockBox, } impl KeyHandler { pub fn new() -> KeyHandler { - let data = MutData { - info: None, - core: None, - remote_meta: MVReg::new(), - }; - KeyHandler { - data: SyncMutex::new(data), + data: LockBox::new(MutData { + info: None, + core: None, + remote_meta: MVReg::new(), + }), } } } @@ -68,12 +67,11 @@ impl CvRDT for Meta { #[async_trait] impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler { async fn init(&self, core: &dyn CoreSubHandle) -> Result<()> { - let mut data = self - .data - .lock() - .map_err(|err| Error::msg(err.to_string()))?; - data.info = Some(core.info()); - data.core = Some(dyn_clone::clone_box(core)); + self.data.with(|data| { + data.info = Some(core.info()); + data.core = Some(dyn_clone::clone_box(core)); + }); + Ok(()) } @@ -81,20 +79,14 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler { &self, new_remote_meta: Option>, ) -> Result<()> { - let (remote_meta, core) = { - let mut data = self - .data - .lock() - .map_err(|err| Error::msg(err.to_string()))?; - + let (remote_meta, core) = self.data.try_with(|data| { if let Some(new_remote_meta) = new_remote_meta { data.remote_meta.merge(new_remote_meta); } let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?); - - (data.remote_meta.clone(), core) - }; + Ok((data.remote_meta.clone(), core)) + })?; let keys_ctx = decode_version_bytes_mvreg_custom(&remote_meta, SUPPORTED_VERSIONS, |buf| async move { @@ -109,15 +101,10 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler { } async fn set_keys(&self, new_keys: ReadCtx) -> Result<()> { - let (mut rm, core) = { - let data = self - .data - .lock() - .map_err(|err| Error::msg(err.to_string()))?; - + let (mut rm, core) = self.data.try_with(|data| { let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?); - (data.remote_meta.clone(), core) - }; + Ok((data.remote_meta.clone(), core)) + })?; encode_version_bytes_mvreg_custom( &mut rm, diff --git a/crdt-enc/src/lib.rs b/crdt-enc/src/lib.rs index 65a4f5f..21a681e 100644 --- a/crdt-enc/src/lib.rs +++ b/crdt-enc/src/lib.rs @@ -7,7 +7,7 @@ use crate::{ cryptor::Cryptor, key_cryptor::{Key, KeyCryptor, Keys}, storage::Storage, - utils::{VersionBytes, VersionBytesRef}, + utils::{LockBox, VersionBytes, VersionBytesRef}, }; use ::anyhow::{Context, Error, Result}; use ::async_trait::async_trait; @@ -20,12 +20,7 @@ use ::futures::{ }; use ::serde::{de::DeserializeOwned, Deserialize, Serialize}; use ::std::{ - collections::HashSet, - convert::Infallible, - default::Default, - fmt::Debug, - mem, - sync::{Arc, Mutex as SyncMutex}, + collections::HashSet, convert::Infallible, default::Default, fmt::Debug, mem, sync::Arc, }; use ::uuid::Uuid; @@ -211,7 +206,7 @@ pub struct Core { key_cryptor: KC, // use sync `std::sync::Mutex` here because it has less overhead than async mutex, we are // holding it for a very shot time and do not `.await` while the lock is held. - data: SyncMutex>, + data: LockBox>, supported_data_versions: Vec, current_data_version: Uuid, apply_ops_lock: AsyncMutex<()>, @@ -245,18 +240,6 @@ where KC: KeyCryptor, { pub async fn open(options: OpenOptions) -> Result> { - let core_data = SyncMutex::new(CoreMutData { - local_meta: None, - remote_meta: RemoteMeta::default(), - keys: None, - state: StateWrapper { - next_op_versions: Default::default(), - state: Default::default(), - }, - read_states: HashSet::new(), - read_remote_metas: HashSet::new(), - }); - let mut supported_data_versions = options.supported_data_versions; supported_data_versions.sort_unstable(); @@ -267,7 +250,17 @@ where key_cryptor: options.key_cryptor, supported_data_versions, current_data_version: options.current_data_version, - data: core_data, + data: LockBox::new(CoreMutData { + local_meta: None, + remote_meta: RemoteMeta::default(), + keys: None, + state: StateWrapper { + next_op_versions: Default::default(), + state: Default::default(), + }, + read_states: HashSet::new(), + read_remote_metas: HashSet::new(), + }), apply_ops_lock: AsyncMutex::new(()), }); @@ -303,10 +296,9 @@ where let actor = local_meta.local_actor_id; - core.with_mut_data(|data| { + core.data.with(|data| { data.local_meta = Some(local_meta); - Ok(()) - })?; + }); futures::try_join![ core.storage.init(&core), @@ -316,16 +308,17 @@ where core.read_remote_meta_(true).await?; - let insert_new_key = - core.with_mut_data(|data| Ok(data.keys.as_ref().unwrap().val.latest_key().is_none()))?; + let insert_new_key = core + .data + .with(|data| data.keys.as_ref().unwrap().val.latest_key().is_none()); if insert_new_key { let new_key = core.cryptor.gen_key().await?; - let keys_ctx = core.with_mut_data(|data| { + let keys_ctx = core.data.with(|data| { let mut keys_ctx = data.keys.take().unwrap(); keys_ctx.val.insert_latest_key(actor, Key::new(new_key)); - Ok(keys_ctx) - })?; + keys_ctx + }); // give keys to kc, it gives us a new key ctx back core.key_cryptor.set_keys(keys_ctx).await?; @@ -335,27 +328,14 @@ where } pub fn info(self: &Arc) -> Info { - self.with_mut_data(|data| { + self.data.with(|data| { let actor = data .local_meta .as_ref() .expect("info not set, yet. Do not call this fn in the init phase") .local_actor_id; - Ok(Info { actor }) + Info { actor } }) - .unwrap() - } - - fn with_mut_data(self: &Arc, f: F) -> Result - where - F: FnOnce(&mut CoreMutData) -> Result, - { - let mut data = self - .data - .lock() - .map_err(|err| Error::msg(format!("unable to lock `CoreMutData`: {}", err)))?; - - f(&mut *data) } /// Locks cores data, do not call recursivl @@ -363,13 +343,13 @@ where where F: FnOnce(&S) -> Result, { - self.with_mut_data(|data| f(&data.state.state)) + self.data.with(|data| f(&data.state.state)) } pub async fn compact(self: &Arc) -> Result<()> { self.read_remote().await?; - let (clear_text, states_to_remove, ops_to_remove, key) = self.with_mut_data(|data| { + let (clear_text, states_to_remove, ops_to_remove, key) = self.data.try_with(|data| { let clear_text = rmp_serde::to_vec_named(&data.state)?; let states_to_remove = data.read_states.iter().cloned().collect(); @@ -405,23 +385,21 @@ where self.storage.remove_ops(ops_to_remove), ]?; - self.with_mut_data(|data| { + self.data.with(|data| { for removed_state in removed_states { data.read_states.remove(&removed_state); } data.read_states.insert(new_state_name); - Ok(()) - })?; + }); Ok(()) } async fn set_keys(self: &Arc, keys: ReadCtx) -> Result<()> { - self.with_mut_data(|data| { + self.data.with(|data| { data.keys = Some(keys); - Ok(()) - })?; + }); Ok(()) } @@ -444,7 +422,7 @@ where .await .context("failed getting state entry names while reading remote states")?; - let (states_to_read, key) = self.with_mut_data(|data| { + let (states_to_read, key) = self.data.try_with(|data| { let states_to_read: Vec<_> = names .into_iter() .filter(|name| !data.read_states.contains(name)) @@ -493,7 +471,7 @@ where let states_read = !new_states.is_empty(); - self.with_mut_data(|data| { + self.data.with(|data| { for (name, state_wrapper) in new_states { data.state.state.merge(state_wrapper.state); data.state @@ -501,8 +479,7 @@ where .merge(state_wrapper.next_op_versions); data.read_states.insert(name); } - Ok(()) - })?; + }); Ok(states_read) } @@ -514,7 +491,7 @@ where .await .context("failed getting op actor entries while reading remote ops")?; - let (ops_to_read, key) = self.with_mut_data(|data| { + let (ops_to_read, key) = self.data.try_with(|data| { let ops_to_read: Vec<_> = actors .into_iter() .map(|actor| (actor, data.state.next_op_versions.get(&actor))) @@ -556,7 +533,7 @@ where .try_collect() .await?; - let ops_read = self.with_mut_data(|data| { + let ops_read = self.data.with(|data| { let mut ops_read = false; for (actor, version, ops) in new_ops { let expected_version = data.state.next_op_versions.get(&actor); @@ -600,13 +577,13 @@ where .await .context("failed getting remote meta entry names while reading remote metas")?; - let remote_metas_to_read = self.with_mut_data(|data| { + let remote_metas_to_read = self.data.with(|data| { let remote_metas_to_read: Vec<_> = names .into_iter() .filter(|name| !data.read_remote_metas.contains(name)) .collect(); - Ok(remote_metas_to_read) - })?; + remote_metas_to_read + }); let remote_metas = self .storage @@ -624,14 +601,14 @@ where .collect::>>()?; let remote_meta = if !remote_metas.is_empty() { - self.with_mut_data(|data| { + self.data.with(|data| { for (name, meta) in remote_metas { data.remote_meta.merge(meta); data.read_remote_metas.insert(name); } - Ok(Some(data.remote_meta.clone())) - })? + Some(data.remote_meta.clone()) + }) } else { None }; @@ -658,10 +635,9 @@ where self: &Arc, remote_meta: MVReg, ) -> Result<()> { - self.with_mut_data(|data| { + self.data.with(|data| { data.remote_meta.storage.merge(remote_meta); - Ok(()) - })?; + }); self.store_remote_meta().await } @@ -670,10 +646,9 @@ where self: &Arc, remote_meta: MVReg, ) -> Result<()> { - self.with_mut_data(|data| { + self.data.with(|data| { data.remote_meta.cryptor.merge(remote_meta); - Ok(()) - })?; + }); self.store_remote_meta().await } @@ -682,27 +657,26 @@ where self: &Arc, remote_meta: MVReg, ) -> Result<()> { - self.with_mut_data(|data| { + self.data.with(|data| { data.remote_meta.key_cryptor.merge(remote_meta); - Ok(()) - })?; + }); self.store_remote_meta().await } async fn store_remote_meta(self: &Arc) -> Result<()> { - let vbox = self.with_mut_data(|data| { + let vbox = self.data.try_with(|data| { let bytes = rmp_serde::to_vec_named(&data.remote_meta)?; Ok(VersionBytes::new(CURRENT_VERSION, bytes)) })?; let new_name = self.storage.store_remote_meta(vbox).await?; - let names_to_remove = self.with_mut_data(|data| { + let names_to_remove = self.data.with(|data| { let names_to_remove = data.read_remote_metas.drain().collect(); data.read_remote_metas.insert(new_name); - Ok(names_to_remove) - })?; + names_to_remove + }); self.storage.remove_remote_metas(names_to_remove).await?; @@ -716,7 +690,7 @@ where let clear_text = rmp_serde::to_vec_named(&ops)?; let clear_text = VersionBytes::new(self.current_data_version, clear_text); - let key = self.with_mut_data(|data| { + let key = self.data.with(|data| { data.keys .as_ref() .unwrap() @@ -740,7 +714,7 @@ where let data_enc = VersionBytes::new(CURRENT_VERSION, data_enc); - let (actor, version) = self.with_mut_data(|data| { + let (actor, version) = self.data.try_with(|data| { let actor = data .local_meta .as_ref() @@ -752,15 +726,14 @@ where self.storage.store_ops(actor, version, data_enc).await?; - self.with_mut_data(|data| { + self.data.with(|data| { for op in ops { data.state.state.apply(op); } let version_inc = data.state.next_op_versions.inc(actor); data.state.next_op_versions.apply(version_inc); - Ok(()) - })?; + }); // release lock by hand to prevent an early release by accident mem::drop(apply_ops_lock); diff --git a/crdt-enc/src/utils.rs b/crdt-enc/src/utils.rs index 36b287c..9627917 100644 --- a/crdt-enc/src/utils.rs +++ b/crdt-enc/src/utils.rs @@ -6,7 +6,7 @@ use ::anyhow::{Context, Result}; use ::crdts::{ctx::ReadCtx, CmRDT, CvRDT, MVReg}; use ::futures::{stream, Future, FutureExt, StreamExt, TryStreamExt}; use ::serde::{de::DeserializeOwned, Deserialize, Serialize}; -use ::std::convert::Infallible; +use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex}; use ::uuid::Uuid; #[derive(Debug, Clone, Default, Serialize, Deserialize)] @@ -125,3 +125,33 @@ where reg.apply(op); Ok(()) } + +/// Prevents `await`s while the lock is held. Awaiting could cause deadlocking. +#[derive(Debug)] +pub struct LockBox { + inner: SyncMutex, +} + +impl LockBox { + pub fn new(val: T) -> LockBox { + LockBox { + inner: SyncMutex::new(val), + } + } + + pub fn with(&self, f: F) -> R + where + F: FnOnce(&mut T) -> R, + { + let mut data = self.inner.lock().expect("Unable to lock LockBox"); + f(&mut *data) + } + + /// Utility `LockBox::with` function, that enforces a `anyhow::Result` return type. + pub fn try_with(&self, f: F) -> Result + where + F: FnOnce(&mut T) -> Result, + { + self.with(f) + } +}