From eeea4f23bcd81ff66d5fafda9b7e1c597f42c1ae Mon Sep 17 00:00:00 2001 From: Doug <6060466+pixlwave@users.noreply.github.com> Date: Tue, 14 Mar 2023 14:40:19 +0000 Subject: [PATCH] fix(bindings): More authentication service server name fixes - Trim any trailing slashes - If server name parsing fails, try as a URL instead of throwing - Add tests - Fix typo & clippy --- .gitignore | 1 + .../AuthenticationServiceTests.swift | 64 ++++++++++++++++++ .../src/authentication_service.rs | 65 ++++++++++--------- crates/matrix-sdk/src/lib.rs | 34 +++++++++- 4 files changed, 132 insertions(+), 32 deletions(-) create mode 100644 bindings/apple/Tests/MatrixRustSDKTests/AuthenticationServiceTests.swift diff --git a/.gitignore b/.gitignore index b9a0e4f38..d202f963a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ emsdk-* .idea/ .env .build +.swiftpm /Package.swift ## User settings diff --git a/bindings/apple/Tests/MatrixRustSDKTests/AuthenticationServiceTests.swift b/bindings/apple/Tests/MatrixRustSDKTests/AuthenticationServiceTests.swift new file mode 100644 index 000000000..df92b6a32 --- /dev/null +++ b/bindings/apple/Tests/MatrixRustSDKTests/AuthenticationServiceTests.swift @@ -0,0 +1,64 @@ +@testable import MatrixRustSDK +import XCTest + +class AuthenticationServiceTests: XCTestCase { + var service: AuthenticationService! + + override func setUp() { + service = AuthenticationService(basePath: FileManager.default.temporaryDirectory.path, + passphrase: nil, + customSlidingSyncProxy: nil) + } + + func testValidServers() { + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "matrix.org")) + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix.org")) + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix.org/")) + } + + func testInvalidCharacters() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "hello!@$£%^world"), + "A server name with invalid characters should not succeed to build.") { error in + guard case AuthenticationError.InvalidServerName = error else { XCTFail("Expected invalid name error."); return } + } + } + + func textNonExistentDomain() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "somesillylinkthatdoesntexist.com"), + "A server name that doesn't exist should not succeed.") { error in + guard case AuthenticationError.Generic = error else { XCTFail("Expected generic error."); return } + } + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://somesillylinkthatdoesntexist.com"), + "A server URL that doesn't exist should not succeed.") { error in + guard case AuthenticationError.Generic = error else { XCTFail("Expected generic error."); return } + } + } + + func testValidDomainWithoutServer() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://google.com"), + "Google should not succeed as it doesn't host a homeserver.") { error in + guard case AuthenticationError.Generic = error else { XCTFail("Expected generic error."); return } + } + } + + func testServerWithoutSlidingSync() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "envs.net"), + "Envs should not succeed as it doesn't advertise a sliding sync proxy.") { error in + guard case AuthenticationError.SlidingSyncNotAvailable = error else { XCTFail("Expected sliding sync error."); return } + } + } + + func testHomeserverURL() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix-client.matrix.org"), + "Directly using a homeserver should not succeed as a sliding sync proxy won't be found.") { error in + guard case AuthenticationError.SlidingSyncNotAvailable = error else { XCTFail("Expected sliding sync error."); return } + } + } + + func testHomeserverURLWithProxyOverride() { + service = AuthenticationService(basePath: FileManager.default.temporaryDirectory.path, + passphrase: nil, customSlidingSyncProxy: "https://slidingsync.proxy") + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix-client.matrix.org"), + "Directly using a homeserver should succeed what a custom sliding sync proxy has been set.") + } +} diff --git a/bindings/matrix-sdk-ffi/src/authentication_service.rs b/bindings/matrix-sdk-ffi/src/authentication_service.rs index b6313cdc6..50ed05702 100644 --- a/bindings/matrix-sdk-ffi/src/authentication_service.rs +++ b/bindings/matrix-sdk-ffi/src/authentication_service.rs @@ -5,6 +5,7 @@ use matrix_sdk::{ ruma::{IdParseError, OwnedDeviceId, UserId}, Session, }; +use url::Url; use zeroize::Zeroize; use super::{client::Client, client_builder::ClientBuilder, RUNTIME}; @@ -106,7 +107,7 @@ impl AuthenticationService { let url = login_details.0; let authentication_issuer = login_details.1; - let supports_password_login = login_details.2.map_err(AuthenticationError::from)?; + let supports_password_login = login_details.2?; Ok(HomeserverLoginDetails { url, authentication_issuer, supports_password_login }) } @@ -126,26 +127,34 @@ impl AuthenticationService { ) -> Result<(), AuthenticationError> { let mut builder = Arc::new(ClientBuilder::new()).base_path(self.base_path.clone()); - // Remove any URL scheme from the name to attempt discovery first. - let server_name = matrix_sdk::sanitize_server_name(&server_name_or_homeserver_url) - .map_err(AuthenticationError::from)?; - - builder = builder.server_name(server_name.to_string()); - - let client = builder - .build() - .or_else(|e| { - if !server_name_or_homeserver_url.starts_with("http://") - && !server_name_or_homeserver_url.starts_with("http://") - { - return Err(e); + // Attempt discovery as a server name first. + let result = matrix_sdk::sanitize_server_name(&server_name_or_homeserver_url); + match result { + Ok(server_name) => { + builder = builder.server_name(server_name.to_string()); + } + Err(e) => { + // When the input isn't a valid server name check it is a URL. + // If this is the case, build the client with a homeserver URL. + if let Ok(_url) = Url::parse(&server_name_or_homeserver_url) { + builder = builder.homeserver_url(server_name_or_homeserver_url.clone()); + } else { + return Err(e.into()); } - // When discovery fails, fallback to the homeserver URL if supplied. - let mut builder = Arc::new(ClientBuilder::new()).base_path(self.base_path.clone()); - builder = builder.homeserver_url(server_name_or_homeserver_url); - builder.build() - }) - .map_err(AuthenticationError::from)?; + } + } + + let client = builder.build().or_else(|e| { + if !server_name_or_homeserver_url.starts_with("http://") + && !server_name_or_homeserver_url.starts_with("https://") + { + return Err(e); + } + // When discovery fails, fallback to the homeserver URL if supplied. + let mut builder = Arc::new(ClientBuilder::new()).base_path(self.base_path.clone()); + builder = builder.homeserver_url(server_name_or_homeserver_url); + builder.build() + })?; let details = RUNTIME.block_on(self.details_from_client(&client))?; @@ -177,9 +186,7 @@ impl AuthenticationService { // Login and ask the server for the full user ID as this could be different from // the username that was entered. - client - .login(username, password, initial_device_name, device_id) - .map_err(AuthenticationError::from)?; + client.login(username, password, initial_device_name, device_id)?; let whoami = client.whoami()?; // Create a new client to setup the store path now the user ID is known. @@ -201,11 +208,10 @@ impl AuthenticationService { .homeserver_url(homeserver_url) .sliding_sync_proxy(sliding_sync_proxy) .username(whoami.user_id.to_string()) - .build() - .map_err(AuthenticationError::from)?; + .build()?; // Restore the client using the session from the login request. - client.restore_session_inner(session).map_err(AuthenticationError::from)?; + client.restore_session_inner(session)?; Ok(client) } @@ -238,7 +244,7 @@ impl AuthenticationService { device_id: device_id.clone(), }; - client.restore_session_inner(discovery_session).map_err(AuthenticationError::from)?; + client.restore_session_inner(discovery_session)?; let whoami = client.whoami()?; // Create the actual client with a store path from the user ID. @@ -254,11 +260,10 @@ impl AuthenticationService { .passphrase(self.passphrase.clone()) .homeserver_url(homeserver_url) .username(whoami.user_id.to_string()) - .build() - .map_err(AuthenticationError::from)?; + .build()?; // Restore the client using the session. - client.restore_session_inner(session).map_err(AuthenticationError::from)?; + client.restore_session_inner(session)?; Ok(client) } } diff --git a/crates/matrix-sdk/src/lib.rs b/crates/matrix-sdk/src/lib.rs index 12f923632..d5bcafee3 100644 --- a/crates/matrix-sdk/src/lib.rs +++ b/crates/matrix-sdk/src/lib.rs @@ -76,7 +76,37 @@ fn init_logging() { } /// Creates a server name from a user supplied string. The string is first -/// sanitized by removing the http(s) scheme before being parsed. +/// sanitized by removing whitespace, the http(s) scheme and any trailing +/// slashes before being parsed. pub fn sanitize_server_name(s: &str) -> Result { - ServerName::parse(s.trim_start_matches("http://").trim_start_matches("https://")) + ServerName::parse( + s.trim().trim_start_matches("http://").trim_start_matches("https://").trim_end_matches('/'), + ) +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + + use crate::sanitize_server_name; + + #[test] + fn test_sanitize_server_name() { + assert_eq!(sanitize_server_name("matrix.org").unwrap().as_str(), "matrix.org"); + assert_eq!(sanitize_server_name("https://matrix.org").unwrap().as_str(), "matrix.org"); + assert_eq!(sanitize_server_name("http://matrix.org").unwrap().as_str(), "matrix.org"); + assert_eq!( + sanitize_server_name("https://matrix.server.org").unwrap().as_str(), + "matrix.server.org" + ); + assert_eq!( + sanitize_server_name("https://matrix.server.org/").unwrap().as_str(), + "matrix.server.org" + ); + assert_eq!( + sanitize_server_name(" https://matrix.server.org// ").unwrap().as_str(), + "matrix.server.org" + ); + assert_matches!(sanitize_server_name("https://matrix.server.org/something"), Err(_)) + } }