use axum::{ body::Body, extract::{Request, State}, http::{header, HeaderMap, StatusCode}, middleware::Next, response::{IntoResponse, Redirect, Response}, }; use openidconnect::{ core::{CoreClient, CoreProviderMetadata, CoreResponseType}, reqwest::async_http_client, AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, }; use rand::RngCore; use serde::Deserialize; use base64::Engine; use hmac::{Hmac, Mac}; use jsonwebtoken::{decode, decode_header, DecodingKey, Validation as JwtValidation}; use jsonwebtoken::jwk::JwkSet; use std::time::{Duration, Instant}; use tokio::sync::RwLock; use super::AppState; use std::sync::Arc; const SESSION_COOKIE: &str = "furumi_session"; const JWKS_CACHE_TTL: Duration = Duration::from_secs(3600); type HmacSha256 = Hmac; pub struct OidcState { pub client: CoreClient, pub session_secret: Vec, jwks_uri: String, issuer_url: String, jwks_cache: RwLock>, http_client: reqwest::Client, } pub async fn oidc_init( issuer: String, client_id: String, client_secret: String, redirect: String, session_secret_override: Option, ) -> anyhow::Result { let provider_metadata = CoreProviderMetadata::discover_async( IssuerUrl::new(issuer)?, async_http_client, ) .await?; let jwks_uri = provider_metadata.jwks_uri().to_string(); let issuer_url = provider_metadata.issuer().to_string(); let client = CoreClient::from_provider_metadata( provider_metadata, ClientId::new(client_id), Some(ClientSecret::new(client_secret)), ) .set_auth_type(openidconnect::AuthType::RequestBody) .set_redirect_uri(RedirectUrl::new(redirect)?); let session_secret = if let Some(s) = session_secret_override { let mut b = s.into_bytes(); b.resize(32, 0); b } else { let mut b = vec![0u8; 32]; rand::thread_rng().fill_bytes(&mut b); b }; let http_client = reqwest::Client::new(); tracing::info!("JWKS URI: {}", jwks_uri); Ok(OidcState { client, session_secret, jwks_uri, issuer_url, jwks_cache: RwLock::new(None), http_client, }) } impl OidcState { async fn get_jwks(&self) -> anyhow::Result { { let cache = self.jwks_cache.read().await; if let Some((ref jwks, fetched_at)) = *cache { if fetched_at.elapsed() < JWKS_CACHE_TTL { return Ok(jwks.clone()); } } } self.refresh_jwks().await } async fn refresh_jwks(&self) -> anyhow::Result { tracing::debug!("Fetching JWKS from {}", self.jwks_uri); let jwks: JwkSet = self.http_client.get(&self.jwks_uri).send().await?.json().await?; let mut cache = self.jwks_cache.write().await; *cache = Some((jwks.clone(), Instant::now())); Ok(jwks) } } #[derive(Debug, Clone)] pub struct AuthUser { pub id: String, pub username: String, pub display_name: Option, pub email: Option, } #[derive(Debug, serde::Deserialize)] struct BearerClaims { sub: String, preferred_username: Option, name: Option, email: Option, } async fn validate_bearer_token(oidc: &OidcState, token: &str) -> Option { let header = decode_header(token).ok()?; let kid = header.kid.as_ref()?; let mut jwks = oidc.get_jwks().await.ok()?; let mut jwk = jwks.find(kid); // Handle key rotation: refresh JWKS if kid not found if jwk.is_none() { jwks = oidc.refresh_jwks().await.ok()?; jwk = jwks.find(kid); } let key = DecodingKey::from_jwk(jwk?).ok()?; let mut validation = JwtValidation::new(header.alg); validation.set_issuer(&[&oidc.issuer_url]); validation.validate_aud = false; let data = decode::(token, &key, &validation).ok()?; let c = data.claims; Some(AuthUser { id: c.sub.clone(), username: c.preferred_username.unwrap_or(c.sub), display_name: c.name, email: c.email, }) } fn generate_sso_cookie(secret: &[u8], user_id: &str) -> String { let mut mac = HmacSha256::new_from_slice(secret).unwrap(); mac.update(user_id.as_bytes()); let sig = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes()); format!("sso:{}:{}", user_id, sig) } fn verify_sso_cookie(secret: &[u8], cookie_val: &str) -> Option { let parts: Vec<&str> = cookie_val.split(':').collect(); if parts.len() != 3 || parts[0] != "sso" { return None; } let user_id = parts[1]; let sig = parts[2]; let mut mac = HmacSha256::new_from_slice(secret).unwrap(); mac.update(user_id.as_bytes()); let expected_sig = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes()); if sig == expected_sig { Some(user_id.to_string()) } else { None } } /// Auth middleware: requires valid Bearer JWT or SSO session cookie. /// Inserts AuthUser into request extensions and upserts user in DB. pub async fn require_auth( State(state): State>, mut req: Request, next: Next, ) -> Response { let mut auth_user: Option = None; // 1. Check Bearer token — JWT from OIDC provider if let Some(ref oidc) = state.oidc { if let Some(token) = req .headers() .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()) .and_then(|v| v.strip_prefix("Bearer ")) { auth_user = validate_bearer_token(oidc, token).await; } } // 2. Check SSO session cookie (if OIDC configured) if auth_user.is_none() { if let Some(ref oidc) = state.oidc { let cookies = req .headers() .get(header::COOKIE) .and_then(|v| v.to_str().ok()) .unwrap_or(""); for c in cookies.split(';') { let c = c.trim(); if let Some(val) = c.strip_prefix(&format!("{}=", SESSION_COOKIE)) { if let Some(user_id) = verify_sso_cookie(&oidc.session_secret, val) { auth_user = Some(AuthUser { id: user_id.clone(), username: user_id, display_name: None, email: None, }); break; } } } } } match auth_user { Some(user) => { tracing::debug!("Auth OK for user: {}", user.username); // Upsert user in background let pool = state.pool.clone(); let u = user.clone(); tokio::spawn(async move { if let Err(e) = crate::db::upsert_user( &pool, &u.id, &u.username, u.display_name.as_deref(), u.email.as_deref(), ).await { tracing::warn!("Failed to upsert user: {}", e); } }); req.extensions_mut().insert(user); next.run(req).await } None => (StatusCode::UNAUTHORIZED, "Unauthorized").into_response(), } } #[derive(Deserialize)] pub struct LoginQuery { pub next: Option, } /// GET /auth/login — initiate OIDC flow. pub async fn oidc_login( State(state): State>, axum::extract::Query(query): axum::extract::Query, req: Request, ) -> impl IntoResponse { let oidc = match &state.oidc { Some(o) => o, None => return Redirect::to("/").into_response(), }; let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (auth_url, csrf_token, nonce) = oidc .client .authorize_url( AuthenticationFlow::::AuthorizationCode, CsrfToken::new_random, Nonce::new_random, ) .add_scope(Scope::new("openid".to_string())) .add_scope(Scope::new("profile".to_string())) .set_pkce_challenge(pkce_challenge) .url(); let next_url = query.next.unwrap_or_else(|| "/".to_string()); let cookie_val = format!( "{}:{}:{}:{}", csrf_token.secret(), nonce.secret(), pkce_verifier.secret(), urlencoding::encode(&next_url) ); let is_https = req .headers() .get("x-forwarded-proto") .and_then(|v| v.to_str().ok()) .map(|s| s == "https") .unwrap_or(false); let cookie_attrs = if is_https { "SameSite=None; Secure" } else { "SameSite=Lax" }; let cookie = format!( "furumi_oidc_state={}; HttpOnly; {}; Path=/; Max-Age=3600", cookie_val, cookie_attrs ); let mut headers = HeaderMap::new(); headers.insert(header::SET_COOKIE, cookie.parse().unwrap()); headers.insert(header::LOCATION, auth_url.as_str().parse().unwrap()); headers.insert( header::CACHE_CONTROL, "no-store, no-cache, must-revalidate".parse().unwrap(), ); (StatusCode::FOUND, headers, Body::empty()).into_response() } #[derive(Deserialize)] pub struct AuthCallbackQuery { code: String, state: String, } /// GET /auth/callback — handle OIDC callback. pub async fn oidc_callback( State(state): State>, axum::extract::Query(query): axum::extract::Query, req: Request, ) -> impl IntoResponse { let oidc = match &state.oidc { Some(o) => o, None => return Redirect::to("/").into_response(), }; let cookies = req .headers() .get(header::COOKIE) .and_then(|v| v.to_str().ok()) .unwrap_or(""); let mut matching_val = None; for c in cookies.split(';') { let c = c.trim(); if let Some(val) = c.strip_prefix("furumi_oidc_state=") { let parts: Vec<&str> = val.split(':').collect(); if parts.len() >= 3 && parts[0] == query.state { matching_val = Some(val.to_string()); break; } } } let cookie_val = match matching_val { Some(c) => c, None => { tracing::warn!("OIDC callback: invalid state or missing cookie"); return (StatusCode::BAD_REQUEST, "Invalid state").into_response(); } }; let parts: Vec<&str> = cookie_val.split(':').collect(); let nonce = Nonce::new(parts[1].to_string()); let pkce_verifier = PkceCodeVerifier::new(parts[2].to_string()); let token_response = oidc .client .exchange_code(AuthorizationCode::new(query.code)) .set_pkce_verifier(pkce_verifier) .request_async(async_http_client) .await; let token_response = match token_response { Ok(tr) => tr, Err(e) => { tracing::error!("OIDC token exchange error: {:?}", e); return (StatusCode::INTERNAL_SERVER_ERROR, format!("OIDC error: {}", e)) .into_response(); } }; let id_token = match token_response.id_token() { Some(t) => t, None => { return (StatusCode::INTERNAL_SERVER_ERROR, "No ID token").into_response(); } }; let claims = match id_token.claims(&oidc.client.id_token_verifier(), &nonce) { Ok(c) => c, Err(e) => { return (StatusCode::UNAUTHORIZED, format!("Invalid ID token: {}", e)).into_response(); } }; let user_id = claims .preferred_username() .map(|u| u.to_string()) .or_else(|| claims.email().map(|e| e.to_string())) .unwrap_or_else(|| claims.subject().to_string()); let session_val = generate_sso_cookie(&oidc.session_secret, &user_id); let redirect_to = parts .get(3) .and_then(|&s| urlencoding::decode(s).ok()) .map(|v| v.into_owned()) .unwrap_or_else(|| "/".to_string()); let redirect_to = if redirect_to.is_empty() { "/".to_string() } else { redirect_to }; let is_https = req .headers() .get("x-forwarded-proto") .and_then(|v| v.to_str().ok()) .map(|s| s == "https") .unwrap_or(false); let session_attrs = if is_https { "SameSite=Lax; Secure" } else { "SameSite=Lax" }; let session_cookie = format!( "{}={}; HttpOnly; {}; Path=/; Max-Age=604800", SESSION_COOKIE, session_val, session_attrs ); let clear_state = "furumi_oidc_state=; HttpOnly; Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT"; let mut headers = HeaderMap::new(); headers.insert(header::SET_COOKIE, session_cookie.parse().unwrap()); headers.append(header::SET_COOKIE, clear_state.parse().unwrap()); headers.insert(header::LOCATION, redirect_to.parse().unwrap()); (StatusCode::FOUND, headers, Body::empty()).into_response() }