diff --git a/Cargo.toml b/Cargo.toml index 5104eeb..4167265 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rsauth2-proxy" -version = "0.1.0" +version = "0.1.1" edition = "2021" [dependencies] diff --git a/src/handlers.rs b/src/handlers.rs index c273a43..e7fbfd1 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -120,15 +120,53 @@ async fn auth_inner(state: &AppState, headers: &HeaderMap) -> Response { #[derive(Deserialize)] pub struct CallbackParams { - pub code: String, - pub state: String, + pub code: Option, + pub state: Option, + pub error: Option, + pub error_description: Option, } pub async fn callback( State(state): State>, Query(params): Query, ) -> Response { - let auth_state = match decrypt_state::(&state.crypto, ¶ms.state) { + // Handle OIDC error responses (e.g. authentication_expired, access_denied) + if let Some(error) = ¶ms.error { + let desc = params.error_description.as_deref().unwrap_or(""); + tracing::warn!(error, description = desc, "OIDC provider returned error"); + metrics::counter!("callback_requests_total", "result" => "error").increment(1); + + // If we have state, try to redirect back to the original URL to retry + if let Some(state_str) = ¶ms.state { + if let Ok(auth_state) = decrypt_state::(&state.crypto, state_str) { + return redirect_to_login(&state, &auth_state.original_url); + } + } + + return ( + StatusCode::BAD_GATEWAY, + format!("Authentication failed: {} {}", error, desc), + ) + .into_response(); + } + + let code = match ¶ms.code { + Some(c) => c.as_str(), + None => { + metrics::counter!("callback_requests_total", "result" => "error").increment(1); + return (StatusCode::BAD_REQUEST, "missing code parameter").into_response(); + } + }; + + let state_str = match ¶ms.state { + Some(s) => s.as_str(), + None => { + metrics::counter!("callback_requests_total", "result" => "error").increment(1); + return (StatusCode::BAD_REQUEST, "missing state parameter").into_response(); + } + }; + + let auth_state = match decrypt_state::(&state.crypto, state_str) { Ok(s) => s, Err(e) => { tracing::warn!(error = %e, "invalid callback state"); @@ -139,7 +177,7 @@ pub async fn callback( let token_response = match state .oidc - .exchange_code(¶ms.code, &auth_state.pkce_verifier) + .exchange_code(code, &auth_state.pkce_verifier) .await { Ok(t) => t, @@ -354,7 +392,7 @@ async fn authorize_request(state: &AppState, session: &Session, host: &str) -> R Some(r) => r, None => { tracing::debug!(host, "host not found in routes, denying"); - return StatusCode::FORBIDDEN.into_response(); + return forbidden_page(host, &session.username, "This service is not registered."); } }; @@ -369,7 +407,14 @@ async fn authorize_request(state: &AppState, session: &Session, host: &str) -> R user = session.username, "user not in allowed groups" ); - return StatusCode::FORBIDDEN.into_response(); + return forbidden_page( + host, + &session.username, + &format!( + "You need to be a member of one of these groups: {}", + route.allowed_groups.join(", ") + ), + ); } } @@ -386,6 +431,92 @@ async fn authorize_request(state: &AppState, session: &Session, host: &str) -> R .into_response() } +fn forbidden_page(host: &str, username: &str, reason: &str) -> Response { + let html = format!( + r#" + + + + +403 - Access Denied + + + +
+
🚫
+

Access Denied

+
{host}
+
{reason}
+
Logged in as {username}
+
Contact your administrator if you think this is a mistake.
+
+ +"#, + host = host, + reason = reason, + username = username, + ); + + ( + StatusCode::FORBIDDEN, + [(header::CONTENT_TYPE, "text/html; charset=utf-8")], + html, + ) + .into_response() +} + fn redirect_to_login(state: &AppState, original_url: &str) -> Response { let (pkce_verifier, pkce_challenge) = generate_pkce();