feat: add & use LockBox helper

This commit is contained in:
2021-03-07 19:41:49 +01:00
parent a9d393282a
commit bcbc661f0a
3 changed files with 108 additions and 118 deletions

View File

@@ -1,14 +1,15 @@
use ::anyhow::{Context, Error, Result};
use ::anyhow::{Context, Result};
use ::async_trait::async_trait;
use ::crdt_enc::{
key_cryptor::Keys,
utils::VersionBytes,
utils::{decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom},
utils::{
decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom, LockBox, VersionBytes,
},
CoreSubHandle, Info,
};
use ::crdts::{ctx::ReadCtx, CvRDT, MVReg, Orswot};
use ::serde::{Deserialize, Serialize};
use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex};
use ::std::{convert::Infallible, fmt::Debug};
use ::uuid::Uuid;
const CURRENT_VERSION: Uuid = Uuid::from_u128(0xe69cb68e_7fbb_41aa_8d22_87eace7a04c9);
@@ -31,19 +32,17 @@ struct MutData {
#[derive(Debug)]
pub struct KeyHandler {
data: SyncMutex<MutData>,
data: LockBox<MutData>,
}
impl KeyHandler {
pub fn new() -> KeyHandler {
let data = MutData {
KeyHandler {
data: LockBox::new(MutData {
info: None,
core: None,
remote_meta: MVReg::new(),
};
KeyHandler {
data: SyncMutex::new(data),
}),
}
}
}
@@ -68,12 +67,11 @@ impl CvRDT for Meta {
#[async_trait]
impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
async fn init(&self, core: &dyn CoreSubHandle) -> Result<()> {
let mut data = self
.data
.lock()
.map_err(|err| Error::msg(err.to_string()))?;
self.data.with(|data| {
data.info = Some(core.info());
data.core = Some(dyn_clone::clone_box(core));
});
Ok(())
}
@@ -81,20 +79,14 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
&self,
new_remote_meta: Option<MVReg<VersionBytes, Uuid>>,
) -> Result<()> {
let (remote_meta, core) = {
let mut data = self
.data
.lock()
.map_err(|err| Error::msg(err.to_string()))?;
let (remote_meta, core) = self.data.try_with(|data| {
if let Some(new_remote_meta) = new_remote_meta {
data.remote_meta.merge(new_remote_meta);
}
let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?);
(data.remote_meta.clone(), core)
};
Ok((data.remote_meta.clone(), core))
})?;
let keys_ctx =
decode_version_bytes_mvreg_custom(&remote_meta, SUPPORTED_VERSIONS, |buf| async move {
@@ -109,15 +101,10 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
}
async fn set_keys(&self, new_keys: ReadCtx<Keys, Uuid>) -> Result<()> {
let (mut rm, core) = {
let data = self
.data
.lock()
.map_err(|err| Error::msg(err.to_string()))?;
let (mut rm, core) = self.data.try_with(|data| {
let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?);
(data.remote_meta.clone(), core)
};
Ok((data.remote_meta.clone(), core))
})?;
encode_version_bytes_mvreg_custom(
&mut rm,

View File

@@ -7,7 +7,7 @@ use crate::{
cryptor::Cryptor,
key_cryptor::{Key, KeyCryptor, Keys},
storage::Storage,
utils::{VersionBytes, VersionBytesRef},
utils::{LockBox, VersionBytes, VersionBytesRef},
};
use ::anyhow::{Context, Error, Result};
use ::async_trait::async_trait;
@@ -20,12 +20,7 @@ use ::futures::{
};
use ::serde::{de::DeserializeOwned, Deserialize, Serialize};
use ::std::{
collections::HashSet,
convert::Infallible,
default::Default,
fmt::Debug,
mem,
sync::{Arc, Mutex as SyncMutex},
collections::HashSet, convert::Infallible, default::Default, fmt::Debug, mem, sync::Arc,
};
use ::uuid::Uuid;
@@ -211,7 +206,7 @@ pub struct Core<S, ST, C, KC> {
key_cryptor: KC,
// use sync `std::sync::Mutex` here because it has less overhead than async mutex, we are
// holding it for a very shot time and do not `.await` while the lock is held.
data: SyncMutex<CoreMutData<S>>,
data: LockBox<CoreMutData<S>>,
supported_data_versions: Vec<Uuid>,
current_data_version: Uuid,
apply_ops_lock: AsyncMutex<()>,
@@ -245,18 +240,6 @@ where
KC: KeyCryptor,
{
pub async fn open(options: OpenOptions<ST, C, KC>) -> Result<Arc<Self>> {
let core_data = SyncMutex::new(CoreMutData {
local_meta: None,
remote_meta: RemoteMeta::default(),
keys: None,
state: StateWrapper {
next_op_versions: Default::default(),
state: Default::default(),
},
read_states: HashSet::new(),
read_remote_metas: HashSet::new(),
});
let mut supported_data_versions = options.supported_data_versions;
supported_data_versions.sort_unstable();
@@ -267,7 +250,17 @@ where
key_cryptor: options.key_cryptor,
supported_data_versions,
current_data_version: options.current_data_version,
data: core_data,
data: LockBox::new(CoreMutData {
local_meta: None,
remote_meta: RemoteMeta::default(),
keys: None,
state: StateWrapper {
next_op_versions: Default::default(),
state: Default::default(),
},
read_states: HashSet::new(),
read_remote_metas: HashSet::new(),
}),
apply_ops_lock: AsyncMutex::new(()),
});
@@ -303,10 +296,9 @@ where
let actor = local_meta.local_actor_id;
core.with_mut_data(|data| {
core.data.with(|data| {
data.local_meta = Some(local_meta);
Ok(())
})?;
});
futures::try_join![
core.storage.init(&core),
@@ -316,16 +308,17 @@ where
core.read_remote_meta_(true).await?;
let insert_new_key =
core.with_mut_data(|data| Ok(data.keys.as_ref().unwrap().val.latest_key().is_none()))?;
let insert_new_key = core
.data
.with(|data| data.keys.as_ref().unwrap().val.latest_key().is_none());
if insert_new_key {
let new_key = core.cryptor.gen_key().await?;
let keys_ctx = core.with_mut_data(|data| {
let keys_ctx = core.data.with(|data| {
let mut keys_ctx = data.keys.take().unwrap();
keys_ctx.val.insert_latest_key(actor, Key::new(new_key));
Ok(keys_ctx)
})?;
keys_ctx
});
// give keys to kc, it gives us a new key ctx back
core.key_cryptor.set_keys(keys_ctx).await?;
@@ -335,27 +328,14 @@ where
}
pub fn info(self: &Arc<Self>) -> Info {
self.with_mut_data(|data| {
self.data.with(|data| {
let actor = data
.local_meta
.as_ref()
.expect("info not set, yet. Do not call this fn in the init phase")
.local_actor_id;
Ok(Info { actor })
Info { actor }
})
.unwrap()
}
fn with_mut_data<F, R>(self: &Arc<Self>, f: F) -> Result<R>
where
F: FnOnce(&mut CoreMutData<S>) -> Result<R>,
{
let mut data = self
.data
.lock()
.map_err(|err| Error::msg(format!("unable to lock `CoreMutData`: {}", err)))?;
f(&mut *data)
}
/// Locks cores data, do not call recursivl
@@ -363,13 +343,13 @@ where
where
F: FnOnce(&S) -> Result<R>,
{
self.with_mut_data(|data| f(&data.state.state))
self.data.with(|data| f(&data.state.state))
}
pub async fn compact(self: &Arc<Self>) -> Result<()> {
self.read_remote().await?;
let (clear_text, states_to_remove, ops_to_remove, key) = self.with_mut_data(|data| {
let (clear_text, states_to_remove, ops_to_remove, key) = self.data.try_with(|data| {
let clear_text = rmp_serde::to_vec_named(&data.state)?;
let states_to_remove = data.read_states.iter().cloned().collect();
@@ -405,23 +385,21 @@ where
self.storage.remove_ops(ops_to_remove),
]?;
self.with_mut_data(|data| {
self.data.with(|data| {
for removed_state in removed_states {
data.read_states.remove(&removed_state);
}
data.read_states.insert(new_state_name);
Ok(())
})?;
});
Ok(())
}
async fn set_keys(self: &Arc<Self>, keys: ReadCtx<Keys, Uuid>) -> Result<()> {
self.with_mut_data(|data| {
self.data.with(|data| {
data.keys = Some(keys);
Ok(())
})?;
});
Ok(())
}
@@ -444,7 +422,7 @@ where
.await
.context("failed getting state entry names while reading remote states")?;
let (states_to_read, key) = self.with_mut_data(|data| {
let (states_to_read, key) = self.data.try_with(|data| {
let states_to_read: Vec<_> = names
.into_iter()
.filter(|name| !data.read_states.contains(name))
@@ -493,7 +471,7 @@ where
let states_read = !new_states.is_empty();
self.with_mut_data(|data| {
self.data.with(|data| {
for (name, state_wrapper) in new_states {
data.state.state.merge(state_wrapper.state);
data.state
@@ -501,8 +479,7 @@ where
.merge(state_wrapper.next_op_versions);
data.read_states.insert(name);
}
Ok(())
})?;
});
Ok(states_read)
}
@@ -514,7 +491,7 @@ where
.await
.context("failed getting op actor entries while reading remote ops")?;
let (ops_to_read, key) = self.with_mut_data(|data| {
let (ops_to_read, key) = self.data.try_with(|data| {
let ops_to_read: Vec<_> = actors
.into_iter()
.map(|actor| (actor, data.state.next_op_versions.get(&actor)))
@@ -556,7 +533,7 @@ where
.try_collect()
.await?;
let ops_read = self.with_mut_data(|data| {
let ops_read = self.data.with(|data| {
let mut ops_read = false;
for (actor, version, ops) in new_ops {
let expected_version = data.state.next_op_versions.get(&actor);
@@ -600,13 +577,13 @@ where
.await
.context("failed getting remote meta entry names while reading remote metas")?;
let remote_metas_to_read = self.with_mut_data(|data| {
let remote_metas_to_read = self.data.with(|data| {
let remote_metas_to_read: Vec<_> = names
.into_iter()
.filter(|name| !data.read_remote_metas.contains(name))
.collect();
Ok(remote_metas_to_read)
})?;
remote_metas_to_read
});
let remote_metas = self
.storage
@@ -624,14 +601,14 @@ where
.collect::<Result<Vec<_>>>()?;
let remote_meta = if !remote_metas.is_empty() {
self.with_mut_data(|data| {
self.data.with(|data| {
for (name, meta) in remote_metas {
data.remote_meta.merge(meta);
data.read_remote_metas.insert(name);
}
Ok(Some(data.remote_meta.clone()))
})?
Some(data.remote_meta.clone())
})
} else {
None
};
@@ -658,10 +635,9 @@ where
self: &Arc<Self>,
remote_meta: MVReg<VersionBytes, Uuid>,
) -> Result<()> {
self.with_mut_data(|data| {
self.data.with(|data| {
data.remote_meta.storage.merge(remote_meta);
Ok(())
})?;
});
self.store_remote_meta().await
}
@@ -670,10 +646,9 @@ where
self: &Arc<Self>,
remote_meta: MVReg<VersionBytes, Uuid>,
) -> Result<()> {
self.with_mut_data(|data| {
self.data.with(|data| {
data.remote_meta.cryptor.merge(remote_meta);
Ok(())
})?;
});
self.store_remote_meta().await
}
@@ -682,27 +657,26 @@ where
self: &Arc<Self>,
remote_meta: MVReg<VersionBytes, Uuid>,
) -> Result<()> {
self.with_mut_data(|data| {
self.data.with(|data| {
data.remote_meta.key_cryptor.merge(remote_meta);
Ok(())
})?;
});
self.store_remote_meta().await
}
async fn store_remote_meta(self: &Arc<Self>) -> Result<()> {
let vbox = self.with_mut_data(|data| {
let vbox = self.data.try_with(|data| {
let bytes = rmp_serde::to_vec_named(&data.remote_meta)?;
Ok(VersionBytes::new(CURRENT_VERSION, bytes))
})?;
let new_name = self.storage.store_remote_meta(vbox).await?;
let names_to_remove = self.with_mut_data(|data| {
let names_to_remove = self.data.with(|data| {
let names_to_remove = data.read_remote_metas.drain().collect();
data.read_remote_metas.insert(new_name);
Ok(names_to_remove)
})?;
names_to_remove
});
self.storage.remove_remote_metas(names_to_remove).await?;
@@ -716,7 +690,7 @@ where
let clear_text = rmp_serde::to_vec_named(&ops)?;
let clear_text = VersionBytes::new(self.current_data_version, clear_text);
let key = self.with_mut_data(|data| {
let key = self.data.with(|data| {
data.keys
.as_ref()
.unwrap()
@@ -740,7 +714,7 @@ where
let data_enc = VersionBytes::new(CURRENT_VERSION, data_enc);
let (actor, version) = self.with_mut_data(|data| {
let (actor, version) = self.data.try_with(|data| {
let actor = data
.local_meta
.as_ref()
@@ -752,15 +726,14 @@ where
self.storage.store_ops(actor, version, data_enc).await?;
self.with_mut_data(|data| {
self.data.with(|data| {
for op in ops {
data.state.state.apply(op);
}
let version_inc = data.state.next_op_versions.inc(actor);
data.state.next_op_versions.apply(version_inc);
Ok(())
})?;
});
// release lock by hand to prevent an early release by accident
mem::drop(apply_ops_lock);

View File

@@ -6,7 +6,7 @@ use ::anyhow::{Context, Result};
use ::crdts::{ctx::ReadCtx, CmRDT, CvRDT, MVReg};
use ::futures::{stream, Future, FutureExt, StreamExt, TryStreamExt};
use ::serde::{de::DeserializeOwned, Deserialize, Serialize};
use ::std::convert::Infallible;
use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex};
use ::uuid::Uuid;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
@@ -125,3 +125,33 @@ where
reg.apply(op);
Ok(())
}
/// Prevents `await`s while the lock is held. Awaiting could cause deadlocking.
#[derive(Debug)]
pub struct LockBox<T> {
inner: SyncMutex<T>,
}
impl<T> LockBox<T> {
pub fn new(val: T) -> LockBox<T> {
LockBox {
inner: SyncMutex::new(val),
}
}
pub fn with<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
let mut data = self.inner.lock().expect("Unable to lock LockBox");
f(&mut *data)
}
/// Utility `LockBox::with` function, that enforces a `anyhow::Result` return type.
pub fn try_with<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce(&mut T) -> Result<R>,
{
self.with(f)
}
}