feat: add & use LockBox helper

This commit is contained in:
2021-03-07 19:41:49 +01:00
parent a9d393282a
commit bcbc661f0a
3 changed files with 108 additions and 118 deletions

View File

@@ -1,14 +1,15 @@
use ::anyhow::{Context, Error, Result}; use ::anyhow::{Context, Result};
use ::async_trait::async_trait; use ::async_trait::async_trait;
use ::crdt_enc::{ use ::crdt_enc::{
key_cryptor::Keys, key_cryptor::Keys,
utils::VersionBytes, utils::{
utils::{decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom}, decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom, LockBox, VersionBytes,
},
CoreSubHandle, Info, CoreSubHandle, Info,
}; };
use ::crdts::{ctx::ReadCtx, CvRDT, MVReg, Orswot}; use ::crdts::{ctx::ReadCtx, CvRDT, MVReg, Orswot};
use ::serde::{Deserialize, Serialize}; use ::serde::{Deserialize, Serialize};
use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex}; use ::std::{convert::Infallible, fmt::Debug};
use ::uuid::Uuid; use ::uuid::Uuid;
const CURRENT_VERSION: Uuid = Uuid::from_u128(0xe69cb68e_7fbb_41aa_8d22_87eace7a04c9); const CURRENT_VERSION: Uuid = Uuid::from_u128(0xe69cb68e_7fbb_41aa_8d22_87eace7a04c9);
@@ -31,19 +32,17 @@ struct MutData {
#[derive(Debug)] #[derive(Debug)]
pub struct KeyHandler { pub struct KeyHandler {
data: SyncMutex<MutData>, data: LockBox<MutData>,
} }
impl KeyHandler { impl KeyHandler {
pub fn new() -> KeyHandler { pub fn new() -> KeyHandler {
let data = MutData { KeyHandler {
data: LockBox::new(MutData {
info: None, info: None,
core: None, core: None,
remote_meta: MVReg::new(), remote_meta: MVReg::new(),
}; }),
KeyHandler {
data: SyncMutex::new(data),
} }
} }
} }
@@ -68,12 +67,11 @@ impl CvRDT for Meta {
#[async_trait] #[async_trait]
impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler { impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
async fn init(&self, core: &dyn CoreSubHandle) -> Result<()> { async fn init(&self, core: &dyn CoreSubHandle) -> Result<()> {
let mut data = self self.data.with(|data| {
.data
.lock()
.map_err(|err| Error::msg(err.to_string()))?;
data.info = Some(core.info()); data.info = Some(core.info());
data.core = Some(dyn_clone::clone_box(core)); data.core = Some(dyn_clone::clone_box(core));
});
Ok(()) Ok(())
} }
@@ -81,20 +79,14 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
&self, &self,
new_remote_meta: Option<MVReg<VersionBytes, Uuid>>, new_remote_meta: Option<MVReg<VersionBytes, Uuid>>,
) -> Result<()> { ) -> Result<()> {
let (remote_meta, core) = { let (remote_meta, core) = self.data.try_with(|data| {
let mut data = self
.data
.lock()
.map_err(|err| Error::msg(err.to_string()))?;
if let Some(new_remote_meta) = new_remote_meta { if let Some(new_remote_meta) = new_remote_meta {
data.remote_meta.merge(new_remote_meta); data.remote_meta.merge(new_remote_meta);
} }
let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?); let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?);
Ok((data.remote_meta.clone(), core))
(data.remote_meta.clone(), core) })?;
};
let keys_ctx = let keys_ctx =
decode_version_bytes_mvreg_custom(&remote_meta, SUPPORTED_VERSIONS, |buf| async move { decode_version_bytes_mvreg_custom(&remote_meta, SUPPORTED_VERSIONS, |buf| async move {
@@ -109,15 +101,10 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
} }
async fn set_keys(&self, new_keys: ReadCtx<Keys, Uuid>) -> Result<()> { async fn set_keys(&self, new_keys: ReadCtx<Keys, Uuid>) -> Result<()> {
let (mut rm, core) = { let (mut rm, core) = self.data.try_with(|data| {
let data = self
.data
.lock()
.map_err(|err| Error::msg(err.to_string()))?;
let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?); let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?);
(data.remote_meta.clone(), core) Ok((data.remote_meta.clone(), core))
}; })?;
encode_version_bytes_mvreg_custom( encode_version_bytes_mvreg_custom(
&mut rm, &mut rm,

View File

@@ -7,7 +7,7 @@ use crate::{
cryptor::Cryptor, cryptor::Cryptor,
key_cryptor::{Key, KeyCryptor, Keys}, key_cryptor::{Key, KeyCryptor, Keys},
storage::Storage, storage::Storage,
utils::{VersionBytes, VersionBytesRef}, utils::{LockBox, VersionBytes, VersionBytesRef},
}; };
use ::anyhow::{Context, Error, Result}; use ::anyhow::{Context, Error, Result};
use ::async_trait::async_trait; use ::async_trait::async_trait;
@@ -20,12 +20,7 @@ use ::futures::{
}; };
use ::serde::{de::DeserializeOwned, Deserialize, Serialize}; use ::serde::{de::DeserializeOwned, Deserialize, Serialize};
use ::std::{ use ::std::{
collections::HashSet, collections::HashSet, convert::Infallible, default::Default, fmt::Debug, mem, sync::Arc,
convert::Infallible,
default::Default,
fmt::Debug,
mem,
sync::{Arc, Mutex as SyncMutex},
}; };
use ::uuid::Uuid; use ::uuid::Uuid;
@@ -211,7 +206,7 @@ pub struct Core<S, ST, C, KC> {
key_cryptor: KC, key_cryptor: KC,
// use sync `std::sync::Mutex` here because it has less overhead than async mutex, we are // use sync `std::sync::Mutex` here because it has less overhead than async mutex, we are
// holding it for a very shot time and do not `.await` while the lock is held. // holding it for a very shot time and do not `.await` while the lock is held.
data: SyncMutex<CoreMutData<S>>, data: LockBox<CoreMutData<S>>,
supported_data_versions: Vec<Uuid>, supported_data_versions: Vec<Uuid>,
current_data_version: Uuid, current_data_version: Uuid,
apply_ops_lock: AsyncMutex<()>, apply_ops_lock: AsyncMutex<()>,
@@ -245,18 +240,6 @@ where
KC: KeyCryptor, KC: KeyCryptor,
{ {
pub async fn open(options: OpenOptions<ST, C, KC>) -> Result<Arc<Self>> { pub async fn open(options: OpenOptions<ST, C, KC>) -> Result<Arc<Self>> {
let core_data = SyncMutex::new(CoreMutData {
local_meta: None,
remote_meta: RemoteMeta::default(),
keys: None,
state: StateWrapper {
next_op_versions: Default::default(),
state: Default::default(),
},
read_states: HashSet::new(),
read_remote_metas: HashSet::new(),
});
let mut supported_data_versions = options.supported_data_versions; let mut supported_data_versions = options.supported_data_versions;
supported_data_versions.sort_unstable(); supported_data_versions.sort_unstable();
@@ -267,7 +250,17 @@ where
key_cryptor: options.key_cryptor, key_cryptor: options.key_cryptor,
supported_data_versions, supported_data_versions,
current_data_version: options.current_data_version, current_data_version: options.current_data_version,
data: core_data, data: LockBox::new(CoreMutData {
local_meta: None,
remote_meta: RemoteMeta::default(),
keys: None,
state: StateWrapper {
next_op_versions: Default::default(),
state: Default::default(),
},
read_states: HashSet::new(),
read_remote_metas: HashSet::new(),
}),
apply_ops_lock: AsyncMutex::new(()), apply_ops_lock: AsyncMutex::new(()),
}); });
@@ -303,10 +296,9 @@ where
let actor = local_meta.local_actor_id; let actor = local_meta.local_actor_id;
core.with_mut_data(|data| { core.data.with(|data| {
data.local_meta = Some(local_meta); data.local_meta = Some(local_meta);
Ok(()) });
})?;
futures::try_join![ futures::try_join![
core.storage.init(&core), core.storage.init(&core),
@@ -316,16 +308,17 @@ where
core.read_remote_meta_(true).await?; core.read_remote_meta_(true).await?;
let insert_new_key = let insert_new_key = core
core.with_mut_data(|data| Ok(data.keys.as_ref().unwrap().val.latest_key().is_none()))?; .data
.with(|data| data.keys.as_ref().unwrap().val.latest_key().is_none());
if insert_new_key { if insert_new_key {
let new_key = core.cryptor.gen_key().await?; let new_key = core.cryptor.gen_key().await?;
let keys_ctx = core.with_mut_data(|data| { let keys_ctx = core.data.with(|data| {
let mut keys_ctx = data.keys.take().unwrap(); let mut keys_ctx = data.keys.take().unwrap();
keys_ctx.val.insert_latest_key(actor, Key::new(new_key)); keys_ctx.val.insert_latest_key(actor, Key::new(new_key));
Ok(keys_ctx) keys_ctx
})?; });
// give keys to kc, it gives us a new key ctx back // give keys to kc, it gives us a new key ctx back
core.key_cryptor.set_keys(keys_ctx).await?; core.key_cryptor.set_keys(keys_ctx).await?;
@@ -335,27 +328,14 @@ where
} }
pub fn info(self: &Arc<Self>) -> Info { pub fn info(self: &Arc<Self>) -> Info {
self.with_mut_data(|data| { self.data.with(|data| {
let actor = data let actor = data
.local_meta .local_meta
.as_ref() .as_ref()
.expect("info not set, yet. Do not call this fn in the init phase") .expect("info not set, yet. Do not call this fn in the init phase")
.local_actor_id; .local_actor_id;
Ok(Info { actor }) Info { actor }
}) })
.unwrap()
}
fn with_mut_data<F, R>(self: &Arc<Self>, f: F) -> Result<R>
where
F: FnOnce(&mut CoreMutData<S>) -> Result<R>,
{
let mut data = self
.data
.lock()
.map_err(|err| Error::msg(format!("unable to lock `CoreMutData`: {}", err)))?;
f(&mut *data)
} }
/// Locks cores data, do not call recursivl /// Locks cores data, do not call recursivl
@@ -363,13 +343,13 @@ where
where where
F: FnOnce(&S) -> Result<R>, F: FnOnce(&S) -> Result<R>,
{ {
self.with_mut_data(|data| f(&data.state.state)) self.data.with(|data| f(&data.state.state))
} }
pub async fn compact(self: &Arc<Self>) -> Result<()> { pub async fn compact(self: &Arc<Self>) -> Result<()> {
self.read_remote().await?; self.read_remote().await?;
let (clear_text, states_to_remove, ops_to_remove, key) = self.with_mut_data(|data| { let (clear_text, states_to_remove, ops_to_remove, key) = self.data.try_with(|data| {
let clear_text = rmp_serde::to_vec_named(&data.state)?; let clear_text = rmp_serde::to_vec_named(&data.state)?;
let states_to_remove = data.read_states.iter().cloned().collect(); let states_to_remove = data.read_states.iter().cloned().collect();
@@ -405,23 +385,21 @@ where
self.storage.remove_ops(ops_to_remove), self.storage.remove_ops(ops_to_remove),
]?; ]?;
self.with_mut_data(|data| { self.data.with(|data| {
for removed_state in removed_states { for removed_state in removed_states {
data.read_states.remove(&removed_state); data.read_states.remove(&removed_state);
} }
data.read_states.insert(new_state_name); data.read_states.insert(new_state_name);
Ok(()) });
})?;
Ok(()) Ok(())
} }
async fn set_keys(self: &Arc<Self>, keys: ReadCtx<Keys, Uuid>) -> Result<()> { async fn set_keys(self: &Arc<Self>, keys: ReadCtx<Keys, Uuid>) -> Result<()> {
self.with_mut_data(|data| { self.data.with(|data| {
data.keys = Some(keys); data.keys = Some(keys);
Ok(()) });
})?;
Ok(()) Ok(())
} }
@@ -444,7 +422,7 @@ where
.await .await
.context("failed getting state entry names while reading remote states")?; .context("failed getting state entry names while reading remote states")?;
let (states_to_read, key) = self.with_mut_data(|data| { let (states_to_read, key) = self.data.try_with(|data| {
let states_to_read: Vec<_> = names let states_to_read: Vec<_> = names
.into_iter() .into_iter()
.filter(|name| !data.read_states.contains(name)) .filter(|name| !data.read_states.contains(name))
@@ -493,7 +471,7 @@ where
let states_read = !new_states.is_empty(); let states_read = !new_states.is_empty();
self.with_mut_data(|data| { self.data.with(|data| {
for (name, state_wrapper) in new_states { for (name, state_wrapper) in new_states {
data.state.state.merge(state_wrapper.state); data.state.state.merge(state_wrapper.state);
data.state data.state
@@ -501,8 +479,7 @@ where
.merge(state_wrapper.next_op_versions); .merge(state_wrapper.next_op_versions);
data.read_states.insert(name); data.read_states.insert(name);
} }
Ok(()) });
})?;
Ok(states_read) Ok(states_read)
} }
@@ -514,7 +491,7 @@ where
.await .await
.context("failed getting op actor entries while reading remote ops")?; .context("failed getting op actor entries while reading remote ops")?;
let (ops_to_read, key) = self.with_mut_data(|data| { let (ops_to_read, key) = self.data.try_with(|data| {
let ops_to_read: Vec<_> = actors let ops_to_read: Vec<_> = actors
.into_iter() .into_iter()
.map(|actor| (actor, data.state.next_op_versions.get(&actor))) .map(|actor| (actor, data.state.next_op_versions.get(&actor)))
@@ -556,7 +533,7 @@ where
.try_collect() .try_collect()
.await?; .await?;
let ops_read = self.with_mut_data(|data| { let ops_read = self.data.with(|data| {
let mut ops_read = false; let mut ops_read = false;
for (actor, version, ops) in new_ops { for (actor, version, ops) in new_ops {
let expected_version = data.state.next_op_versions.get(&actor); let expected_version = data.state.next_op_versions.get(&actor);
@@ -600,13 +577,13 @@ where
.await .await
.context("failed getting remote meta entry names while reading remote metas")?; .context("failed getting remote meta entry names while reading remote metas")?;
let remote_metas_to_read = self.with_mut_data(|data| { let remote_metas_to_read = self.data.with(|data| {
let remote_metas_to_read: Vec<_> = names let remote_metas_to_read: Vec<_> = names
.into_iter() .into_iter()
.filter(|name| !data.read_remote_metas.contains(name)) .filter(|name| !data.read_remote_metas.contains(name))
.collect(); .collect();
Ok(remote_metas_to_read) remote_metas_to_read
})?; });
let remote_metas = self let remote_metas = self
.storage .storage
@@ -624,14 +601,14 @@ where
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let remote_meta = if !remote_metas.is_empty() { let remote_meta = if !remote_metas.is_empty() {
self.with_mut_data(|data| { self.data.with(|data| {
for (name, meta) in remote_metas { for (name, meta) in remote_metas {
data.remote_meta.merge(meta); data.remote_meta.merge(meta);
data.read_remote_metas.insert(name); data.read_remote_metas.insert(name);
} }
Ok(Some(data.remote_meta.clone())) Some(data.remote_meta.clone())
})? })
} else { } else {
None None
}; };
@@ -658,10 +635,9 @@ where
self: &Arc<Self>, self: &Arc<Self>,
remote_meta: MVReg<VersionBytes, Uuid>, remote_meta: MVReg<VersionBytes, Uuid>,
) -> Result<()> { ) -> Result<()> {
self.with_mut_data(|data| { self.data.with(|data| {
data.remote_meta.storage.merge(remote_meta); data.remote_meta.storage.merge(remote_meta);
Ok(()) });
})?;
self.store_remote_meta().await self.store_remote_meta().await
} }
@@ -670,10 +646,9 @@ where
self: &Arc<Self>, self: &Arc<Self>,
remote_meta: MVReg<VersionBytes, Uuid>, remote_meta: MVReg<VersionBytes, Uuid>,
) -> Result<()> { ) -> Result<()> {
self.with_mut_data(|data| { self.data.with(|data| {
data.remote_meta.cryptor.merge(remote_meta); data.remote_meta.cryptor.merge(remote_meta);
Ok(()) });
})?;
self.store_remote_meta().await self.store_remote_meta().await
} }
@@ -682,27 +657,26 @@ where
self: &Arc<Self>, self: &Arc<Self>,
remote_meta: MVReg<VersionBytes, Uuid>, remote_meta: MVReg<VersionBytes, Uuid>,
) -> Result<()> { ) -> Result<()> {
self.with_mut_data(|data| { self.data.with(|data| {
data.remote_meta.key_cryptor.merge(remote_meta); data.remote_meta.key_cryptor.merge(remote_meta);
Ok(()) });
})?;
self.store_remote_meta().await self.store_remote_meta().await
} }
async fn store_remote_meta(self: &Arc<Self>) -> Result<()> { async fn store_remote_meta(self: &Arc<Self>) -> Result<()> {
let vbox = self.with_mut_data(|data| { let vbox = self.data.try_with(|data| {
let bytes = rmp_serde::to_vec_named(&data.remote_meta)?; let bytes = rmp_serde::to_vec_named(&data.remote_meta)?;
Ok(VersionBytes::new(CURRENT_VERSION, bytes)) Ok(VersionBytes::new(CURRENT_VERSION, bytes))
})?; })?;
let new_name = self.storage.store_remote_meta(vbox).await?; let new_name = self.storage.store_remote_meta(vbox).await?;
let names_to_remove = self.with_mut_data(|data| { let names_to_remove = self.data.with(|data| {
let names_to_remove = data.read_remote_metas.drain().collect(); let names_to_remove = data.read_remote_metas.drain().collect();
data.read_remote_metas.insert(new_name); data.read_remote_metas.insert(new_name);
Ok(names_to_remove) names_to_remove
})?; });
self.storage.remove_remote_metas(names_to_remove).await?; self.storage.remove_remote_metas(names_to_remove).await?;
@@ -716,7 +690,7 @@ where
let clear_text = rmp_serde::to_vec_named(&ops)?; let clear_text = rmp_serde::to_vec_named(&ops)?;
let clear_text = VersionBytes::new(self.current_data_version, clear_text); let clear_text = VersionBytes::new(self.current_data_version, clear_text);
let key = self.with_mut_data(|data| { let key = self.data.with(|data| {
data.keys data.keys
.as_ref() .as_ref()
.unwrap() .unwrap()
@@ -740,7 +714,7 @@ where
let data_enc = VersionBytes::new(CURRENT_VERSION, data_enc); let data_enc = VersionBytes::new(CURRENT_VERSION, data_enc);
let (actor, version) = self.with_mut_data(|data| { let (actor, version) = self.data.try_with(|data| {
let actor = data let actor = data
.local_meta .local_meta
.as_ref() .as_ref()
@@ -752,15 +726,14 @@ where
self.storage.store_ops(actor, version, data_enc).await?; self.storage.store_ops(actor, version, data_enc).await?;
self.with_mut_data(|data| { self.data.with(|data| {
for op in ops { for op in ops {
data.state.state.apply(op); data.state.state.apply(op);
} }
let version_inc = data.state.next_op_versions.inc(actor); let version_inc = data.state.next_op_versions.inc(actor);
data.state.next_op_versions.apply(version_inc); data.state.next_op_versions.apply(version_inc);
Ok(()) });
})?;
// release lock by hand to prevent an early release by accident // release lock by hand to prevent an early release by accident
mem::drop(apply_ops_lock); mem::drop(apply_ops_lock);

View File

@@ -6,7 +6,7 @@ use ::anyhow::{Context, Result};
use ::crdts::{ctx::ReadCtx, CmRDT, CvRDT, MVReg}; use ::crdts::{ctx::ReadCtx, CmRDT, CvRDT, MVReg};
use ::futures::{stream, Future, FutureExt, StreamExt, TryStreamExt}; use ::futures::{stream, Future, FutureExt, StreamExt, TryStreamExt};
use ::serde::{de::DeserializeOwned, Deserialize, Serialize}; use ::serde::{de::DeserializeOwned, Deserialize, Serialize};
use ::std::convert::Infallible; use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex};
use ::uuid::Uuid; use ::uuid::Uuid;
#[derive(Debug, Clone, Default, Serialize, Deserialize)] #[derive(Debug, Clone, Default, Serialize, Deserialize)]
@@ -125,3 +125,33 @@ where
reg.apply(op); reg.apply(op);
Ok(()) Ok(())
} }
/// Prevents `await`s while the lock is held. Awaiting could cause deadlocking.
#[derive(Debug)]
pub struct LockBox<T> {
inner: SyncMutex<T>,
}
impl<T> LockBox<T> {
pub fn new(val: T) -> LockBox<T> {
LockBox {
inner: SyncMutex::new(val),
}
}
pub fn with<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
let mut data = self.inner.lock().expect("Unable to lock LockBox");
f(&mut *data)
}
/// Utility `LockBox::with` function, that enforces a `anyhow::Result` return type.
pub fn try_with<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce(&mut T) -> Result<R>,
{
self.with(f)
}
}