minor touchups to enhance auth service
This commit is contained in:
31
src-tauri/Cargo.lock
generated
31
src-tauri/Cargo.lock
generated
@@ -47,12 +47,6 @@ version = "1.0.100"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
|
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ascii"
|
|
||||||
version = "1.1.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-broadcast"
|
name = "async-broadcast"
|
||||||
version = "0.7.2"
|
version = "0.7.2"
|
||||||
@@ -462,12 +456,6 @@ dependencies = [
|
|||||||
"windows-link 0.2.1",
|
"windows-link 0.2.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "chunked_transfer"
|
|
||||||
version = "1.5.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "combine"
|
name = "combine"
|
||||||
version = "4.6.7"
|
version = "4.6.7"
|
||||||
@@ -1072,7 +1060,6 @@ dependencies = [
|
|||||||
"tauri-plugin-opener",
|
"tauri-plugin-opener",
|
||||||
"tauri-plugin-positioner",
|
"tauri-plugin-positioner",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"tiny_http",
|
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
@@ -1611,12 +1598,6 @@ version = "1.10.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "httpdate"
|
|
||||||
version = "1.0.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper"
|
name = "hyper"
|
||||||
version = "1.7.0"
|
version = "1.7.0"
|
||||||
@@ -4423,18 +4404,6 @@ dependencies = [
|
|||||||
"time-core",
|
"time-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tiny_http"
|
|
||||||
version = "0.12.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82"
|
|
||||||
dependencies = [
|
|
||||||
"ascii",
|
|
||||||
"chunked_transfer",
|
|
||||||
"httpdate",
|
|
||||||
"log",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tinystr"
|
name = "tinystr"
|
||||||
version = "0.8.2"
|
version = "0.8.2"
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ url = "2.5.7"
|
|||||||
rand = "0.9.2"
|
rand = "0.9.2"
|
||||||
sha2 = "0.10.9"
|
sha2 = "0.10.9"
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
tiny_http = "0.12.0"
|
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = "0.3"
|
tracing-subscriber = "0.3"
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
use crate::{core::state::FDOLL, lock_r, lock_w, APP_HANDLE};
|
use crate::{core::state::FDOLL, lock_r, lock_w, APP_HANDLE};
|
||||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||||
use flate2::{write::GzEncoder, read::GzDecoder, Compression};
|
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
|
||||||
use keyring::Entry;
|
use keyring::Entry;
|
||||||
use rand::{distr::Alphanumeric, Rng};
|
use rand::{distr::Alphanumeric, Rng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::thread;
|
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||||
use tauri_plugin_opener::OpenerExt;
|
use tauri_plugin_opener::OpenerExt;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
use url::form_urlencoded;
|
use url::form_urlencoded;
|
||||||
@@ -18,6 +19,7 @@ static REFRESH_LOCK: once_cell::sync::Lazy<Mutex<()>> =
|
|||||||
once_cell::sync::Lazy::new(|| Mutex::new(()));
|
once_cell::sync::Lazy::new(|| Mutex::new(()));
|
||||||
|
|
||||||
static AUTH_SUCCESS_HTML: &str = include_str!("../../assets/auth-success.html");
|
static AUTH_SUCCESS_HTML: &str = include_str!("../../assets/auth-success.html");
|
||||||
|
const SERVICE_NAME: &str = "friendolls";
|
||||||
|
|
||||||
/// Errors that can occur during OAuth authentication flow.
|
/// Errors that can occur during OAuth authentication flow.
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
@@ -54,6 +56,9 @@ pub enum OAuthError {
|
|||||||
|
|
||||||
#[error("OAuth state expired or not initialized")]
|
#[error("OAuth state expired or not initialized")]
|
||||||
StateExpired,
|
StateExpired,
|
||||||
|
|
||||||
|
#[error("IO error: {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parameters received from the OAuth callback.
|
/// Parameters received from the OAuth callback.
|
||||||
@@ -168,42 +173,50 @@ pub fn save_auth_pass(auth_pass: &AuthPass) -> Result<(), OAuthError> {
|
|||||||
let json = serde_json::to_string(auth_pass)?;
|
let json = serde_json::to_string(auth_pass)?;
|
||||||
info!("Original JSON length: {}", json.len());
|
info!("Original JSON length: {}", json.len());
|
||||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::best());
|
let mut encoder = GzEncoder::new(Vec::new(), Compression::best());
|
||||||
encoder.write_all(json.as_bytes()).map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?;
|
encoder
|
||||||
let compressed = encoder.finish().map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?;
|
.write_all(json.as_bytes())
|
||||||
|
.map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?;
|
||||||
|
let compressed = encoder
|
||||||
|
.finish()
|
||||||
|
.map_err(|e| OAuthError::SerializationError(serde_json::Error::io(e)))?;
|
||||||
info!("Compressed length: {}", compressed.len());
|
info!("Compressed length: {}", compressed.len());
|
||||||
let encoded = URL_SAFE_NO_PAD.encode(&compressed);
|
let encoded = URL_SAFE_NO_PAD.encode(&compressed);
|
||||||
info!("Encoded length: {}", encoded.len());
|
info!("Encoded length: {}", encoded.len());
|
||||||
|
|
||||||
// Windows keyring has a 2560-byte UTF-16 limit, which means 1280 chars max
|
// Windows keyring has a 2560-byte UTF-16 limit, which means 1280 chars max
|
||||||
// Split into chunks of 1200 chars to be safe
|
// Split into chunks of 1200 chars to be safe
|
||||||
const CHUNK_SIZE: usize = 1200;
|
const CHUNK_SIZE: usize = 1200;
|
||||||
let chunks: Vec<&str> = encoded.as_bytes()
|
let chunks: Vec<&str> = encoded
|
||||||
|
.as_bytes()
|
||||||
.chunks(CHUNK_SIZE)
|
.chunks(CHUNK_SIZE)
|
||||||
.map(|chunk| std::str::from_utf8(chunk).unwrap())
|
.map(|chunk| std::str::from_utf8(chunk).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
info!("Splitting auth pass into {} chunks", chunks.len());
|
info!("Splitting auth pass into {} chunks", chunks.len());
|
||||||
|
|
||||||
// Save chunk count
|
// Save chunk count
|
||||||
let count_entry = Entry::new("friendolls", "auth_pass_count")?;
|
let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?;
|
||||||
count_entry.set_password(&chunks.len().to_string())?;
|
count_entry.set_password(&chunks.len().to_string())?;
|
||||||
|
|
||||||
// Save each chunk
|
// Save each chunk
|
||||||
for (i, chunk) in chunks.iter().enumerate() {
|
for (i, chunk) in chunks.iter().enumerate() {
|
||||||
let entry = Entry::new("friendolls", &format!("auth_pass_{}", i))?;
|
let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?;
|
||||||
entry.set_password(chunk)?;
|
entry.set_password(chunk)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Auth pass saved to keyring successfully in {} chunks", chunks.len());
|
info!(
|
||||||
|
"Auth pass saved to keyring successfully in {} chunks",
|
||||||
|
chunks.len()
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load auth_pass from secure storage (keyring).
|
/// Load auth_pass from secure storage (keyring).
|
||||||
pub fn load_auth_pass() -> Result<Option<AuthPass>, OAuthError> {
|
pub fn load_auth_pass() -> Result<Option<AuthPass>, OAuthError> {
|
||||||
info!("Reading credentials from keyring");
|
info!("Reading credentials from keyring");
|
||||||
|
|
||||||
// Get chunk count
|
// Get chunk count
|
||||||
let count_entry = Entry::new("friendolls", "auth_pass_count")?;
|
let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?;
|
||||||
let chunk_count = match count_entry.get_password() {
|
let chunk_count = match count_entry.get_password() {
|
||||||
Ok(count_str) => match count_str.parse::<usize>() {
|
Ok(count_str) => match count_str.parse::<usize>() {
|
||||||
Ok(count) => count,
|
Ok(count) => count,
|
||||||
@@ -221,13 +234,13 @@ pub fn load_auth_pass() -> Result<Option<AuthPass>, OAuthError> {
|
|||||||
return Err(OAuthError::KeyringError(e));
|
return Err(OAuthError::KeyringError(e));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("Loading {} auth pass chunks from keyring", chunk_count);
|
info!("Loading {} auth pass chunks from keyring", chunk_count);
|
||||||
|
|
||||||
// Reassemble chunks
|
// Reassemble chunks
|
||||||
let mut encoded = String::new();
|
let mut encoded = String::new();
|
||||||
for i in 0..chunk_count {
|
for i in 0..chunk_count {
|
||||||
let entry = Entry::new("friendolls", &format!("auth_pass_{}", i))?;
|
let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?;
|
||||||
match entry.get_password() {
|
match entry.get_password() {
|
||||||
Ok(chunk) => encoded.push_str(&chunk),
|
Ok(chunk) => encoded.push_str(&chunk),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -236,9 +249,9 @@ pub fn load_auth_pass() -> Result<Option<AuthPass>, OAuthError> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Reassembled encoded length: {}", encoded.len());
|
info!("Reassembled encoded length: {}", encoded.len());
|
||||||
|
|
||||||
let compressed = match URL_SAFE_NO_PAD.decode(&encoded) {
|
let compressed = match URL_SAFE_NO_PAD.decode(&encoded) {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -246,14 +259,14 @@ pub fn load_auth_pass() -> Result<Option<AuthPass>, OAuthError> {
|
|||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut decoder = GzDecoder::new(&compressed[..]);
|
let mut decoder = GzDecoder::new(&compressed[..]);
|
||||||
let mut json = String::new();
|
let mut json = String::new();
|
||||||
if let Err(e) = decoder.read_to_string(&mut json) {
|
if let Err(e) = decoder.read_to_string(&mut json) {
|
||||||
error!("Failed to decompress auth pass from keyring: {}", e);
|
error!("Failed to decompress auth pass from keyring: {}", e);
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
let auth_pass: AuthPass = match serde_json::from_str(&json) {
|
let auth_pass: AuthPass = match serde_json::from_str(&json) {
|
||||||
Ok(v) => {
|
Ok(v) => {
|
||||||
info!("Deserialized auth pass from keyring");
|
info!("Deserialized auth pass from keyring");
|
||||||
@@ -264,7 +277,7 @@ pub fn load_auth_pass() -> Result<Option<AuthPass>, OAuthError> {
|
|||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("Auth pass loaded from keyring");
|
info!("Auth pass loaded from keyring");
|
||||||
Ok(Some(auth_pass))
|
Ok(Some(auth_pass))
|
||||||
}
|
}
|
||||||
@@ -272,21 +285,21 @@ pub fn load_auth_pass() -> Result<Option<AuthPass>, OAuthError> {
|
|||||||
/// Clear auth_pass from secure storage and app state.
|
/// Clear auth_pass from secure storage and app state.
|
||||||
pub fn clear_auth_pass() -> Result<(), OAuthError> {
|
pub fn clear_auth_pass() -> Result<(), OAuthError> {
|
||||||
// Try to get chunk count
|
// Try to get chunk count
|
||||||
let count_entry = Entry::new("friendolls", "auth_pass_count")?;
|
let count_entry = Entry::new(SERVICE_NAME, "auth_pass_count")?;
|
||||||
let chunk_count = match count_entry.get_password() {
|
let chunk_count = match count_entry.get_password() {
|
||||||
Ok(count_str) => count_str.parse::<usize>().unwrap_or(0),
|
Ok(count_str) => count_str.parse::<usize>().unwrap_or(0),
|
||||||
Err(_) => 0,
|
Err(_) => 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Delete all chunks
|
// Delete all chunks
|
||||||
for i in 0..chunk_count {
|
for i in 0..chunk_count {
|
||||||
let entry = Entry::new("friendolls", &format!("auth_pass_{}", i))?;
|
let entry = Entry::new(SERVICE_NAME, &format!("auth_pass_{}", i))?;
|
||||||
let _ = entry.delete_credential();
|
let _ = entry.delete_credential();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete chunk count
|
// Delete chunk count
|
||||||
let _ = count_entry.delete_credential();
|
let _ = count_entry.delete_credential();
|
||||||
|
|
||||||
info!("Auth pass cleared from keyring successfully");
|
info!("Auth pass cleared from keyring successfully");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -485,56 +498,77 @@ where
|
|||||||
|
|
||||||
info!("Initiating OAuth flow");
|
info!("Initiating OAuth flow");
|
||||||
|
|
||||||
thread::spawn(move || {
|
// Bind the server FIRST to ensure port is open
|
||||||
tokio::runtime::Builder::new_current_thread()
|
// We bind synchronously using std::net::TcpListener then convert to tokio::net::TcpListener
|
||||||
.enable_all()
|
// to ensure the port is bound before we open the browser.
|
||||||
.build()
|
let std_listener = match std::net::TcpListener::bind(&app_config.auth.redirect_host) {
|
||||||
.unwrap()
|
Ok(s) => {
|
||||||
.block_on(async move {
|
s.set_nonblocking(true).unwrap();
|
||||||
match listen_for_callback().await {
|
s
|
||||||
Ok(callback_params) => {
|
}
|
||||||
// Validate state
|
Err(e) => {
|
||||||
let stored_state = lock_r!(FDOLL).oauth_flow.state.clone();
|
error!("Failed to bind callback server: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
if stored_state.as_ref() != Some(&callback_params.state) {
|
info!(
|
||||||
error!("State mismatch - possible CSRF attack!");
|
"Listening on {} for /callback",
|
||||||
return;
|
&app_config.auth.redirect_host
|
||||||
}
|
);
|
||||||
|
|
||||||
// Retrieve code_verifier
|
tauri::async_runtime::spawn(async move {
|
||||||
let code_verifier = match lock_r!(FDOLL).oauth_flow.code_verifier.clone() {
|
let listener = match TcpListener::from_std(std_listener) {
|
||||||
Some(cv) => cv,
|
Ok(l) => l,
|
||||||
None => {
|
Err(e) => {
|
||||||
error!("Code verifier not found in state");
|
error!("Failed to create async listener: {}", e);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Clear OAuth flow state after successful callback
|
match listen_for_callback(listener).await {
|
||||||
lock_w!(FDOLL).oauth_flow = Default::default();
|
Ok(callback_params) => {
|
||||||
|
// Validate state
|
||||||
|
let stored_state = lock_r!(FDOLL).oauth_flow.state.clone();
|
||||||
|
|
||||||
match exchange_code_for_auth_pass(callback_params, &code_verifier).await {
|
if stored_state.as_ref() != Some(&callback_params.state) {
|
||||||
Ok(auth_pass) => {
|
error!("State mismatch - possible CSRF attack!");
|
||||||
lock_w!(FDOLL).auth_pass = Some(auth_pass.clone());
|
return;
|
||||||
if let Err(e) = save_auth_pass(&auth_pass) {
|
}
|
||||||
error!("Failed to save auth pass: {}", e);
|
|
||||||
} else {
|
// Retrieve code_verifier
|
||||||
info!("Authentication successful!");
|
let code_verifier = match lock_r!(FDOLL).oauth_flow.code_verifier.clone() {
|
||||||
on_success();
|
Some(cv) => cv,
|
||||||
}
|
None => {
|
||||||
}
|
error!("Code verifier not found in state");
|
||||||
Err(e) => {
|
return;
|
||||||
error!("Failed to exchange code for tokens: {}", e);
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
// Clear OAuth flow state after successful callback
|
||||||
|
lock_w!(FDOLL).oauth_flow = Default::default();
|
||||||
|
|
||||||
|
match exchange_code_for_auth_pass(callback_params, &code_verifier).await {
|
||||||
|
Ok(auth_pass) => {
|
||||||
|
lock_w!(FDOLL).auth_pass = Some(auth_pass.clone());
|
||||||
|
if let Err(e) = save_auth_pass(&auth_pass) {
|
||||||
|
error!("Failed to save auth pass: {}", e);
|
||||||
|
} else {
|
||||||
|
info!("Authentication successful!");
|
||||||
|
on_success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to receive callback: {}", e);
|
error!("Failed to exchange code for tokens: {}", e);
|
||||||
// Clear OAuth flow state on error
|
|
||||||
lock_w!(FDOLL).oauth_flow = Default::default();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to receive callback: {}", e);
|
||||||
|
// Clear OAuth flow state on error
|
||||||
|
lock_w!(FDOLL).oauth_flow = Default::default();
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Err(e) = opener.open_url(url, None::<&str>) {
|
if let Err(e) = opener.open_url(url, None::<&str>) {
|
||||||
@@ -628,82 +662,91 @@ pub async fn refresh_token(refresh_token: &str) -> Result<AuthPass, OAuthError>
|
|||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns `OAuthError` if:
|
/// Returns `OAuthError` if:
|
||||||
/// - Server fails to bind to the configured port
|
|
||||||
/// - Required callback parameters are missing
|
/// - Required callback parameters are missing
|
||||||
/// - Timeout is reached before callback is received
|
/// - Timeout is reached before callback is received
|
||||||
async fn listen_for_callback() -> Result<OAuthCallbackParams, OAuthError> {
|
async fn listen_for_callback(listener: TcpListener) -> Result<OAuthCallbackParams, OAuthError> {
|
||||||
let app_config = lock_r!(FDOLL)
|
|
||||||
.app_config
|
|
||||||
.clone()
|
|
||||||
.ok_or(OAuthError::InvalidConfig)?;
|
|
||||||
|
|
||||||
let server = tiny_http::Server::http(&app_config.auth.redirect_host)
|
|
||||||
.map_err(|e| OAuthError::ServerBindError(e.to_string()))?;
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"Listening on {} for /callback",
|
|
||||||
&app_config.auth.redirect_host
|
|
||||||
);
|
|
||||||
|
|
||||||
// Set a 5-minute timeout
|
// Set a 5-minute timeout
|
||||||
let timeout = Duration::from_secs(300);
|
let timeout = Duration::from_secs(300);
|
||||||
let start_time = SystemTime::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
for request in server.incoming_requests() {
|
loop {
|
||||||
// Check timeout
|
let elapsed = start_time.elapsed();
|
||||||
if SystemTime::now()
|
if elapsed > timeout {
|
||||||
.duration_since(start_time)
|
|
||||||
.unwrap_or(Duration::ZERO)
|
|
||||||
> timeout
|
|
||||||
{
|
|
||||||
warn!("Callback listener timed out after 5 minutes");
|
warn!("Callback listener timed out after 5 minutes");
|
||||||
return Err(OAuthError::CallbackTimeout);
|
return Err(OAuthError::CallbackTimeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
let url = request.url().to_string();
|
let accept_result = tokio::time::timeout(timeout - elapsed, listener.accept()).await;
|
||||||
|
|
||||||
if url.starts_with("/callback") {
|
let (mut stream, _) = match accept_result {
|
||||||
let query = url.split('?').nth(1).unwrap_or("");
|
Ok(Ok(res)) => res,
|
||||||
let params = form_urlencoded::parse(query.as_bytes())
|
Ok(Err(e)) => {
|
||||||
.map(|(k, v)| (k.into_owned(), v.into_owned()))
|
warn!("Accept error: {}", e);
|
||||||
.collect::<Vec<(String, String)>>();
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!("Callback listener timed out after 5 minutes");
|
||||||
|
return Err(OAuthError::CallbackTimeout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
info!("Received OAuth callback");
|
let mut buffer = [0; 4096];
|
||||||
|
let n = match stream.read(&mut buffer).await {
|
||||||
|
Ok(n) if n > 0 => n,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
let find_param = |key: &str| -> Result<String, OAuthError> {
|
let request = String::from_utf8_lossy(&buffer[..n]);
|
||||||
params
|
let first_line = request.lines().next().unwrap_or("");
|
||||||
.iter()
|
let mut parts = first_line.split_whitespace();
|
||||||
.find(|(k, _)| k == key)
|
|
||||||
.map(|(_, v)| v.clone())
|
|
||||||
.ok_or_else(|| OAuthError::MissingParameter(key.to_string()))
|
|
||||||
};
|
|
||||||
|
|
||||||
let callback_params = OAuthCallbackParams {
|
match (parts.next(), parts.next()) {
|
||||||
state: find_param("state")?,
|
(Some("GET"), Some(path)) if path.starts_with("/callback") => {
|
||||||
session_state: find_param("session_state")?,
|
let full_url = format!("http://localhost{}", path);
|
||||||
iss: find_param("iss")?,
|
let url = match url::Url::parse(&full_url) {
|
||||||
code: find_param("code")?,
|
Ok(u) => u,
|
||||||
};
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
let response = tiny_http::Response::from_string(AUTH_SUCCESS_HTML).with_header(
|
let params: std::collections::HashMap<_, _> =
|
||||||
tiny_http::Header::from_bytes(
|
url.query_pairs().into_owned().collect();
|
||||||
&b"Content-Type"[..],
|
|
||||||
&b"text/html; charset=utf-8"[..],
|
|
||||||
)
|
|
||||||
.map_err(|_| OAuthError::ServerBindError("Header creation failed".to_string()))?,
|
|
||||||
);
|
|
||||||
|
|
||||||
let _ = request.respond(response);
|
info!("Received OAuth callback");
|
||||||
|
|
||||||
info!("Callback processed, stopping listener");
|
let find_param = |key: &str| -> Result<String, OAuthError> {
|
||||||
return Ok(callback_params);
|
params
|
||||||
} else if url == "/health" {
|
.get(key)
|
||||||
// Health check endpoint
|
.cloned()
|
||||||
let _ = request.respond(tiny_http::Response::from_string("OK"));
|
.ok_or_else(|| OAuthError::MissingParameter(key.to_string()))
|
||||||
} else {
|
};
|
||||||
let _ = request.respond(tiny_http::Response::empty(404));
|
|
||||||
|
let callback_params = OAuthCallbackParams {
|
||||||
|
state: find_param("state")?,
|
||||||
|
session_state: find_param("session_state")?,
|
||||||
|
iss: find_param("iss")?,
|
||||||
|
code: find_param("code")?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = format!(
|
||||||
|
"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\n\r\n{}",
|
||||||
|
AUTH_SUCCESS_HTML.len(),
|
||||||
|
AUTH_SUCCESS_HTML
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = stream.write_all(response.as_bytes()).await;
|
||||||
|
let _ = stream.flush().await;
|
||||||
|
|
||||||
|
info!("Callback processed, stopping listener");
|
||||||
|
return Ok(callback_params);
|
||||||
|
}
|
||||||
|
(Some("GET"), Some("/health")) => {
|
||||||
|
let response = "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK";
|
||||||
|
let _ = stream.write_all(response.as_bytes()).await;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let response = "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
|
||||||
|
let _ = stream.write_all(response.as_bytes()).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(OAuthError::CallbackTimeout)
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user