diff --git a/src-tauri/src/init/lifecycle.rs b/src-tauri/src/init/lifecycle.rs index cbc1963..9d95fa0 100644 --- a/src-tauri/src/init/lifecycle.rs +++ b/src-tauri/src/init/lifecycle.rs @@ -12,7 +12,10 @@ use crate::{ scene::open_scene_window, ws::client::{clear_ws_client, establish_websocket_connection}, }, - state::{clear_app_data, init_app_data_scoped, AppDataRefreshScope}, + state::{ + auth::{start_background_token_refresh, stop_background_token_refresh}, + clear_app_data, init_app_data_scoped, AppDataRefreshScope, + }, system_tray::update_system_tray, }; @@ -35,12 +38,14 @@ pub async fn destruct_user_session() { async fn connect_user_profile() { init_app_data_scoped(AppDataRefreshScope::All).await; establish_websocket_connection().await; + start_background_token_refresh().await; } /// Clears the user profile and WebSocket connection. async fn disconnect_user_profile() { clear_app_data(); clear_ws_client().await; + stop_background_token_refresh(); } /// Destructs the user session and show health manager window diff --git a/src-tauri/src/services/auth.rs b/src-tauri/src/services/auth.rs index ff0cb0e..3a2bb97 100644 --- a/src-tauri/src/services/auth.rs +++ b/src-tauri/src/services/auth.rs @@ -1,4 +1,5 @@ use crate::get_app_handle; +use crate::state::auth::get_auth_pass_with_refresh; use crate::{lock_r, lock_w, state::FDOLL}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use flate2::{read::GzDecoder, write::GzEncoder, Compression}; @@ -12,14 +13,10 @@ use tauri_plugin_opener::OpenerExt; use thiserror::Error; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; use tracing::{error, info, warn}; use url::form_urlencoded; -static REFRESH_LOCK: once_cell::sync::Lazy> = - once_cell::sync::Lazy::new(|| Mutex::new(())); - static AUTH_SUCCESS_HTML: &str = include_str!("../assets/auth-success.html"); const SERVICE_NAME: &str = "friendolls"; @@ -115,60 +112,7 @@ fn generate_code_challenge(code_verifier: &str) -> String { /// access token, refresh token, expire time etc. /// Automatically refreshes if expired. pub async fn get_session_token() -> Option { - info!("Retrieving tokens"); - let Some(auth_pass) = ({ lock_r!(FDOLL).auth.auth_pass.clone() }) else { - return None; - }; - - let Some(issued_at) = auth_pass.issued_at else { - warn!("Auth pass missing issued_at timestamp, clearing"); - lock_w!(FDOLL).auth.auth_pass = None; - return None; - }; - - let current_time = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); - - let expired = current_time - issued_at >= auth_pass.expires_in; - let refresh_expired = current_time - issued_at >= auth_pass.refresh_expires_in; - - if !expired { - return Some(auth_pass); - } - - if refresh_expired { - info!("Refresh token expired, clearing auth state"); - lock_w!(FDOLL).auth.auth_pass = None; - if let Err(e) = clear_auth_pass() { - error!("Failed to clear expired auth pass: {}", e); - } - return None; - } - - // Use mutex to prevent concurrent refresh - let _guard = REFRESH_LOCK.lock().await; - - // Double-check after acquiring lock - let auth_pass = lock_r!(FDOLL).auth.auth_pass.clone()?; - let current_time = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); - let expired = current_time - auth_pass.issued_at? >= auth_pass.expires_in; - - if !expired { - // Another thread already refreshed - return Some(auth_pass); - } - - info!("Access token expired, attempting refresh"); - match refresh_token(&auth_pass.refresh_token).await { - Ok(new_pass) => Some(new_pass), - Err(e) => { - error!("Failed to refresh token: {}", e); - lock_w!(FDOLL).auth.auth_pass = None; - if let Err(e) = clear_auth_pass() { - error!("Failed to clear auth pass after refresh failure: {}", e); - } - None - } - } + get_auth_pass_with_refresh().await } /// Helper function to get the current access token. diff --git a/src-tauri/src/state/auth.rs b/src-tauri/src/state/auth.rs index a5fe5d1..280f284 100644 --- a/src-tauri/src/state/auth.rs +++ b/src-tauri/src/state/auth.rs @@ -1,5 +1,16 @@ -use crate::services::auth::{load_auth_pass, AuthPass}; -use tracing::{info, warn}; +use crate::init::lifecycle::destruct_user_session; +use crate::services::auth::{clear_auth_pass, load_auth_pass, refresh_token, AuthPass}; +use crate::services::welcome::open_welcome_window; +use crate::{lock_r, lock_w, state::FDOLL}; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::Mutex; +use tokio::time; +use tokio::time::Duration; +use tokio_util::sync::CancellationToken; +use tracing::{error, info, warn}; + +static REFRESH_LOCK: once_cell::sync::Lazy> = + once_cell::sync::Lazy::new(|| Mutex::new(())); #[derive(Default, Clone)] pub struct OAuthFlowTracker { @@ -12,6 +23,7 @@ pub struct OAuthFlowTracker { pub struct AuthState { pub auth_pass: Option, pub oauth_flow: OAuthFlowTracker, + pub background_refresh_token: Option, } impl Default for AuthState { @@ -19,6 +31,7 @@ impl Default for AuthState { Self { auth_pass: None, oauth_flow: OAuthFlowTracker::default(), + background_refresh_token: None, } } } @@ -36,5 +49,176 @@ pub fn init_auth_state() -> AuthState { AuthState { auth_pass, oauth_flow: OAuthFlowTracker::default(), + background_refresh_token: None, } -} \ No newline at end of file +} + +/// Returns the auth pass object, including access token, refresh token, and metadata. +/// Automatically refreshes if expired and clears session if refresh token is expired. +pub async fn get_auth_pass_with_refresh() -> Option { + info!("Retrieving tokens"); + let Some(auth_pass) = ({ lock_r!(FDOLL).auth.auth_pass.clone() }) else { + return None; + }; + + let Some(issued_at) = auth_pass.issued_at else { + warn!("Auth pass missing issued_at timestamp, clearing"); + lock_w!(FDOLL).auth.auth_pass = None; + return None; + }; + + let current_time = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); + let expired = current_time - issued_at >= auth_pass.expires_in; + let refresh_expired = current_time - issued_at >= auth_pass.refresh_expires_in; + + if !expired { + return Some(auth_pass); + } + + if refresh_expired { + info!("Refresh token expired, clearing auth state"); + lock_w!(FDOLL).auth.auth_pass = None; + if let Err(e) = clear_auth_pass() { + error!("Failed to clear expired auth pass: {}", e); + } + destruct_user_session().await; + open_welcome_window(); + return None; + } + + let _guard = REFRESH_LOCK.lock().await; + + let auth_pass = lock_r!(FDOLL).auth.auth_pass.clone()?; + let current_time = SystemTime::now().duration_since(UNIX_EPOCH).ok()?.as_secs(); + let Some(issued_at) = auth_pass.issued_at else { + warn!("Auth pass missing issued_at timestamp after refresh lock, clearing"); + lock_w!(FDOLL).auth.auth_pass = None; + return None; + }; + let expired = current_time - issued_at >= auth_pass.expires_in; + let refresh_expired = current_time - issued_at >= auth_pass.refresh_expires_in; + + if refresh_expired { + info!("Refresh token expired, clearing auth state after refresh lock"); + lock_w!(FDOLL).auth.auth_pass = None; + if let Err(e) = clear_auth_pass() { + error!("Failed to clear expired auth pass: {}", e); + } + destruct_user_session().await; + open_welcome_window(); + return None; + } + + if !expired { + return Some(auth_pass); + } + + info!("Access token expired, attempting refresh"); + match refresh_token(&auth_pass.refresh_token).await { + Ok(new_pass) => Some(new_pass), + Err(e) => { + error!("Failed to refresh token: {}", e); + lock_w!(FDOLL).auth.auth_pass = None; + if let Err(e) = clear_auth_pass() { + error!("Failed to clear auth pass after refresh failure: {}", e); + } + None + } + } +} + +async fn refresh_if_expiring_soon() { + let Some(auth_pass) = ({ lock_r!(FDOLL).auth.auth_pass.clone() }) else { + return; + }; + + let Some(issued_at) = auth_pass.issued_at else { + return; + }; + + let current_time = match SystemTime::now().duration_since(UNIX_EPOCH) { + Ok(value) => value.as_secs(), + Err(_) => return, + }; + + let refresh_expires_at = issued_at.saturating_add(auth_pass.refresh_expires_in); + if current_time >= refresh_expires_at { + lock_w!(FDOLL).auth.auth_pass = None; + if let Err(e) = clear_auth_pass() { + error!("Failed to clear expired auth pass: {}", e); + } + destruct_user_session().await; + open_welcome_window(); + return; + } + + let access_expires_at = issued_at.saturating_add(auth_pass.expires_in); + if access_expires_at.saturating_sub(current_time) >= 60 { + return; + } + + let _guard = REFRESH_LOCK.lock().await; + + let Some(latest_pass) = ({ lock_r!(FDOLL).auth.auth_pass.clone() }) else { + return; + }; + + let Some(latest_issued_at) = latest_pass.issued_at else { + return; + }; + + let current_time = match SystemTime::now().duration_since(UNIX_EPOCH) { + Ok(value) => value.as_secs(), + Err(_) => return, + }; + + let refresh_expires_at = latest_issued_at.saturating_add(latest_pass.refresh_expires_in); + if current_time >= refresh_expires_at { + lock_w!(FDOLL).auth.auth_pass = None; + if let Err(e) = clear_auth_pass() { + error!("Failed to clear expired auth pass: {}", e); + } + destruct_user_session().await; + open_welcome_window(); + return; + } + + let access_expires_at = latest_issued_at.saturating_add(latest_pass.expires_in); + if access_expires_at.saturating_sub(current_time) >= 60 { + return; + } + + if let Err(e) = refresh_token(&latest_pass.refresh_token).await { + warn!("Background refresh failed: {}", e); + } +} + +/// Starts a background loop to periodically refresh tokens when authenticated. +pub async fn start_background_token_refresh() { + stop_background_token_refresh(); + let cancel_token = CancellationToken::new(); + { + let mut guard = lock_w!(FDOLL); + guard.auth.background_refresh_token = Some(cancel_token.clone()); + } + tokio::spawn(async move { + let mut interval = time::interval(Duration::from_secs(60)); + loop { + tokio::select! { + _ = cancel_token.cancelled() => { + break; + } + _ = interval.tick() => { + refresh_if_expiring_soon().await; + } + } + } + }); +} + +/// Stops the background token refresh loop. +pub fn stop_background_token_refresh() { + if let Some(token) = lock_w!(FDOLL).auth.background_refresh_token.take() { + token.cancel(); + } +} diff --git a/src-tauri/src/state/mod.rs b/src-tauri/src/state/mod.rs index 2c71244..5c2e91a 100644 --- a/src-tauri/src/state/mod.rs +++ b/src-tauri/src/state/mod.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, LazyLock, RwLock}; use tauri::tray::TrayIcon; use tracing::info; -mod auth; +pub mod auth; mod network; mod ui;