use crate::{lock_r, lock_w, state::FDOLL, APP_HANDLE}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use flate2::{read::GzDecoder, write::GzEncoder, Compression}; use keyring::Entry; use rand::{distr::Alphanumeric, Rng}; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::io::{Read, Write}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use tauri_plugin_opener::OpenerExt; use thiserror::Error; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::sync::Mutex; use tracing::{error, info, warn}; use url::form_urlencoded; static REFRESH_LOCK: once_cell::sync::Lazy> = once_cell::sync::Lazy::new(|| Mutex::new(())); static AUTH_SUCCESS_HTML: &str = include_str!("../assets/auth-success.html"); const SERVICE_NAME: &str = "friendolls"; /// Errors that can occur during OAuth authentication flow. #[derive(Debug, Error)] pub enum OAuthError { #[error("Failed to exchange code: {0}")] ExchangeFailed(String), #[error("Invalid callback state - possible CSRF attack")] InvalidState, #[error("Missing callback parameter: {0}")] MissingParameter(String), #[error("Keyring error: {0}")] KeyringError(#[from] keyring::Error), #[error("Network error: {0}")] NetworkError(#[from] reqwest::Error), #[error("JSON serialization error: {0}")] SerializationError(#[from] serde_json::Error), #[error("Server binding failed: {0}")] ServerBindError(String), #[error("Callback timeout - no response received")] CallbackTimeout, #[error("Invalid app configuration")] InvalidConfig, #[error("Failed to refresh token")] RefreshFailed, #[error("OAuth state expired or not initialized")] StateExpired, #[error("IO error: {0}")] IoError(#[from] std::io::Error), } /// Parameters received from the OAuth callback. pub struct OAuthCallbackParams { state: String, session_state: String, iss: String, code: String, } /// Authentication pass containing access token, refresh token, and metadata. #[derive(Debug, Clone, Deserialize, Serialize)] pub struct AuthPass { pub access_token: String, pub expires_in: u64, pub refresh_expires_in: u64, pub refresh_token: String, pub token_type: String, pub session_state: String, pub scope: String, pub issued_at: Option, } /// Generate a random code verifier for PKCE. /// /// Per PKCE spec (RFC 7636), the code verifier should be 43-128 characters. fn generate_code_verifier(length: usize) -> String { rand::rng() .sample_iter(&Alphanumeric) .take(length) .map(char::from) .collect() } /// Generate code challenge from a code verifier using SHA-256. /// /// This implements the S256 method as specified in RFC 7636. fn generate_code_challenge(code_verifier: &str) -> String { let mut hasher = Sha256::new(); hasher.update(code_verifier.as_bytes()); let result = hasher.finalize(); URL_SAFE_NO_PAD.encode(&result) } /// Returns the auth pass object, including /// access token, refresh token, expire time etc. /// Automatically refreshes if expired. pub async fn get_tokens() -> Option { info!("Retrieving tokens"); let Some(auth_pass) = ({ lock_r!(FDOLL).auth_pass.clone() }) else { return None; }; let Some(issued_at) = auth_pass.issued_at else { warn!("Auth pass missing issued_at timestamp, clearing"); lock_w!(FDOLL).auth_pass = None; return None; }; let current_time = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); let expired = current_time - issued_at >= auth_pass.expires_in; let refresh_expired = current_time - issued_at >= auth_pass.refresh_expires_in; if !expired { return Some(auth_pass); } if refresh_expired { info!("Refresh token expired, clearing auth state"); lock_w!(FDOLL).auth_pass = None; if let Err(e) = clear_auth_pass() { error!("Failed to clear expired auth pass: {}", e); } return None; } // Use mutex to prevent concurrent refresh let _guard = REFRESH_LOCK.lock().await; // Double-check after acquiring lock let auth_pass = lock_r!(FDOLL).auth_pass.clone()?; let current_time = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); let expired = current_time - auth_pass.issued_at? >= auth_pass.expires_in; if !expired { // Another thread already refreshed return Some(auth_pass); } info!("Access token expired, attempting refresh"); match refresh_token(&auth_pass.refresh_token).await { Ok(new_pass) => Some(new_pass), Err(e) => { error!("Failed to refresh token: {}", e); lock_w!(FDOLL).auth_pass = None; if let Err(e) = clear_auth_pass() { error!("Failed to clear auth pass after refresh failure: {}", e); } None } } } /// Helper function to get the current access token. pub async fn get_access_token() -> Option { get_tokens().await.map(|pass| pass.access_token) } /// Save auth_pass to secure storage (keyring) and update app state. pub fn save_auth_pass(auth_pass: &AuthPass) -> Result<(), OAuthError> { let json = serde_json::to_string(auth_pass)?; info!("Original JSON length: {}", json.len()); let mut encoder = GzEncoder::new(Vec::new(), Compression::best()); encoder .write_all(json.as_bytes()) .map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?; let compressed = encoder .finish() .map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?; info!("Compressed length: {}", compressed.len()); let encoded = URL_SAFE_NO_PAD.encode(&compressed); info!("Encoded length: {}", encoded.len()); // Windows keyring has a 2560-byte UTF-16 limit, which means 1280 chars max // Split into chunks of 1200 chars to be safe const CHUNK_SIZE: usize = 1200; let chunks: Vec<&str> = encoded .as_bytes() .chunks(CHUNK_SIZE) .map(|chunk| std::str::from_utf8(chunk).unwrap()) .collect(); info!("Splitting auth pass into {} chunks", chunks.len()); // Save chunk count let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?; count_entry.set_password(&chunks.len().to_string())?; // Save each chunk for (i, chunk) in chunks.iter().enumerate() { let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?; entry.set_password(chunk)?; } info!( "Auth pass saved to keyring successfully in {} chunks", chunks.len() ); Ok(()) } /// Load auth_pass from secure storage (keyring). pub fn load_auth_pass() -> Result, OAuthError> { info!("Reading credentials from keyring"); // Get chunk count let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?; let chunk_count = match count_entry.get_password() { Ok(count_str) => match count_str.parse::() { Ok(count) => count, Err(_) => { error!("Invalid chunk count in keyring"); return Ok(None); } }, Err(keyring::Error::NoEntry) => { info!("No auth pass found in keyring"); return Ok(None); } Err(e) => { error!("Failed to load chunk count from keyring"); return Err(OAuthError::KeyringError(e)); } }; info!("Loading {} auth pass chunks from keyring", chunk_count); // Reassemble chunks let mut encoded = String::new(); for i in 0..chunk_count { let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?; match entry.get_password() { Ok(chunk) => encoded.push_str(&chunk), Err(e) => { error!("Failed to load chunk {} from keyring", i); return Err(OAuthError::KeyringError(e)); } } } info!("Reassembled encoded length: {}", encoded.len()); let compressed = match URL_SAFE_NO_PAD.decode(&encoded) { Ok(c) => c, Err(e) => { error!("Failed to base64 decode auth pass from keyring: {}", e); return Ok(None); } }; let mut decoder = GzDecoder::new(&compressed[..]); let mut json = String::new(); if let Err(e) = decoder.read_to_string(&mut json) { error!("Failed to decompress auth pass from keyring: {}", e); return Ok(None); } let auth_pass: AuthPass = match serde_json::from_str(&json) { Ok(v) => { info!("Deserialized auth pass from keyring"); v } Err(_e) => { error!("Failed to decode auth pass from keyring"); return Ok(None); } }; info!("Auth pass loaded from keyring"); Ok(Some(auth_pass)) } /// Clear auth_pass from secure storage and app state. pub fn clear_auth_pass() -> Result<(), OAuthError> { // Try to get chunk count let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?; let chunk_count = match count_entry.get_password() { Ok(count_str) => count_str.parse::().unwrap_or(0), Err(_) => 0, }; // Delete all chunks for i in 0..chunk_count { let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?; let _ = entry.delete_credential(); } // Delete chunk count let _ = count_entry.delete_credential(); info!("Auth pass cleared from keyring successfully"); Ok(()) } /// Logout the current user by clearing tokens from storage and state. /// /// # Note /// /// This currently only clears local tokens. For complete logout, you should also /// call the OAuth provider's token revocation endpoint if available. /// /// # Example /// /// ```rust,no_run /// use crate::core::services::auth::logout; /// /// logout().expect("Failed to logout"); /// ``` pub fn logout() -> Result<(), OAuthError> { info!("Logging out user"); lock_w!(FDOLL).auth_pass = None; clear_auth_pass()?; // Clear OAuth flow state as well lock_w!(FDOLL).oauth_flow = Default::default(); // TODO: Call OAuth provider's revocation endpoint // This would require adding a revoke_token() function that calls: // POST {auth_url}/revoke with the refresh_token Ok(()) } /// Helper to add authentication header to a request builder if tokens are available. /// /// # Example /// /// ```rust,no_run /// use crate::core::services::auth::with_auth; /// /// let client = reqwest::Client::new(); /// let request = client.get("https://api.example.com/user"); /// let authenticated_request = with_auth(request).await; /// ``` pub async fn with_auth(request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { if let Some(token) = get_access_token().await { request.header("Authorization", format!("Bearer {}", token)) } else { request } } /// Exchange authorization code for tokens. /// /// This is called after receiving the OAuth callback with an authorization code. /// It exchanges the code for an access token and refresh token. /// /// # Arguments /// /// * `callback_params` - Parameters received from the OAuth callback /// * `code_verifier` - The PKCE code verifier that was used to generate the code challenge /// /// # Errors /// /// Returns `OAuthError` if the exchange fails or the server returns an error. pub async fn exchange_code_for_auth_pass( callback_params: OAuthCallbackParams, code_verifier: &str, ) -> Result { let (app_config, http_client) = { let guard = lock_r!(FDOLL); let clients = guard.clients.as_ref(); if clients.is_none() { error!("Clients not initialized yet!"); return Err(OAuthError::InvalidConfig); } info!("HTTP client retrieved successfully for token exchange"); ( guard.app_config.clone(), clients.unwrap().http_client.clone(), ) }; let url = url::Url::parse(&format!("{}/token", &app_config.auth.auth_url)) .map_err(|_| OAuthError::InvalidConfig)?; let body = form_urlencoded::Serializer::new(String::new()) .append_pair("client_id", &app_config.auth.audience) .append_pair("grant_type", "authorization_code") .append_pair("redirect_uri", &app_config.auth.redirect_uri) .append_pair("code", &callback_params.code) .append_pair("code_verifier", code_verifier) .finish(); info!("Exchanging authorization code for tokens"); info!("Token endpoint URL: {}", url); info!("Request body length: {} bytes", body.len()); let exchange_request = http_client .post(url.clone()) .header("Content-Type", "application/x-www-form-urlencoded") .body(body); info!("Sending token exchange request..."); let exchange_request_response = match exchange_request.send().await { Ok(resp) => { info!("Received response with status: {}", resp.status()); resp } Err(e) => { error!("Failed to send token exchange request: {}", e); error!("Error details: {:?}", e); if e.is_timeout() { error!("Request timed out"); } if e.is_connect() { error!("Connection error - check network and DNS"); } if e.is_request() { error!("Request error - check request format"); } return Err(OAuthError::NetworkError(e)); } }; if !exchange_request_response.status().is_success() { let status = exchange_request_response.status(); let error_text = exchange_request_response.text().await.unwrap_or_default(); error!( "Token exchange failed with status {}: {}", status, error_text ); return Err(OAuthError::ExchangeFailed(format!( "Status: {}, Body: {}", status, error_text ))); } let mut auth_pass: AuthPass = exchange_request_response.json().await?; auth_pass.issued_at = Some( SystemTime::now() .duration_since(UNIX_EPOCH) .map_err(|_| OAuthError::ExchangeFailed("System time error".to_string()))? .as_secs(), ); info!("Successfully exchanged code for tokens"); Ok(auth_pass) } /// Initialize the OAuth authorization code flow. /// /// This function: /// 1. Generates PKCE code verifier and challenge /// 2. Generates state parameter for CSRF protection /// 3. Stores state and code verifier in app state /// 4. Opens the OAuth authorization URL in the user's browser /// 5. Starts a background listener for the callback /// /// The user will be redirected to the OAuth provider's login page, and after /// successful authentication, will be redirected back to the local callback server. /// /// # Example /// /// ```rust,no_run /// use crate::core::services::auth::init_auth_code_retrieval; /// /// init_auth_code_retrieval(); /// // User will be prompted to login in their browser /// ``` pub fn init_auth_code_retrieval(on_success: F) where F: FnOnce() + Send + 'static, { info!("init_auth_code_retrieval called"); let app_config = lock_r!(FDOLL).app_config.clone(); let opener = match APP_HANDLE.get() { Some(handle) => { info!("APP_HANDLE retrieved successfully"); handle.opener() } None => { error!("Cannot initialize auth: app handle not available"); return; } }; let code_verifier = generate_code_verifier(64); let code_challenge = generate_code_challenge(&code_verifier); let state = generate_code_verifier(16); // Store state and code_verifier for validation let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); { let mut guard = lock_w!(FDOLL); guard.oauth_flow.state = Some(state.clone()); guard.oauth_flow.code_verifier = Some(code_verifier.clone()); guard.oauth_flow.initiated_at = Some(current_time); } let mut url = match url::Url::parse(&format!("{}/auth", &app_config.auth.auth_url)) { Ok(url) => { info!("Parsed auth URL successfully"); url } Err(e) => { error!("Invalid auth URL configuration: {}", e); return; } }; url.query_pairs_mut() .append_pair("client_id", &app_config.auth.audience) .append_pair("response_type", "code") .append_pair("redirect_uri", &app_config.auth.redirect_uri) .append_pair("scope", "openid email profile") .append_pair("state", &state) .append_pair("code_challenge", &code_challenge) .append_pair("code_challenge_method", "S256"); info!("Initiating OAuth flow"); // Bind the server FIRST to ensure port is open // We bind synchronously using std::net::TcpListener then convert to tokio::net::TcpListener // to ensure the port is bound before we open the browser. info!("Attempting to bind to: {}", app_config.auth.redirect_host); let std_listener = match std::net::TcpListener::bind(&app_config.auth.redirect_host) { Ok(s) => { info!("Successfully bound to {}", app_config.auth.redirect_host); s.set_nonblocking(true).unwrap(); s } Err(e) => { error!("Failed to bind callback server: {}", e); return; } }; info!( "Listening on {} for /callback", app_config.auth.redirect_host ); tauri::async_runtime::spawn(async move { let listener = match TcpListener::from_std(std_listener) { Ok(l) => l, Err(e) => { error!("Failed to create async listener: {}", e); return; } }; match listen_for_callback(listener).await { Ok(callback_params) => { // Validate state let stored_state = lock_r!(FDOLL).oauth_flow.state.clone(); if stored_state.as_ref() != Some(&callback_params.state) { error!("State mismatch - possible CSRF attack!"); return; } // Retrieve code_verifier let code_verifier = match lock_r!(FDOLL).oauth_flow.code_verifier.clone() { Some(cv) => cv, None => { error!("Code verifier not found in state"); return; } }; // Clear OAuth flow state after successful callback lock_w!(FDOLL).oauth_flow = Default::default(); match exchange_code_for_auth_pass(callback_params, &code_verifier).await { Ok(auth_pass) => { lock_w!(FDOLL).auth_pass = Some(auth_pass.clone()); if let Err(e) = save_auth_pass(&auth_pass) { error!("Failed to save auth pass: {}", e); } else { info!("Authentication successful!"); crate::services::ws::init_ws_client().await; on_success(); } } Err(e) => { error!("Failed to exchange code for tokens: {}", e); } } } Err(e) => { error!("Failed to receive callback: {}", e); // Clear OAuth flow state on error lock_w!(FDOLL).oauth_flow = Default::default(); } } }); info!("Opening auth URL: {}", url); if let Err(e) = opener.open_url(url, None::<&str>) { error!("Failed to open auth portal: {}", e); } else { info!("Successfully called open_url for auth portal"); } } /// Refresh the access token using a refresh token. /// /// This is called automatically by `get_tokens()` when the access token is expired /// but the refresh token is still valid. /// /// # Arguments /// /// * `refresh_token` - The refresh token to use /// /// # Errors /// /// Returns `OAuthError::RefreshFailed` if the refresh fails. pub async fn refresh_token(refresh_token: &str) -> Result { let (app_config, http_client) = { let guard = lock_r!(FDOLL); ( guard.app_config.clone(), guard .clients .as_ref() .expect("clients present") .http_client .clone(), ) }; let url = url::Url::parse(&format!("{}/token", &app_config.auth.auth_url)) .map_err(|_| OAuthError::InvalidConfig)?; let body = form_urlencoded::Serializer::new(String::new()) .append_pair("client_id", &app_config.auth.audience) .append_pair("grant_type", "refresh_token") .append_pair("refresh_token", refresh_token) .finish(); info!("Refreshing access token"); let refresh_request = http_client .post(url) .header("Content-Type", "application/x-www-form-urlencoded") .body(body); let refresh_response = refresh_request.send().await?; if !refresh_response.status().is_success() { let status = refresh_response.status(); let error_text = refresh_response.text().await.unwrap_or_default(); error!( "Token refresh failed with status {}: {}", status, error_text ); return Err(OAuthError::RefreshFailed); } let mut auth_pass: AuthPass = refresh_response.json().await?; auth_pass.issued_at = Some( SystemTime::now() .duration_since(UNIX_EPOCH) .map_err(|_| OAuthError::RefreshFailed)? .as_secs(), ); // Update state and storage lock_w!(FDOLL).auth_pass = Some(auth_pass.clone()); if let Err(e) = save_auth_pass(&auth_pass) { error!("Failed to save refreshed auth pass: {}", e); } else { info!("Token refreshed successfully"); } Ok(auth_pass) } /// Start a local HTTP server to listen for the OAuth callback. /// /// This function starts a mini web server that listens on the configured redirect host /// for the OAuth callback. It: /// - Listens on the `/callback` endpoint /// - Validates all required parameters are present /// - Returns a nice HTML page to the user /// - Has a 5-minute timeout to prevent hanging indefinitely /// - Also provides a `/health` endpoint for health checks /// /// # Timeout /// /// The server will timeout after 5 minutes if no callback is received, /// preventing the server from running indefinitely if the user abandons the flow. /// /// # Errors /// /// Returns `OAuthError` if: /// - Required callback parameters are missing /// - Timeout is reached before callback is received async fn listen_for_callback(listener: TcpListener) -> Result { // Set a 5-minute timeout let timeout = Duration::from_secs(300); let start_time = Instant::now(); loop { let elapsed = start_time.elapsed(); if elapsed > timeout { warn!("Callback listener timed out after 5 minutes"); return Err(OAuthError::CallbackTimeout); } let accept_result = tokio::time::timeout(timeout - elapsed, listener.accept()).await; let (mut stream, _) = match accept_result { Ok(Ok(res)) => res, Ok(Err(e)) => { warn!("Accept error: {}", e); continue; } Err(_) => { warn!("Callback listener timed out after 5 minutes"); return Err(OAuthError::CallbackTimeout); } }; let mut buffer = [0; 4096]; let n = match stream.read(&mut buffer).await { Ok(n) if n > 0 => n, _ => continue, }; let request = String::from_utf8_lossy(&buffer[..n]); let first_line = request.lines().next().unwrap_or(""); let mut parts = first_line.split_whitespace(); match (parts.next(), parts.next()) { (Some("GET"), Some(path)) if path.starts_with("/callback") => { let full_url = format!("http://localhost{}", path); let url = match url::Url::parse(&full_url) { Ok(u) => u, Err(_) => continue, }; let params: std::collections::HashMap<_, _> = url.query_pairs().into_owned().collect(); info!("Received OAuth callback"); let find_param = |key: &str| -> Result { params .get(key) .cloned() .ok_or_else(|| OAuthError::MissingParameter(key.to_string())) }; let callback_params = OAuthCallbackParams { state: find_param("state")?, session_state: find_param("session_state")?, iss: find_param("iss")?, code: find_param("code")?, }; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\n\r\n{}", AUTH_SUCCESS_HTML.len(), AUTH_SUCCESS_HTML ); let _ = stream.write_all(response.as_bytes()).await; let _ = stream.flush().await; info!("Callback processed, stopping listener"); return Ok(callback_params); } (Some("GET"), Some("/health")) => { let response = "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK"; let _ = stream.write_all(response.as_bytes()).await; } _ => { let response = "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; let _ = stream.write_all(response.as_bytes()).await; } } } }