crypto: Save outbound sessions in MemoryStore

This commit is contained in:
Andy Balaam 2024-03-15 11:43:37 +00:00
parent 32edfb1a9f
commit 486b6d6e2b
1 changed files with 38 additions and 3 deletions

View File

@ -54,6 +54,7 @@ pub struct MemoryStore {
account: StdRwLock<Option<Account>>,
sessions: SessionStore,
inbound_group_sessions: GroupSessionStore,
outbound_group_sessions: StdRwLock<Vec<OutboundGroupSession>>,
olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
devices: DeviceStore,
identities: StdRwLock<HashMap<OwnedUserId, ReadOnlyUserIdentities>>,
@ -74,6 +75,7 @@ impl Default for MemoryStore {
account: Default::default(),
sessions: SessionStore::new(),
inbound_group_sessions: GroupSessionStore::new(),
outbound_group_sessions: Default::default(),
olm_hashes: Default::default(),
devices: DeviceStore::new(),
identities: Default::default(),
@ -119,6 +121,10 @@ impl MemoryStore {
self.inbound_group_sessions.add(session);
}
}
fn save_outbound_group_sessions(&self, mut sessions: Vec<OutboundGroupSession>) {
self.outbound_group_sessions.write().unwrap().append(&mut sessions);
}
}
type Result<T> = std::result::Result<T, Infallible>;
@ -151,6 +157,7 @@ impl CryptoStore for MemoryStore {
async fn save_changes(&self, changes: Changes) -> Result<()> {
self.save_sessions(changes.sessions).await;
self.save_inbound_group_sessions(changes.inbound_group_sessions);
self.save_outbound_group_sessions(changes.outbound_group_sessions);
self.save_devices(changes.devices.new);
self.save_devices(changes.devices.changed);
@ -297,8 +304,17 @@ impl CryptoStore for MemoryStore {
Ok(self.backup_keys.read().await.to_owned())
}
async fn get_outbound_group_session(&self, _: &RoomId) -> Result<Option<OutboundGroupSession>> {
Ok(None)
async fn get_outbound_group_session(
&self,
room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>> {
Ok(self
.outbound_group_sessions
.read()
.unwrap()
.iter()
.find(|session| session.room_id() == room_id)
.cloned())
}
async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
@ -487,7 +503,7 @@ mod tests {
}
#[async_test]
async fn test_group_session_store() {
async fn test_inbound_group_session_store() {
let (account, _) = get_account_and_session_test_helper();
let room_id = room_id!("!test:localhost");
let curve_key = "Nn0L2hkcCMFKqynTjyGsJbth7QrVmX3lbrksMkrGOAw";
@ -511,6 +527,25 @@ mod tests {
assert_eq!(inbound, loaded_session);
}
#[async_test]
async fn test_outbound_group_session_store() {
// Given an outbound sessions
let (account, _) = get_account_and_session_test_helper();
let room_id = room_id!("!test:localhost");
let (outbound, _) = account.create_group_session_pair_with_defaults(room_id).await;
// When we save it to the store
let store = MemoryStore::new();
store.save_outbound_group_sessions(vec![outbound.clone()]);
// Then we can get it out again
let loaded_session = store.get_outbound_group_session(room_id).await.unwrap().unwrap();
assert_eq!(
serde_json::to_string(&outbound.pickle().await).unwrap(),
serde_json::to_string(&loaded_session.pickle().await).unwrap()
);
}
#[async_test]
async fn test_device_store() {
let device = get_device();