feat: add & use LockBox helper
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
use ::anyhow::{Context, Error, Result};
|
||||
use ::anyhow::{Context, Result};
|
||||
use ::async_trait::async_trait;
|
||||
use ::crdt_enc::{
|
||||
key_cryptor::Keys,
|
||||
utils::VersionBytes,
|
||||
utils::{decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom},
|
||||
utils::{
|
||||
decode_version_bytes_mvreg_custom, encode_version_bytes_mvreg_custom, LockBox, VersionBytes,
|
||||
},
|
||||
CoreSubHandle, Info,
|
||||
};
|
||||
use ::crdts::{ctx::ReadCtx, CvRDT, MVReg, Orswot};
|
||||
use ::serde::{Deserialize, Serialize};
|
||||
use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex};
|
||||
use ::std::{convert::Infallible, fmt::Debug};
|
||||
use ::uuid::Uuid;
|
||||
|
||||
const CURRENT_VERSION: Uuid = Uuid::from_u128(0xe69cb68e_7fbb_41aa_8d22_87eace7a04c9);
|
||||
@@ -31,19 +32,17 @@ struct MutData {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct KeyHandler {
|
||||
data: SyncMutex<MutData>,
|
||||
data: LockBox<MutData>,
|
||||
}
|
||||
|
||||
impl KeyHandler {
|
||||
pub fn new() -> KeyHandler {
|
||||
let data = MutData {
|
||||
KeyHandler {
|
||||
data: LockBox::new(MutData {
|
||||
info: None,
|
||||
core: None,
|
||||
remote_meta: MVReg::new(),
|
||||
};
|
||||
|
||||
KeyHandler {
|
||||
data: SyncMutex::new(data),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -68,12 +67,11 @@ impl CvRDT for Meta {
|
||||
#[async_trait]
|
||||
impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
|
||||
async fn init(&self, core: &dyn CoreSubHandle) -> Result<()> {
|
||||
let mut data = self
|
||||
.data
|
||||
.lock()
|
||||
.map_err(|err| Error::msg(err.to_string()))?;
|
||||
self.data.with(|data| {
|
||||
data.info = Some(core.info());
|
||||
data.core = Some(dyn_clone::clone_box(core));
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -81,20 +79,14 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
|
||||
&self,
|
||||
new_remote_meta: Option<MVReg<VersionBytes, Uuid>>,
|
||||
) -> Result<()> {
|
||||
let (remote_meta, core) = {
|
||||
let mut data = self
|
||||
.data
|
||||
.lock()
|
||||
.map_err(|err| Error::msg(err.to_string()))?;
|
||||
|
||||
let (remote_meta, core) = self.data.try_with(|data| {
|
||||
if let Some(new_remote_meta) = new_remote_meta {
|
||||
data.remote_meta.merge(new_remote_meta);
|
||||
}
|
||||
|
||||
let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?);
|
||||
|
||||
(data.remote_meta.clone(), core)
|
||||
};
|
||||
Ok((data.remote_meta.clone(), core))
|
||||
})?;
|
||||
|
||||
let keys_ctx =
|
||||
decode_version_bytes_mvreg_custom(&remote_meta, SUPPORTED_VERSIONS, |buf| async move {
|
||||
@@ -109,15 +101,10 @@ impl crdt_enc::key_cryptor::KeyCryptor for KeyHandler {
|
||||
}
|
||||
|
||||
async fn set_keys(&self, new_keys: ReadCtx<Keys, Uuid>) -> Result<()> {
|
||||
let (mut rm, core) = {
|
||||
let data = self
|
||||
.data
|
||||
.lock()
|
||||
.map_err(|err| Error::msg(err.to_string()))?;
|
||||
|
||||
let (mut rm, core) = self.data.try_with(|data| {
|
||||
let core = dyn_clone::clone_box(&**data.core.as_ref().context("core is none")?);
|
||||
(data.remote_meta.clone(), core)
|
||||
};
|
||||
Ok((data.remote_meta.clone(), core))
|
||||
})?;
|
||||
|
||||
encode_version_bytes_mvreg_custom(
|
||||
&mut rm,
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::{
|
||||
cryptor::Cryptor,
|
||||
key_cryptor::{Key, KeyCryptor, Keys},
|
||||
storage::Storage,
|
||||
utils::{VersionBytes, VersionBytesRef},
|
||||
utils::{LockBox, VersionBytes, VersionBytesRef},
|
||||
};
|
||||
use ::anyhow::{Context, Error, Result};
|
||||
use ::async_trait::async_trait;
|
||||
@@ -20,12 +20,7 @@ use ::futures::{
|
||||
};
|
||||
use ::serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use ::std::{
|
||||
collections::HashSet,
|
||||
convert::Infallible,
|
||||
default::Default,
|
||||
fmt::Debug,
|
||||
mem,
|
||||
sync::{Arc, Mutex as SyncMutex},
|
||||
collections::HashSet, convert::Infallible, default::Default, fmt::Debug, mem, sync::Arc,
|
||||
};
|
||||
use ::uuid::Uuid;
|
||||
|
||||
@@ -211,7 +206,7 @@ pub struct Core<S, ST, C, KC> {
|
||||
key_cryptor: KC,
|
||||
// use sync `std::sync::Mutex` here because it has less overhead than async mutex, we are
|
||||
// holding it for a very shot time and do not `.await` while the lock is held.
|
||||
data: SyncMutex<CoreMutData<S>>,
|
||||
data: LockBox<CoreMutData<S>>,
|
||||
supported_data_versions: Vec<Uuid>,
|
||||
current_data_version: Uuid,
|
||||
apply_ops_lock: AsyncMutex<()>,
|
||||
@@ -245,18 +240,6 @@ where
|
||||
KC: KeyCryptor,
|
||||
{
|
||||
pub async fn open(options: OpenOptions<ST, C, KC>) -> Result<Arc<Self>> {
|
||||
let core_data = SyncMutex::new(CoreMutData {
|
||||
local_meta: None,
|
||||
remote_meta: RemoteMeta::default(),
|
||||
keys: None,
|
||||
state: StateWrapper {
|
||||
next_op_versions: Default::default(),
|
||||
state: Default::default(),
|
||||
},
|
||||
read_states: HashSet::new(),
|
||||
read_remote_metas: HashSet::new(),
|
||||
});
|
||||
|
||||
let mut supported_data_versions = options.supported_data_versions;
|
||||
supported_data_versions.sort_unstable();
|
||||
|
||||
@@ -267,7 +250,17 @@ where
|
||||
key_cryptor: options.key_cryptor,
|
||||
supported_data_versions,
|
||||
current_data_version: options.current_data_version,
|
||||
data: core_data,
|
||||
data: LockBox::new(CoreMutData {
|
||||
local_meta: None,
|
||||
remote_meta: RemoteMeta::default(),
|
||||
keys: None,
|
||||
state: StateWrapper {
|
||||
next_op_versions: Default::default(),
|
||||
state: Default::default(),
|
||||
},
|
||||
read_states: HashSet::new(),
|
||||
read_remote_metas: HashSet::new(),
|
||||
}),
|
||||
apply_ops_lock: AsyncMutex::new(()),
|
||||
});
|
||||
|
||||
@@ -303,10 +296,9 @@ where
|
||||
|
||||
let actor = local_meta.local_actor_id;
|
||||
|
||||
core.with_mut_data(|data| {
|
||||
core.data.with(|data| {
|
||||
data.local_meta = Some(local_meta);
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
futures::try_join![
|
||||
core.storage.init(&core),
|
||||
@@ -316,16 +308,17 @@ where
|
||||
|
||||
core.read_remote_meta_(true).await?;
|
||||
|
||||
let insert_new_key =
|
||||
core.with_mut_data(|data| Ok(data.keys.as_ref().unwrap().val.latest_key().is_none()))?;
|
||||
let insert_new_key = core
|
||||
.data
|
||||
.with(|data| data.keys.as_ref().unwrap().val.latest_key().is_none());
|
||||
if insert_new_key {
|
||||
let new_key = core.cryptor.gen_key().await?;
|
||||
|
||||
let keys_ctx = core.with_mut_data(|data| {
|
||||
let keys_ctx = core.data.with(|data| {
|
||||
let mut keys_ctx = data.keys.take().unwrap();
|
||||
keys_ctx.val.insert_latest_key(actor, Key::new(new_key));
|
||||
Ok(keys_ctx)
|
||||
})?;
|
||||
keys_ctx
|
||||
});
|
||||
|
||||
// give keys to kc, it gives us a new key ctx back
|
||||
core.key_cryptor.set_keys(keys_ctx).await?;
|
||||
@@ -335,27 +328,14 @@ where
|
||||
}
|
||||
|
||||
pub fn info(self: &Arc<Self>) -> Info {
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
let actor = data
|
||||
.local_meta
|
||||
.as_ref()
|
||||
.expect("info not set, yet. Do not call this fn in the init phase")
|
||||
.local_actor_id;
|
||||
Ok(Info { actor })
|
||||
Info { actor }
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn with_mut_data<F, R>(self: &Arc<Self>, f: F) -> Result<R>
|
||||
where
|
||||
F: FnOnce(&mut CoreMutData<S>) -> Result<R>,
|
||||
{
|
||||
let mut data = self
|
||||
.data
|
||||
.lock()
|
||||
.map_err(|err| Error::msg(format!("unable to lock `CoreMutData`: {}", err)))?;
|
||||
|
||||
f(&mut *data)
|
||||
}
|
||||
|
||||
/// Locks cores data, do not call recursivl
|
||||
@@ -363,13 +343,13 @@ where
|
||||
where
|
||||
F: FnOnce(&S) -> Result<R>,
|
||||
{
|
||||
self.with_mut_data(|data| f(&data.state.state))
|
||||
self.data.with(|data| f(&data.state.state))
|
||||
}
|
||||
|
||||
pub async fn compact(self: &Arc<Self>) -> Result<()> {
|
||||
self.read_remote().await?;
|
||||
|
||||
let (clear_text, states_to_remove, ops_to_remove, key) = self.with_mut_data(|data| {
|
||||
let (clear_text, states_to_remove, ops_to_remove, key) = self.data.try_with(|data| {
|
||||
let clear_text = rmp_serde::to_vec_named(&data.state)?;
|
||||
|
||||
let states_to_remove = data.read_states.iter().cloned().collect();
|
||||
@@ -405,23 +385,21 @@ where
|
||||
self.storage.remove_ops(ops_to_remove),
|
||||
]?;
|
||||
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
for removed_state in removed_states {
|
||||
data.read_states.remove(&removed_state);
|
||||
}
|
||||
|
||||
data.read_states.insert(new_state_name);
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn set_keys(self: &Arc<Self>, keys: ReadCtx<Keys, Uuid>) -> Result<()> {
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
data.keys = Some(keys);
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -444,7 +422,7 @@ where
|
||||
.await
|
||||
.context("failed getting state entry names while reading remote states")?;
|
||||
|
||||
let (states_to_read, key) = self.with_mut_data(|data| {
|
||||
let (states_to_read, key) = self.data.try_with(|data| {
|
||||
let states_to_read: Vec<_> = names
|
||||
.into_iter()
|
||||
.filter(|name| !data.read_states.contains(name))
|
||||
@@ -493,7 +471,7 @@ where
|
||||
|
||||
let states_read = !new_states.is_empty();
|
||||
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
for (name, state_wrapper) in new_states {
|
||||
data.state.state.merge(state_wrapper.state);
|
||||
data.state
|
||||
@@ -501,8 +479,7 @@ where
|
||||
.merge(state_wrapper.next_op_versions);
|
||||
data.read_states.insert(name);
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
Ok(states_read)
|
||||
}
|
||||
@@ -514,7 +491,7 @@ where
|
||||
.await
|
||||
.context("failed getting op actor entries while reading remote ops")?;
|
||||
|
||||
let (ops_to_read, key) = self.with_mut_data(|data| {
|
||||
let (ops_to_read, key) = self.data.try_with(|data| {
|
||||
let ops_to_read: Vec<_> = actors
|
||||
.into_iter()
|
||||
.map(|actor| (actor, data.state.next_op_versions.get(&actor)))
|
||||
@@ -556,7 +533,7 @@ where
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
let ops_read = self.with_mut_data(|data| {
|
||||
let ops_read = self.data.with(|data| {
|
||||
let mut ops_read = false;
|
||||
for (actor, version, ops) in new_ops {
|
||||
let expected_version = data.state.next_op_versions.get(&actor);
|
||||
@@ -600,13 +577,13 @@ where
|
||||
.await
|
||||
.context("failed getting remote meta entry names while reading remote metas")?;
|
||||
|
||||
let remote_metas_to_read = self.with_mut_data(|data| {
|
||||
let remote_metas_to_read = self.data.with(|data| {
|
||||
let remote_metas_to_read: Vec<_> = names
|
||||
.into_iter()
|
||||
.filter(|name| !data.read_remote_metas.contains(name))
|
||||
.collect();
|
||||
Ok(remote_metas_to_read)
|
||||
})?;
|
||||
remote_metas_to_read
|
||||
});
|
||||
|
||||
let remote_metas = self
|
||||
.storage
|
||||
@@ -624,14 +601,14 @@ where
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let remote_meta = if !remote_metas.is_empty() {
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
for (name, meta) in remote_metas {
|
||||
data.remote_meta.merge(meta);
|
||||
data.read_remote_metas.insert(name);
|
||||
}
|
||||
|
||||
Ok(Some(data.remote_meta.clone()))
|
||||
})?
|
||||
Some(data.remote_meta.clone())
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -658,10 +635,9 @@ where
|
||||
self: &Arc<Self>,
|
||||
remote_meta: MVReg<VersionBytes, Uuid>,
|
||||
) -> Result<()> {
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
data.remote_meta.storage.merge(remote_meta);
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
self.store_remote_meta().await
|
||||
}
|
||||
@@ -670,10 +646,9 @@ where
|
||||
self: &Arc<Self>,
|
||||
remote_meta: MVReg<VersionBytes, Uuid>,
|
||||
) -> Result<()> {
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
data.remote_meta.cryptor.merge(remote_meta);
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
self.store_remote_meta().await
|
||||
}
|
||||
@@ -682,27 +657,26 @@ where
|
||||
self: &Arc<Self>,
|
||||
remote_meta: MVReg<VersionBytes, Uuid>,
|
||||
) -> Result<()> {
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
data.remote_meta.key_cryptor.merge(remote_meta);
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
self.store_remote_meta().await
|
||||
}
|
||||
|
||||
async fn store_remote_meta(self: &Arc<Self>) -> Result<()> {
|
||||
let vbox = self.with_mut_data(|data| {
|
||||
let vbox = self.data.try_with(|data| {
|
||||
let bytes = rmp_serde::to_vec_named(&data.remote_meta)?;
|
||||
Ok(VersionBytes::new(CURRENT_VERSION, bytes))
|
||||
})?;
|
||||
|
||||
let new_name = self.storage.store_remote_meta(vbox).await?;
|
||||
|
||||
let names_to_remove = self.with_mut_data(|data| {
|
||||
let names_to_remove = self.data.with(|data| {
|
||||
let names_to_remove = data.read_remote_metas.drain().collect();
|
||||
data.read_remote_metas.insert(new_name);
|
||||
Ok(names_to_remove)
|
||||
})?;
|
||||
names_to_remove
|
||||
});
|
||||
|
||||
self.storage.remove_remote_metas(names_to_remove).await?;
|
||||
|
||||
@@ -716,7 +690,7 @@ where
|
||||
let clear_text = rmp_serde::to_vec_named(&ops)?;
|
||||
let clear_text = VersionBytes::new(self.current_data_version, clear_text);
|
||||
|
||||
let key = self.with_mut_data(|data| {
|
||||
let key = self.data.with(|data| {
|
||||
data.keys
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
@@ -740,7 +714,7 @@ where
|
||||
|
||||
let data_enc = VersionBytes::new(CURRENT_VERSION, data_enc);
|
||||
|
||||
let (actor, version) = self.with_mut_data(|data| {
|
||||
let (actor, version) = self.data.try_with(|data| {
|
||||
let actor = data
|
||||
.local_meta
|
||||
.as_ref()
|
||||
@@ -752,15 +726,14 @@ where
|
||||
|
||||
self.storage.store_ops(actor, version, data_enc).await?;
|
||||
|
||||
self.with_mut_data(|data| {
|
||||
self.data.with(|data| {
|
||||
for op in ops {
|
||||
data.state.state.apply(op);
|
||||
}
|
||||
|
||||
let version_inc = data.state.next_op_versions.inc(actor);
|
||||
data.state.next_op_versions.apply(version_inc);
|
||||
Ok(())
|
||||
})?;
|
||||
});
|
||||
|
||||
// release lock by hand to prevent an early release by accident
|
||||
mem::drop(apply_ops_lock);
|
||||
|
||||
@@ -6,7 +6,7 @@ use ::anyhow::{Context, Result};
|
||||
use ::crdts::{ctx::ReadCtx, CmRDT, CvRDT, MVReg};
|
||||
use ::futures::{stream, Future, FutureExt, StreamExt, TryStreamExt};
|
||||
use ::serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use ::std::convert::Infallible;
|
||||
use ::std::{convert::Infallible, fmt::Debug, sync::Mutex as SyncMutex};
|
||||
use ::uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
@@ -125,3 +125,33 @@ where
|
||||
reg.apply(op);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prevents `await`s while the lock is held. Awaiting could cause deadlocking.
|
||||
#[derive(Debug)]
|
||||
pub struct LockBox<T> {
|
||||
inner: SyncMutex<T>,
|
||||
}
|
||||
|
||||
impl<T> LockBox<T> {
|
||||
pub fn new(val: T) -> LockBox<T> {
|
||||
LockBox {
|
||||
inner: SyncMutex::new(val),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut T) -> R,
|
||||
{
|
||||
let mut data = self.inner.lock().expect("Unable to lock LockBox");
|
||||
f(&mut *data)
|
||||
}
|
||||
|
||||
/// Utility `LockBox::with` function, that enforces a `anyhow::Result` return type.
|
||||
pub fn try_with<F, R>(&self, f: F) -> Result<R>
|
||||
where
|
||||
F: FnOnce(&mut T) -> Result<R>,
|
||||
{
|
||||
self.with(f)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user