This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
use anyhow::{Context, Result};
|
||||
use base64::{engine::general_purpose::STANDARD, Engine};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
pub listen: String,
|
||||
pub oidc_issuer: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub cookie_secret: [u8; 32],
|
||||
pub cookie_domain: String,
|
||||
pub routes_file: String,
|
||||
pub callback_url: String,
|
||||
pub log_level: String,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn from_env() -> Result<Self> {
|
||||
let cookie_secret_b64 =
|
||||
std::env::var("AUTH_PROXY_COOKIE_SECRET").context("AUTH_PROXY_COOKIE_SECRET required")?;
|
||||
let decoded = STANDARD
|
||||
.decode(&cookie_secret_b64)
|
||||
.context("invalid base64 in AUTH_PROXY_COOKIE_SECRET")?;
|
||||
let cookie_secret: [u8; 32] = decoded
|
||||
.try_into()
|
||||
.map_err(|v: Vec<u8>| anyhow::anyhow!("cookie secret must be 32 bytes, got {}", v.len()))?;
|
||||
|
||||
Ok(Self {
|
||||
listen: std::env::var("AUTH_PROXY_LISTEN").unwrap_or_else(|_| "0.0.0.0:8080".into()),
|
||||
oidc_issuer: std::env::var("AUTH_PROXY_OIDC_ISSUER")
|
||||
.context("AUTH_PROXY_OIDC_ISSUER required")?,
|
||||
client_id: std::env::var("AUTH_PROXY_CLIENT_ID")
|
||||
.context("AUTH_PROXY_CLIENT_ID required")?,
|
||||
client_secret: std::env::var("AUTH_PROXY_CLIENT_SECRET")
|
||||
.context("AUTH_PROXY_CLIENT_SECRET required")?,
|
||||
cookie_secret,
|
||||
cookie_domain: std::env::var("AUTH_PROXY_COOKIE_DOMAIN")
|
||||
.context("AUTH_PROXY_COOKIE_DOMAIN required")?,
|
||||
routes_file: std::env::var("AUTH_PROXY_ROUTES_FILE")
|
||||
.unwrap_or_else(|_| "/config/routes.yaml".into()),
|
||||
callback_url: std::env::var("AUTH_PROXY_CALLBACK_URL")
|
||||
.context("AUTH_PROXY_CALLBACK_URL required")?,
|
||||
log_level: std::env::var("AUTH_PROXY_LOG_LEVEL").unwrap_or_else(|_| "info".into()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn base_url(&self) -> &str {
|
||||
self.callback_url
|
||||
.strip_suffix("/callback")
|
||||
.unwrap_or(&self.callback_url)
|
||||
}
|
||||
}
|
||||
+107
@@ -0,0 +1,107 @@
|
||||
use aes_gcm::aead::{Aead, KeyInit};
|
||||
use aes_gcm::{Aes256Gcm, Nonce};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||
use rand::RngCore;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CookieCrypto {
|
||||
cipher: Aes256Gcm,
|
||||
}
|
||||
|
||||
impl CookieCrypto {
|
||||
pub fn new(key_bytes: &[u8; 32]) -> Self {
|
||||
Self {
|
||||
cipher: Aes256Gcm::new_from_slice(key_bytes).expect("valid 32-byte key"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt(&self, plaintext: &[u8]) -> anyhow::Result<String> {
|
||||
let mut nonce_bytes = [0u8; 12];
|
||||
rand::thread_rng().fill_bytes(&mut nonce_bytes);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
let ciphertext = self
|
||||
.cipher
|
||||
.encrypt(nonce, plaintext)
|
||||
.map_err(|e| anyhow::anyhow!("encryption failed: {}", e))?;
|
||||
|
||||
let mut result = Vec::with_capacity(12 + ciphertext.len());
|
||||
result.extend_from_slice(&nonce_bytes);
|
||||
result.extend_from_slice(&ciphertext);
|
||||
|
||||
Ok(URL_SAFE_NO_PAD.encode(&result))
|
||||
}
|
||||
|
||||
pub fn decrypt(&self, encoded: &str) -> anyhow::Result<Vec<u8>> {
|
||||
let data = URL_SAFE_NO_PAD.decode(encoded)?;
|
||||
if data.len() < 13 {
|
||||
anyhow::bail!("ciphertext too short");
|
||||
}
|
||||
let (nonce_bytes, ciphertext) = data.split_at(12);
|
||||
let nonce = Nonce::from_slice(nonce_bytes);
|
||||
|
||||
self.cipher
|
||||
.decrypt(nonce, ciphertext)
|
||||
.map_err(|e| anyhow::anyhow!("decryption failed: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_roundtrip() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
let plaintext = b"hello world";
|
||||
let encrypted = crypto.encrypt(plaintext).unwrap();
|
||||
let decrypted = crypto.decrypt(&encrypted).unwrap();
|
||||
assert_eq!(decrypted, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_produces_different_ciphertext_each_time() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
let a = crypto.encrypt(b"same").unwrap();
|
||||
let b = crypto.encrypt(b"same").unwrap();
|
||||
assert_ne!(a, b); // different nonces
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_key_fails() {
|
||||
let crypto1 = CookieCrypto::new(&[0x42; 32]);
|
||||
let crypto2 = CookieCrypto::new(&[0x43; 32]);
|
||||
let encrypted = crypto1.encrypt(b"hello").unwrap();
|
||||
assert!(crypto2.decrypt(&encrypted).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tampered_ciphertext_fails() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
let encrypted = crypto.encrypt(b"hello").unwrap();
|
||||
let mut data = URL_SAFE_NO_PAD.decode(&encrypted).unwrap();
|
||||
*data.last_mut().unwrap() ^= 0xFF;
|
||||
let tampered = URL_SAFE_NO_PAD.encode(&data);
|
||||
assert!(crypto.decrypt(&tampered).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_plaintext_roundtrip() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
let encrypted = crypto.encrypt(b"").unwrap();
|
||||
let decrypted = crypto.decrypt(&encrypted).unwrap();
|
||||
assert!(decrypted.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn short_ciphertext_rejected() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
assert!(crypto.decrypt("dG9vc2hvcnQ").is_err()); // "tooshort" base64
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_base64_rejected() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
assert!(crypto.decrypt("not valid base64!!!").is_err());
|
||||
}
|
||||
}
|
||||
+1160
File diff suppressed because it is too large
Load Diff
+200
@@ -0,0 +1,200 @@
|
||||
mod config;
|
||||
mod crypto;
|
||||
mod handlers;
|
||||
mod oidc;
|
||||
mod routes;
|
||||
mod session;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::crypto::CookieCrypto;
|
||||
use crate::handlers::AppState;
|
||||
use axum::routing::get;
|
||||
use axum::Router;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
const HELP: &str = "\
|
||||
rsauth2-proxy — Auth proxy for Traefik ForwardAuth with Keycloak OIDC
|
||||
|
||||
USAGE:
|
||||
rsauth2-proxy Start the proxy (configured via environment variables)
|
||||
rsauth2-proxy --help Show this help
|
||||
rsauth2-proxy --version Show version
|
||||
|
||||
ENVIRONMENT VARIABLES:
|
||||
AUTH_PROXY_OIDC_ISSUER (required) OIDC issuer URL
|
||||
Example: https://auth.example.com/realms/main
|
||||
AUTH_PROXY_CLIENT_ID (required) Keycloak client ID
|
||||
AUTH_PROXY_CLIENT_SECRET (required) Keycloak client secret
|
||||
AUTH_PROXY_COOKIE_SECRET (required) AES-256 encryption key, 32 bytes, base64-encoded
|
||||
Generate with: openssl rand -base64 32
|
||||
AUTH_PROXY_COOKIE_DOMAIN (required) Cookie domain, must cover all protected hosts
|
||||
Example: .example.com
|
||||
AUTH_PROXY_CALLBACK_URL (required) Full public URL for the OIDC callback endpoint
|
||||
Example: https://auth-proxy.example.com/callback
|
||||
AUTH_PROXY_LISTEN (optional) Listen address [default: 0.0.0.0:8080]
|
||||
AUTH_PROXY_ROUTES_FILE (optional) Path to routes config [default: /config/routes.yaml]
|
||||
AUTH_PROXY_LOG_LEVEL (optional) Log level: debug, info, warn, error [default: info]
|
||||
|
||||
ENDPOINTS:
|
||||
GET /auth ForwardAuth endpoint (called by Traefik for every request)
|
||||
GET /callback OIDC callback (receives authorization code from Keycloak)
|
||||
GET /refresh Token refresh (transparent redirect when access token expires)
|
||||
GET /sign_out Logout (clears cookie, redirects to Keycloak end_session)
|
||||
GET /health Health check (returns 200 OK)
|
||||
GET /metrics Prometheus metrics (text exposition format)
|
||||
|
||||
ROUTES FILE (routes.yaml):
|
||||
routes:
|
||||
grafana.example.com:
|
||||
allowed_groups: [\"admins\", \"developers\"] # user must be in at least one group
|
||||
wiki.example.com:
|
||||
allowed_groups: [] # any authenticated user
|
||||
# Hosts not listed are denied (403)
|
||||
|
||||
The file is polled every 5 seconds and reloaded on change.
|
||||
Works with Kubernetes ConfigMap volume mounts (symlink-based updates).
|
||||
|
||||
RESPONSE HEADERS (forwarded to upstream on 200):
|
||||
X-Auth-Request-User preferred_username from the ID token
|
||||
X-Auth-Request-Email email from the ID token
|
||||
X-Auth-Request-Groups comma-separated list of groups
|
||||
|
||||
KEYCLOAK SETUP:
|
||||
Groups must be included in the ID token via a \"Group Membership\" protocol mapper
|
||||
with claim name \"groups\" and full path disabled.
|
||||
|
||||
TRAEFIK MIDDLEWARE:
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: auth-proxy
|
||||
spec:
|
||||
forwardAuth:
|
||||
address: http://auth-proxy.auth-proxy.svc:80/auth
|
||||
trustForwardHeader: true
|
||||
authResponseHeaders:
|
||||
- X-Auth-Request-User
|
||||
- X-Auth-Request-Email
|
||||
- X-Auth-Request-Groups
|
||||
|
||||
SECURITY:
|
||||
- Sessions encrypted with AES-256-GCM (not just signed)
|
||||
- PKCE S256 on all authorization flows
|
||||
- JWT validated against Keycloak JWKS (keys refreshed hourly)
|
||||
- No open redirects (redirect URL encrypted in state parameter)
|
||||
- Deny by default (unlisted hosts get 403)
|
||||
- Fully stateless (all state in encrypted cookies), supports multiple replicas
|
||||
- Graceful shutdown on SIGTERM/SIGINT
|
||||
|
||||
SOURCE: https://github.com/ab/rsauth2-proxy";
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.iter().any(|a| a == "--help" || a == "-h") {
|
||||
println!("{}", HELP);
|
||||
return Ok(());
|
||||
}
|
||||
if args.iter().any(|a| a == "--version" || a == "-V") {
|
||||
println!("rsauth2-proxy {}", VERSION);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let config = Config::from_env()?;
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_new(&config.log_level)
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let metrics_handle = metrics_exporter_prometheus::PrometheusBuilder::new()
|
||||
.install_recorder()
|
||||
.expect("failed to install prometheus recorder");
|
||||
|
||||
metrics::describe_counter!("auth_requests_total", "Total auth check requests");
|
||||
metrics::describe_histogram!(
|
||||
"auth_request_duration_seconds",
|
||||
"Auth check request duration in seconds"
|
||||
);
|
||||
metrics::describe_counter!("callback_requests_total", "Total OIDC callback requests");
|
||||
metrics::describe_counter!("refresh_requests_total", "Total token refresh requests");
|
||||
metrics::describe_gauge!("routes_loaded_total", "Number of loaded routes");
|
||||
|
||||
tracing::info!(listen = %config.listen, "starting rsauth2-proxy");
|
||||
|
||||
let oidc = oidc::OidcClient::discover(
|
||||
&config.oidc_issuer,
|
||||
config.client_id.clone(),
|
||||
config.client_secret.clone(),
|
||||
config.callback_url.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let initial_routes = routes::load_routes(&config.routes_file)?;
|
||||
let shared_routes: routes::SharedRoutes = Arc::new(RwLock::new(initial_routes));
|
||||
|
||||
let crypto = CookieCrypto::new(&config.cookie_secret);
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
config: config.clone(),
|
||||
oidc,
|
||||
routes: shared_routes.clone(),
|
||||
crypto,
|
||||
metrics_handle,
|
||||
});
|
||||
|
||||
// Background: watch routes file for changes
|
||||
let routes_path = config.routes_file.clone();
|
||||
let routes_ref = shared_routes.clone();
|
||||
tokio::spawn(async move {
|
||||
routes::watch_routes(routes_path, routes_ref).await;
|
||||
});
|
||||
|
||||
// Background: periodically refresh JWKS keys
|
||||
let oidc_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(3600));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if let Err(e) = oidc_state.oidc.refresh_jwks().await {
|
||||
tracing::warn!(error = %e, "failed to refresh JWKS");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(handlers::root))
|
||||
.route("/auth", get(handlers::auth))
|
||||
.route("/callback", get(handlers::callback))
|
||||
.route("/refresh", get(handlers::refresh))
|
||||
.route("/sign_out", get(handlers::sign_out))
|
||||
.route("/health", get(handlers::health))
|
||||
.route("/metrics", get(handlers::metrics_handler))
|
||||
.with_state(state);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&config.listen).await?;
|
||||
tracing::info!(addr = %config.listen, "listening");
|
||||
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await?;
|
||||
|
||||
tracing::info!("shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = tokio::signal::ctrl_c();
|
||||
let mut sigterm =
|
||||
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()).unwrap();
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {},
|
||||
_ = sigterm.recv() => {},
|
||||
}
|
||||
tracing::info!("shutdown signal received");
|
||||
}
|
||||
+223
@@ -0,0 +1,223 @@
|
||||
use anyhow::{Context, Result};
|
||||
use jsonwebtoken::{decode, decode_header, DecodingKey, TokenData, Validation};
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct OidcDiscovery {
|
||||
pub issuer: String,
|
||||
pub authorization_endpoint: String,
|
||||
pub token_endpoint: String,
|
||||
pub jwks_uri: String,
|
||||
pub end_session_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct JwkKey {
|
||||
pub kid: Option<String>,
|
||||
pub kty: String,
|
||||
pub n: Option<String>,
|
||||
pub e: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct JwksResponse {
|
||||
keys: Vec<JwkKey>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct TokenResponse {
|
||||
pub id_token: Option<String>,
|
||||
pub refresh_token: Option<String>,
|
||||
pub expires_in: Option<u64>,
|
||||
pub refresh_expires_in: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct IdTokenClaims {
|
||||
pub sub: String,
|
||||
pub preferred_username: Option<String>,
|
||||
pub email: Option<String>,
|
||||
pub groups: Option<Vec<String>>,
|
||||
pub exp: i64,
|
||||
pub iat: i64,
|
||||
pub nonce: Option<String>,
|
||||
}
|
||||
|
||||
pub struct OidcClient {
|
||||
http: Client,
|
||||
pub discovery: OidcDiscovery,
|
||||
jwks: Arc<RwLock<Vec<JwkKey>>>,
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
redirect_uri: String,
|
||||
}
|
||||
|
||||
impl OidcClient {
|
||||
pub async fn discover(
|
||||
issuer: &str,
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
redirect_uri: String,
|
||||
) -> Result<Self> {
|
||||
let http = Client::new();
|
||||
let discovery_url = format!(
|
||||
"{}/.well-known/openid-configuration",
|
||||
issuer.trim_end_matches('/')
|
||||
);
|
||||
let discovery: OidcDiscovery = http
|
||||
.get(&discovery_url)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()
|
||||
.context("OIDC discovery request failed")?
|
||||
.json()
|
||||
.await
|
||||
.context("failed to parse OIDC discovery")?;
|
||||
|
||||
let jwks: JwksResponse = http
|
||||
.get(&discovery.jwks_uri)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()
|
||||
.context("JWKS request failed")?
|
||||
.json()
|
||||
.await
|
||||
.context("failed to parse JWKS")?;
|
||||
|
||||
tracing::info!(
|
||||
keys = jwks.keys.len(),
|
||||
"OIDC discovery complete"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
http,
|
||||
discovery,
|
||||
jwks: Arc::new(RwLock::new(jwks.keys)),
|
||||
client_id,
|
||||
client_secret,
|
||||
redirect_uri,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn auth_url(&self, state: &str, pkce_challenge: &str, nonce: &str) -> String {
|
||||
let mut url = url::Url::parse(&self.discovery.authorization_endpoint)
|
||||
.expect("valid authorization_endpoint URL");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("client_id", &self.client_id)
|
||||
.append_pair("redirect_uri", &self.redirect_uri)
|
||||
.append_pair("response_type", "code")
|
||||
.append_pair("scope", "openid profile email")
|
||||
.append_pair("state", state)
|
||||
.append_pair("nonce", nonce)
|
||||
.append_pair("code_challenge", pkce_challenge)
|
||||
.append_pair("code_challenge_method", "S256");
|
||||
url.to_string()
|
||||
}
|
||||
|
||||
pub async fn exchange_code(&self, code: &str, pkce_verifier: &str) -> Result<TokenResponse> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(&self.discovery.token_endpoint)
|
||||
.form(&[
|
||||
("grant_type", "authorization_code"),
|
||||
("code", code),
|
||||
("redirect_uri", self.redirect_uri.as_str()),
|
||||
("client_id", self.client_id.as_str()),
|
||||
("client_secret", self.client_secret.as_str()),
|
||||
("code_verifier", pkce_verifier),
|
||||
])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("token exchange failed ({}): {}", status, body);
|
||||
}
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(&self.discovery.token_endpoint)
|
||||
.form(&[
|
||||
("grant_type", "refresh_token"),
|
||||
("refresh_token", refresh_token),
|
||||
("client_id", self.client_id.as_str()),
|
||||
("client_secret", self.client_secret.as_str()),
|
||||
])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("token refresh failed ({}): {}", status, body);
|
||||
}
|
||||
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn validate_id_token(
|
||||
&self,
|
||||
token: &str,
|
||||
expected_nonce: Option<&str>,
|
||||
) -> Result<IdTokenClaims> {
|
||||
let header = decode_header(token)?;
|
||||
let kid = header.kid.as_deref();
|
||||
|
||||
let jwks = self.jwks.read().await;
|
||||
let key = if let Some(kid) = kid {
|
||||
jwks.iter().find(|k| k.kid.as_deref() == Some(kid))
|
||||
} else {
|
||||
jwks.first()
|
||||
}
|
||||
.context("no matching JWK found")?;
|
||||
|
||||
let decoding_key = DecodingKey::from_rsa_components(
|
||||
key.n.as_deref().context("missing 'n' in JWK")?,
|
||||
key.e.as_deref().context("missing 'e' in JWK")?,
|
||||
)?;
|
||||
|
||||
let mut validation = Validation::new(header.alg);
|
||||
validation.set_issuer(&[&self.discovery.issuer]);
|
||||
validation.set_audience(&[&self.client_id]);
|
||||
|
||||
let token_data: TokenData<IdTokenClaims> = decode(token, &decoding_key, &validation)?;
|
||||
|
||||
if let Some(expected) = expected_nonce {
|
||||
if token_data.claims.nonce.as_deref() != Some(expected) {
|
||||
anyhow::bail!("nonce mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
|
||||
pub async fn refresh_jwks(&self) -> Result<()> {
|
||||
let jwks: JwksResponse = self
|
||||
.http
|
||||
.get(&self.discovery.jwks_uri)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
let count = jwks.keys.len();
|
||||
*self.jwks.write().await = jwks.keys;
|
||||
tracing::debug!(keys = count, "JWKS refreshed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn end_session_url(&self) -> Option<&str> {
|
||||
self.discovery.end_session_endpoint.as_deref()
|
||||
}
|
||||
}
|
||||
+157
@@ -0,0 +1,157 @@
|
||||
use anyhow::Result;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct RouteConfig {
|
||||
pub allowed_groups: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RoutesFile {
|
||||
pub routes: HashMap<String, RouteConfig>,
|
||||
}
|
||||
|
||||
pub type SharedRoutes = Arc<RwLock<HashMap<String, RouteConfig>>>;
|
||||
|
||||
pub fn load_routes(path: &str) -> Result<HashMap<String, RouteConfig>> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let routes_file: RoutesFile = serde_yaml::from_str(&content)?;
|
||||
let count = routes_file.routes.len();
|
||||
tracing::info!(count, "routes loaded");
|
||||
metrics::gauge!("routes_loaded_total").set(count as f64);
|
||||
Ok(routes_file.routes)
|
||||
}
|
||||
|
||||
pub async fn watch_routes(path: String, routes: SharedRoutes) {
|
||||
let mut last_content = std::fs::read_to_string(&path).unwrap_or_default();
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Read through the symlink — k8s ConfigMap updates swap the symlink target,
|
||||
// so reading the original path always gives current content.
|
||||
match std::fs::read_to_string(&path) {
|
||||
Ok(content) if content != last_content => match serde_yaml::from_str::<RoutesFile>(&content) {
|
||||
Ok(routes_file) => {
|
||||
let count = routes_file.routes.len();
|
||||
*routes.write().await = routes_file.routes;
|
||||
last_content = content;
|
||||
metrics::gauge!("routes_loaded_total").set(count as f64);
|
||||
tracing::info!(count, "routes reloaded");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to parse routes file");
|
||||
}
|
||||
},
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to read routes file");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_routes_yaml() {
|
||||
let yaml = r#"
|
||||
routes:
|
||||
grafana.example.com:
|
||||
allowed_groups: ["admins", "developers"]
|
||||
wiki.example.com:
|
||||
allowed_groups: []
|
||||
"#;
|
||||
let rf: RoutesFile = serde_yaml::from_str(yaml).unwrap();
|
||||
assert_eq!(rf.routes.len(), 2);
|
||||
assert_eq!(
|
||||
rf.routes["grafana.example.com"].allowed_groups,
|
||||
vec!["admins", "developers"]
|
||||
);
|
||||
assert!(rf.routes["wiki.example.com"].allowed_groups.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_routes_from_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("routes.yaml");
|
||||
std::fs::write(
|
||||
&path,
|
||||
"routes:\n app.example.com:\n allowed_groups: [\"users\"]\n",
|
||||
)
|
||||
.unwrap();
|
||||
let routes = load_routes(path.to_str().unwrap()).unwrap();
|
||||
assert_eq!(routes.len(), 1);
|
||||
assert_eq!(routes["app.example.com"].allowed_groups, vec!["users"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_routes_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("routes.yaml");
|
||||
std::fs::write(&path, "routes: {}\n").unwrap();
|
||||
let routes = load_routes(path.to_str().unwrap()).unwrap();
|
||||
assert!(routes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_yaml_returns_error() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("routes.yaml");
|
||||
std::fs::write(&path, "not: valid: yaml: [[[").unwrap();
|
||||
assert!(load_routes(path.to_str().unwrap()).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_file_returns_error() {
|
||||
assert!(load_routes("/nonexistent/routes.yaml").is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn watch_routes_detects_change() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("routes.yaml");
|
||||
std::fs::write(
|
||||
&path,
|
||||
"routes:\n a.example.com:\n allowed_groups: []\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let initial = load_routes(path.to_str().unwrap()).unwrap();
|
||||
let shared: SharedRoutes = Arc::new(RwLock::new(initial));
|
||||
assert_eq!(shared.read().await.len(), 1);
|
||||
|
||||
let watch_path = path.to_str().unwrap().to_string();
|
||||
let watch_routes = shared.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
super::watch_routes(watch_path, watch_routes).await;
|
||||
});
|
||||
|
||||
// Let the watcher initialize and read the original content
|
||||
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
|
||||
|
||||
// Update the file
|
||||
std::fs::write(
|
||||
&path,
|
||||
"routes:\n a.example.com:\n allowed_groups: []\n b.example.com:\n allowed_groups: [\"admins\"]\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Poll interval is 5s; wait up to 12s for the change to be detected
|
||||
for _ in 0..12 {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
if shared.read().await.len() == 2 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert_eq!(shared.read().await.len(), 2);
|
||||
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
+100
@@ -0,0 +1,100 @@
|
||||
use crate::crypto::CookieCrypto;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
pub sub: String,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub groups: Vec<String>,
|
||||
pub exp: i64,
|
||||
pub iat: i64,
|
||||
pub refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub fn is_access_expired(&self) -> bool {
|
||||
now_timestamp() >= self.exp
|
||||
}
|
||||
|
||||
pub fn encrypt(&self, crypto: &CookieCrypto) -> anyhow::Result<String> {
|
||||
let json = serde_json::to_vec(self)?;
|
||||
crypto.encrypt(&json)
|
||||
}
|
||||
|
||||
pub fn decrypt(crypto: &CookieCrypto, encoded: &str) -> anyhow::Result<Self> {
|
||||
let plaintext = crypto.decrypt(encoded)?;
|
||||
Ok(serde_json::from_slice(&plaintext)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn now_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::crypto::CookieCrypto;
|
||||
|
||||
fn make_session(exp_offset: i64) -> Session {
|
||||
Session {
|
||||
sub: "user-123".into(),
|
||||
username: "john".into(),
|
||||
email: "john@example.com".into(),
|
||||
groups: vec!["admins".into(), "devs".into()],
|
||||
exp: now_timestamp() + exp_offset,
|
||||
iat: now_timestamp(),
|
||||
refresh_token: Some("refresh-tok-xyz".into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_decrypt_preserves_all_fields() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
let session = make_session(3600);
|
||||
let encrypted = session.encrypt(&crypto).unwrap();
|
||||
let restored = Session::decrypt(&crypto, &encrypted).unwrap();
|
||||
|
||||
assert_eq!(restored.sub, "user-123");
|
||||
assert_eq!(restored.username, "john");
|
||||
assert_eq!(restored.email, "john@example.com");
|
||||
assert_eq!(restored.groups, vec!["admins", "devs"]);
|
||||
assert_eq!(restored.exp, session.exp);
|
||||
assert_eq!(restored.iat, session.iat);
|
||||
assert_eq!(restored.refresh_token.as_deref(), Some("refresh-tok-xyz"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_without_refresh_token() {
|
||||
let crypto = CookieCrypto::new(&[0x42; 32]);
|
||||
let mut session = make_session(3600);
|
||||
session.refresh_token = None;
|
||||
let encrypted = session.encrypt(&crypto).unwrap();
|
||||
let restored = Session::decrypt(&crypto, &encrypted).unwrap();
|
||||
assert!(restored.refresh_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_session_not_expired() {
|
||||
let session = make_session(3600); // +1h
|
||||
assert!(!session.is_access_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn past_session_is_expired() {
|
||||
let session = make_session(-1); // 1 second ago
|
||||
assert!(session.is_access_expired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_key_cannot_decrypt() {
|
||||
let crypto1 = CookieCrypto::new(&[0x42; 32]);
|
||||
let crypto2 = CookieCrypto::new(&[0x99; 32]);
|
||||
let encrypted = make_session(3600).encrypt(&crypto1).unwrap();
|
||||
assert!(Session::decrypt(&crypto2, &encrypted).is_err());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user