memorystore: Fix bug where duplicate tracked users are stored

This commit is contained in:
Andy Balaam 2024-03-28 15:15:13 +00:00
parent ac0bc95c25
commit 31131146a6
2 changed files with 29 additions and 19 deletions

View File

@ -56,7 +56,7 @@ pub struct MemoryStore {
inbound_group_sessions: GroupSessionStore,
outbound_group_sessions: StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>,
private_identity: StdRwLock<Option<PrivateCrossSigningIdentity>>,
tracked_users: StdRwLock<Vec<TrackedUser>>,
tracked_users: StdRwLock<HashMap<OwnedUserId, TrackedUser>>,
olm_hashes: StdRwLock<HashMap<String, HashSet<String>>>,
devices: DeviceStore,
identities: StdRwLock<HashMap<OwnedUserId, ReadOnlyUserIdentities>>,
@ -324,12 +324,13 @@ impl CryptoStore for MemoryStore {
}
async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
Ok(self.tracked_users.read().unwrap().clone())
Ok(self.tracked_users.read().unwrap().values().cloned().collect())
}
async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> {
self.tracked_users.write().unwrap().extend(tracked_users.iter().map(|(user_id, dirty)| {
TrackedUser { user_id: user_id.to_owned().into(), dirty: *dirty }
let user_id: OwnedUserId = user_id.to_owned().into();
(user_id.clone(), TrackedUser { user_id, dirty: *dirty })
}));
Ok(())
}
@ -559,23 +560,29 @@ mod tests {
}
#[async_test]
async fn test_tracked_users_store() {
// Given some tracked users
let tracked_users =
&[(user_id!("@dirty_user:s"), true), (user_id!("@clean_user:t"), false)];
// When we save them to the store
async fn test_tracked_users_are_stored_once_per_user_id() {
// Given a store containing 2 tracked users, both dirty
let user1 = user_id!("@user1:s");
let user2 = user_id!("@user2:s");
let user3 = user_id!("@user3:s");
let store = MemoryStore::new();
store.save_tracked_users(tracked_users).await.unwrap();
store.save_tracked_users(&[(user1, true), (user2, true)]).await.unwrap();
// Then we can get them out again
// When we mark one as clean and add another
store.save_tracked_users(&[(user2, false), (user3, false)]).await.unwrap();
// Then we can get them out again and their dirty flags are correct
let loaded_tracked_users =
store.load_tracked_users().await.expect("failed to load tracked users");
assert_eq!(loaded_tracked_users[0].user_id, user_id!("@dirty_user:s"));
assert!(loaded_tracked_users[0].dirty);
assert_eq!(loaded_tracked_users[1].user_id, user_id!("@clean_user:t"));
assert!(!loaded_tracked_users[1].dirty);
assert_eq!(loaded_tracked_users.len(), 2);
let tracked_contains = |user_id, dirty| {
loaded_tracked_users.iter().any(|u| u.user_id == user_id && u.dirty == dirty)
};
assert!(tracked_contains(user1, true));
assert!(tracked_contains(user2, false));
assert!(tracked_contains(user3, false));
assert_eq!(loaded_tracked_users.len(), 3);
}
#[async_test]

View File

@ -133,11 +133,14 @@ pub trait CryptoStore: AsyncTraitDeps {
room_id: &RoomId,
) -> Result<Option<OutboundGroupSession>, Self::Error>;
/// Load the list of users whose devices we are keeping track of.
/// Provide the list of users whose devices we are keeping track of, and
/// whether they are considered dirty/outdated.
async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>, Self::Error>;
/// Save a list of users and their respective dirty/outdated flags to the
/// store.
/// Update the list of users whose devices we are keeping track of, and
/// whether they are considered dirty/outdated.
///
/// Replaces any existing entry with a matching user ID.
async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error>;
/// Get the device for the given user with the given device ID.