diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index f902f3a..6b9ba1f 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -47,12 +47,6 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" -[[package]] -name = "ascii" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" - [[package]] name = "async-broadcast" version = "0.7.2" @@ -462,12 +456,6 @@ dependencies = [ "windows-link 0.2.1", ] -[[package]] -name = "chunked_transfer" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" - [[package]] name = "combine" version = "4.6.7" @@ -1072,7 +1060,6 @@ dependencies = [ "tauri-plugin-opener", "tauri-plugin-positioner", "thiserror 1.0.69", - "tiny_http", "tokio", "tracing", "tracing-subscriber", @@ -1611,12 +1598,6 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" -[[package]] -name = "httpdate" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" - [[package]] name = "hyper" version = "1.7.0" @@ -4423,18 +4404,6 @@ dependencies = [ "time-core", ] -[[package]] -name = "tiny_http" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82" -dependencies = [ - "ascii", - "chunked_transfer", - "httpdate", - "log", -] - [[package]] name = "tinystr" version = "0.8.2" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index f0b2c3d..0fc4464 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -34,7 +34,6 @@ url = "2.5.7" rand = "0.9.2" sha2 = "0.10.9" base64 = "0.22.1" -tiny_http = "0.12.0" thiserror = "1" tracing = "0.1" tracing-subscriber = "0.3" diff --git a/src-tauri/src/core/services/auth.rs b/src-tauri/src/core/services/auth.rs index b9872ab..74c72ac 100644 --- a/src-tauri/src/core/services/auth.rs +++ b/src-tauri/src/core/services/auth.rs @@ -1,15 +1,16 @@ use crate::{core::state::FDOLL, lock_r, lock_w, APP_HANDLE}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; -use flate2::{write::GzEncoder, read::GzDecoder, Compression}; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; use keyring::Entry; use rand::{distr::Alphanumeric, Rng}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use std::thread; use std::io::{Read, Write}; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use tauri_plugin_opener::OpenerExt; use thiserror::Error; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; use tokio::sync::Mutex; use tracing::{error, info, warn}; use url::form_urlencoded; @@ -18,6 +19,7 @@ static REFRESH_LOCK: once_cell::sync::Lazy> = once_cell::sync::Lazy::new(|| Mutex::new(())); static AUTH_SUCCESS_HTML: &str = include_str!("../../assets/auth-success.html"); +const SERVICE_NAME: &str = "friendolls"; /// Errors that can occur during OAuth authentication flow. #[derive(Debug, Error)] @@ -54,6 +56,9 @@ pub enum OAuthError { #[error("OAuth state expired or not initialized")] StateExpired, + + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), } /// Parameters received from the OAuth callback. @@ -168,42 +173,50 @@ pub fn save_auth_pass(auth_pass: &AuthPass) -> Result<(), OAuthError> { let json = serde_json::to_string(auth_pass)?; info!("Original JSON length: {}", json.len()); let mut encoder = GzEncoder::new(Vec::new(), Compression::best()); - encoder.write_all(json.as_bytes()).map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?; - let compressed = encoder.finish().map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?; + encoder + .write_all(json.as_bytes()) + .map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?; + let compressed = encoder + .finish() + .map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?; info!("Compressed length: {}", compressed.len()); let encoded = URL_SAFE_NO_PAD.encode(&compressed); info!("Encoded length: {}", encoded.len()); - + // Windows keyring has a 2560-byte UTF-16 limit, which means 1280 chars max // Split into chunks of 1200 chars to be safe const CHUNK_SIZE: usize = 1200; - let chunks: Vec<&str> = encoded.as_bytes() + let chunks: Vec<&str> = encoded + .as_bytes() .chunks(CHUNK_SIZE) .map(|chunk| std::str::from_utf8(chunk).unwrap()) .collect(); - + info!("Splitting auth pass into {} chunks", chunks.len()); - + // Save chunk count - let count_entry = Entry::new("friendolls", "auth_pass_count")?; + let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?; count_entry.set_password(&chunks.len().to_string())?; - + // Save each chunk for (i, chunk) in chunks.iter().enumerate() { - let entry = Entry::new("friendolls", &format!("auth_pass_{}", i))?; + let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?; entry.set_password(chunk)?; } - - info!("Auth pass saved to keyring successfully in {} chunks", chunks.len()); + + info!( + "Auth pass saved to keyring successfully in {} chunks", + chunks.len() + ); Ok(()) } /// Load auth_pass from secure storage (keyring). pub fn load_auth_pass() -> Result, OAuthError> { info!("Reading credentials from keyring"); - + // Get chunk count - let count_entry = Entry::new("friendolls", "auth_pass_count")?; + let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?; let chunk_count = match count_entry.get_password() { Ok(count_str) => match count_str.parse::() { Ok(count) => count, @@ -221,13 +234,13 @@ pub fn load_auth_pass() -> Result, OAuthError> { return Err(OAuthError::KeyringError(e)); } }; - + info!("Loading {} auth pass chunks from keyring", chunk_count); - + // Reassemble chunks let mut encoded = String::new(); for i in 0..chunk_count { - let entry = Entry::new("friendolls", &format!("auth_pass_{}", i))?; + let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?; match entry.get_password() { Ok(chunk) => encoded.push_str(&chunk), Err(e) => { @@ -236,9 +249,9 @@ pub fn load_auth_pass() -> Result, OAuthError> { } } } - + info!("Reassembled encoded length: {}", encoded.len()); - + let compressed = match URL_SAFE_NO_PAD.decode(&encoded) { Ok(c) => c, Err(e) => { @@ -246,14 +259,14 @@ pub fn load_auth_pass() -> Result, OAuthError> { return Ok(None); } }; - + let mut decoder = GzDecoder::new(&compressed[..]); let mut json = String::new(); if let Err(e) = decoder.read_to_string(&mut json) { error!("Failed to decompress auth pass from keyring: {}", e); return Ok(None); } - + let auth_pass: AuthPass = match serde_json::from_str(&json) { Ok(v) => { info!("Deserialized auth pass from keyring"); @@ -264,7 +277,7 @@ pub fn load_auth_pass() -> Result, OAuthError> { return Ok(None); } }; - + info!("Auth pass loaded from keyring"); Ok(Some(auth_pass)) } @@ -272,21 +285,21 @@ pub fn load_auth_pass() -> Result, OAuthError> { /// Clear auth_pass from secure storage and app state. pub fn clear_auth_pass() -> Result<(), OAuthError> { // Try to get chunk count - let count_entry = Entry::new("friendolls", "auth_pass_count")?; + let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?; let chunk_count = match count_entry.get_password() { Ok(count_str) => count_str.parse::().unwrap_or(0), Err(_) => 0, }; - + // Delete all chunks for i in 0..chunk_count { - let entry = Entry::new("friendolls", &format!("auth_pass_{}", i))?; + let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?; let _ = entry.delete_credential(); } - + // Delete chunk count let _ = count_entry.delete_credential(); - + info!("Auth pass cleared from keyring successfully"); Ok(()) } @@ -485,56 +498,77 @@ where info!("Initiating OAuth flow"); - thread::spawn(move || { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap() - .block_on(async move { - match listen_for_callback().await { - Ok(callback_params) => { - // Validate state - let stored_state = lock_r!(FDOLL).oauth_flow.state.clone(); + // Bind the server FIRST to ensure port is open + // We bind synchronously using std::net::TcpListener then convert to tokio::net::TcpListener + // to ensure the port is bound before we open the browser. + let std_listener = match std::net::TcpListener::bind(&app_config.auth.redirect_host) { + Ok(s) => { + s.set_nonblocking(true).unwrap(); + s + } + Err(e) => { + error!("Failed to bind callback server: {}", e); + return; + } + }; - if stored_state.as_ref() != Some(&callback_params.state) { - error!("State mismatch - possible CSRF attack!"); - return; - } + info!( + "Listening on {} for /callback", + &app_config.auth.redirect_host + ); - // Retrieve code_verifier - let code_verifier = match lock_r!(FDOLL).oauth_flow.code_verifier.clone() { - Some(cv) => cv, - None => { - error!("Code verifier not found in state"); - return; - } - }; + tauri::async_runtime::spawn(async move { + let listener = match TcpListener::from_std(std_listener) { + Ok(l) => l, + Err(e) => { + error!("Failed to create async listener: {}", e); + return; + } + }; - // Clear OAuth flow state after successful callback - lock_w!(FDOLL).oauth_flow = Default::default(); + match listen_for_callback(listener).await { + Ok(callback_params) => { + // Validate state + let stored_state = lock_r!(FDOLL).oauth_flow.state.clone(); - match exchange_code_for_auth_pass(callback_params, &code_verifier).await { - Ok(auth_pass) => { - lock_w!(FDOLL).auth_pass = Some(auth_pass.clone()); - if let Err(e) = save_auth_pass(&auth_pass) { - error!("Failed to save auth pass: {}", e); - } else { - info!("Authentication successful!"); - on_success(); - } - } - Err(e) => { - error!("Failed to exchange code for tokens: {}", e); - } + if stored_state.as_ref() != Some(&callback_params.state) { + error!("State mismatch - possible CSRF attack!"); + return; + } + + // Retrieve code_verifier + let code_verifier = match lock_r!(FDOLL).oauth_flow.code_verifier.clone() { + Some(cv) => cv, + None => { + error!("Code verifier not found in state"); + return; + } + }; + + // Clear OAuth flow state after successful callback + lock_w!(FDOLL).oauth_flow = Default::default(); + + match exchange_code_for_auth_pass(callback_params, &code_verifier).await { + Ok(auth_pass) => { + lock_w!(FDOLL).auth_pass = Some(auth_pass.clone()); + if let Err(e) = save_auth_pass(&auth_pass) { + error!("Failed to save auth pass: {}", e); + } else { + info!("Authentication successful!"); + on_success(); } } Err(e) => { - error!("Failed to receive callback: {}", e); - // Clear OAuth flow state on error - lock_w!(FDOLL).oauth_flow = Default::default(); + error!("Failed to exchange code for tokens: {}", e); } } - }); + } + Err(e) => { + error!("Failed to receive callback: {}", e); + // Clear OAuth flow state on error + lock_w!(FDOLL).oauth_flow = Default::default(); + } + } }); if let Err(e) = opener.open_url(url, None::<&str>) { @@ -628,82 +662,91 @@ pub async fn refresh_token(refresh_token: &str) -> Result /// # Errors /// /// Returns `OAuthError` if: -/// - Server fails to bind to the configured port /// - Required callback parameters are missing /// - Timeout is reached before callback is received -async fn listen_for_callback() -> Result { - let app_config = lock_r!(FDOLL) - .app_config - .clone() - .ok_or(OAuthError::InvalidConfig)?; - - let server = tiny_http::Server::http(&app_config.auth.redirect_host) - .map_err(|e| OAuthError::ServerBindError(e.to_string()))?; - - info!( - "Listening on {} for /callback", - &app_config.auth.redirect_host - ); - +async fn listen_for_callback(listener: TcpListener) -> Result { // Set a 5-minute timeout let timeout = Duration::from_secs(300); - let start_time = SystemTime::now(); + let start_time = Instant::now(); - for request in server.incoming_requests() { - // Check timeout - if SystemTime::now() - .duration_since(start_time) - .unwrap_or(Duration::ZERO) - > timeout - { + loop { + let elapsed = start_time.elapsed(); + if elapsed > timeout { warn!("Callback listener timed out after 5 minutes"); return Err(OAuthError::CallbackTimeout); } - let url = request.url().to_string(); + let accept_result = tokio::time::timeout(timeout - elapsed, listener.accept()).await; - if url.starts_with("/callback") { - let query = url.split('?').nth(1).unwrap_or(""); - let params = form_urlencoded::parse(query.as_bytes()) - .map(|(k, v)| (k.into_owned(), v.into_owned())) - .collect::>(); + let (mut stream, _) = match accept_result { + Ok(Ok(res)) => res, + Ok(Err(e)) => { + warn!("Accept error: {}", e); + continue; + } + Err(_) => { + warn!("Callback listener timed out after 5 minutes"); + return Err(OAuthError::CallbackTimeout); + } + }; - info!("Received OAuth callback"); + let mut buffer = [0; 4096]; + let n = match stream.read(&mut buffer).await { + Ok(n) if n > 0 => n, + _ => continue, + }; - let find_param = |key: &str| -> Result { - params - .iter() - .find(|(k, _)| k == key) - .map(|(_, v)| v.clone()) - .ok_or_else(|| OAuthError::MissingParameter(key.to_string())) - }; + let request = String::from_utf8_lossy(&buffer[..n]); + let first_line = request.lines().next().unwrap_or(""); + let mut parts = first_line.split_whitespace(); - let callback_params = OAuthCallbackParams { - state: find_param("state")?, - session_state: find_param("session_state")?, - iss: find_param("iss")?, - code: find_param("code")?, - }; + match (parts.next(), parts.next()) { + (Some("GET"), Some(path)) if path.starts_with("/callback") => { + let full_url = format!("http://localhost{}", path); + let url = match url::Url::parse(&full_url) { + Ok(u) => u, + Err(_) => continue, + }; - let response = tiny_http::Response::from_string(AUTH_SUCCESS_HTML).with_header( - tiny_http::Header::from_bytes( - &b"Content-Type"[..], - &b"text/html; charset=utf-8"[..], - ) - .map_err(|_| OAuthError::ServerBindError("Header creation failed".to_string()))?, - ); + let params: std::collections::HashMap<_, _> = + url.query_pairs().into_owned().collect(); - let _ = request.respond(response); + info!("Received OAuth callback"); - info!("Callback processed, stopping listener"); - return Ok(callback_params); - } else if url == "/health" { - // Health check endpoint - let _ = request.respond(tiny_http::Response::from_string("OK")); - } else { - let _ = request.respond(tiny_http::Response::empty(404)); + let find_param = |key: &str| -> Result { + params + .get(key) + .cloned() + .ok_or_else(|| OAuthError::MissingParameter(key.to_string())) + }; + + let callback_params = OAuthCallbackParams { + state: find_param("state")?, + session_state: find_param("session_state")?, + iss: find_param("iss")?, + code: find_param("code")?, + }; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\n\r\n{}", + AUTH_SUCCESS_HTML.len(), + AUTH_SUCCESS_HTML + ); + + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.flush().await; + + info!("Callback processed, stopping listener"); + return Ok(callback_params); + } + (Some("GET"), Some("/health")) => { + let response = "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK"; + let _ = stream.write_all(response.as_bytes()).await; + } + _ => { + let response = "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + let _ = stream.write_all(response.as_bytes()).await; + } } } - - Err(OAuthError::CallbackTimeout) }