minor refinements
This commit is contained in:
@@ -318,7 +318,7 @@ pub fn clear_auth_pass() -> Result<(), OAuthError> {
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use crate::core::services::auth::logout;
|
||||
/// use crate::services::auth::logout;
|
||||
///
|
||||
/// logout().expect("Failed to logout");
|
||||
/// ```
|
||||
@@ -342,7 +342,7 @@ pub fn logout() -> Result<(), OAuthError> {
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use crate::core::services::auth::with_auth;
|
||||
/// use crate::services::auth::with_auth;
|
||||
///
|
||||
/// let client = reqwest::Client::new();
|
||||
/// let request = client.get("https://api.example.com/user");
|
||||
@@ -370,6 +370,7 @@ pub async fn with_auth(request: reqwest::RequestBuilder) -> reqwest::RequestBuil
|
||||
///
|
||||
/// Returns `OAuthError` if the exchange fails or the server returns an error.
|
||||
pub async fn exchange_code_for_auth_pass(
|
||||
redirect_uri: &str,
|
||||
callback_params: OAuthCallbackParams,
|
||||
code_verifier: &str,
|
||||
) -> Result<AuthPass, OAuthError> {
|
||||
@@ -393,7 +394,7 @@ pub async fn exchange_code_for_auth_pass(
|
||||
let body = form_urlencoded::Serializer::new(String::new())
|
||||
.append_pair("client_id", &app_config.auth.audience)
|
||||
.append_pair("grant_type", "authorization_code")
|
||||
.append_pair("redirect_uri", &app_config.auth.redirect_uri)
|
||||
.append_pair("redirect_uri", redirect_uri)
|
||||
.append_pair("code", &callback_params.code)
|
||||
.append_pair("code_verifier", code_verifier)
|
||||
.finish();
|
||||
@@ -469,7 +470,7 @@ pub async fn exchange_code_for_auth_pass(
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use crate::core::services::auth::init_auth_code_retrieval;
|
||||
/// use crate::services::auth::init_auth_code_retrieval;
|
||||
///
|
||||
/// init_auth_code_retrieval();
|
||||
/// // User will be prompted to login in their browser
|
||||
@@ -511,24 +512,19 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
url.query_pairs_mut()
|
||||
.append_pair("client_id", &app_config.auth.audience)
|
||||
.append_pair("response_type", "code")
|
||||
.append_pair("redirect_uri", &app_config.auth.redirect_uri)
|
||||
.append_pair("scope", "openid email profile")
|
||||
.append_pair("state", &state)
|
||||
.append_pair("code_challenge", &code_challenge)
|
||||
.append_pair("code_challenge_method", "S256");
|
||||
|
||||
info!("Initiating OAuth flow");
|
||||
|
||||
// Bind the server FIRST to ensure port is open
|
||||
// We bind synchronously using std::net::TcpListener then convert to tokio::net::TcpListener
|
||||
// to ensure the port is bound before we open the browser.
|
||||
info!("Attempting to bind to: {}", app_config.auth.redirect_host);
|
||||
let std_listener = match std::net::TcpListener::bind(&app_config.auth.redirect_host) {
|
||||
|
||||
// Bind to port 0 (ephemeral port),
|
||||
// The OS will assign an available port.
|
||||
let bind_addr = "localhost:0";
|
||||
|
||||
info!("Attempting to bind to: {}", bind_addr);
|
||||
let std_listener = match std::net::TcpListener::bind(&bind_addr) {
|
||||
Ok(s) => {
|
||||
info!("Successfully bound to {}", app_config.auth.redirect_host);
|
||||
s.set_nonblocking(true).unwrap();
|
||||
s
|
||||
}
|
||||
@@ -538,68 +534,75 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
"Listening on {} for /callback",
|
||||
app_config.auth.redirect_host
|
||||
);
|
||||
// Get the actual port assigned by the OS
|
||||
let local_addr = std_listener
|
||||
.local_addr()
|
||||
.map_err(|e| OAuthError::ServerBindError(e.to_string()))?;
|
||||
let port = local_addr.port();
|
||||
info!("Successfully bound to {}", local_addr);
|
||||
info!("Listening on port {} for /callback", port);
|
||||
|
||||
let redirect_uri = format!("http://localhost:{}/callback", port);
|
||||
|
||||
url.query_pairs_mut()
|
||||
.append_pair("client_id", &app_config.auth.audience)
|
||||
.append_pair("response_type", "code")
|
||||
.append_pair("redirect_uri", &redirect_uri)
|
||||
.append_pair("scope", "openid email profile")
|
||||
.append_pair("state", &state)
|
||||
.append_pair("code_challenge", &code_challenge)
|
||||
.append_pair("code_challenge_method", "S256");
|
||||
let redirect_uri_clone = redirect_uri.clone();
|
||||
tauri::async_runtime::spawn(async move {
|
||||
info!("Starting callback listener task");
|
||||
let listener = match TcpListener::from_std(std_listener) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
error!("Failed to create async listener: {}", e);
|
||||
error!("Failed to convert listener: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match listen_for_callback(listener).await {
|
||||
Ok(callback_params) => {
|
||||
// Validate state
|
||||
let stored_state = lock_r!(FDOLL).oauth_flow.state.clone();
|
||||
Ok(params) => {
|
||||
let (stored_state, stored_verifier) = {
|
||||
let guard = lock_r!(FDOLL);
|
||||
(
|
||||
guard.oauth_flow.state.clone(),
|
||||
guard.oauth_flow.code_verifier.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
if stored_state.as_ref() != Some(&callback_params.state) {
|
||||
error!("State mismatch - possible CSRF attack!");
|
||||
if stored_state.as_deref() != Some(params.state.as_str()) {
|
||||
error!("State mismatch");
|
||||
return;
|
||||
}
|
||||
|
||||
// Retrieve code_verifier
|
||||
let code_verifier = match lock_r!(FDOLL).oauth_flow.code_verifier.clone() {
|
||||
Some(cv) => cv,
|
||||
None => {
|
||||
error!("Code verifier not found in state");
|
||||
return;
|
||||
}
|
||||
let Some(code_verifier) = stored_verifier else {
|
||||
error!("Code verifier missing");
|
||||
return;
|
||||
};
|
||||
|
||||
// Clear OAuth flow state after successful callback
|
||||
lock_w!(FDOLL).oauth_flow = Default::default();
|
||||
|
||||
match exchange_code_for_auth_pass(callback_params, &code_verifier).await {
|
||||
match exchange_code_for_auth_pass(&redirect_uri_clone, params, &code_verifier).await
|
||||
{
|
||||
Ok(auth_pass) => {
|
||||
lock_w!(FDOLL).auth_pass = Some(auth_pass.clone());
|
||||
{
|
||||
let mut guard = lock_w!(FDOLL);
|
||||
guard.auth_pass = Some(auth_pass.clone());
|
||||
guard.oauth_flow = Default::default();
|
||||
}
|
||||
if let Err(e) = save_auth_pass(&auth_pass) {
|
||||
error!("Failed to save auth pass: {}", e);
|
||||
} else {
|
||||
info!("Authentication successful!");
|
||||
crate::services::ws::init_ws_client().await;
|
||||
on_success();
|
||||
return;
|
||||
}
|
||||
on_success();
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to exchange code for tokens: {}", e);
|
||||
}
|
||||
Err(e) => error!("Token exchange failed: {}", e),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to receive callback: {}", e);
|
||||
// Clear OAuth flow state on error
|
||||
lock_w!(FDOLL).oauth_flow = Default::default();
|
||||
}
|
||||
Err(e) => error!("Callback listener error: {}", e),
|
||||
}
|
||||
});
|
||||
|
||||
info!("Opening auth URL: {}", url);
|
||||
if let Err(e) = app_handle.opener().open_url(url, None::<&str>) {
|
||||
error!("Failed to open auth portal: {}", e);
|
||||
return Err(OAuthError::OpenPortalFailed(e));
|
||||
|
||||
Reference in New Issue
Block a user