Init
Build and Publish / Build and Publish Docker Image (push) Successful in 6m2s

This commit is contained in:
Ultradesu
2026-05-05 13:10:16 +01:00
commit 8d4321ea1a
14 changed files with 4728 additions and 0 deletions
+52
View File
@@ -0,0 +1,52 @@
use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD, Engine};
#[derive(Clone)]
pub struct Config {
pub listen: String,
pub oidc_issuer: String,
pub client_id: String,
pub client_secret: String,
pub cookie_secret: [u8; 32],
pub cookie_domain: String,
pub routes_file: String,
pub callback_url: String,
pub log_level: String,
}
impl Config {
pub fn from_env() -> Result<Self> {
let cookie_secret_b64 =
std::env::var("AUTH_PROXY_COOKIE_SECRET").context("AUTH_PROXY_COOKIE_SECRET required")?;
let decoded = STANDARD
.decode(&cookie_secret_b64)
.context("invalid base64 in AUTH_PROXY_COOKIE_SECRET")?;
let cookie_secret: [u8; 32] = decoded
.try_into()
.map_err(|v: Vec<u8>| anyhow::anyhow!("cookie secret must be 32 bytes, got {}", v.len()))?;
Ok(Self {
listen: std::env::var("AUTH_PROXY_LISTEN").unwrap_or_else(|_| "0.0.0.0:8080".into()),
oidc_issuer: std::env::var("AUTH_PROXY_OIDC_ISSUER")
.context("AUTH_PROXY_OIDC_ISSUER required")?,
client_id: std::env::var("AUTH_PROXY_CLIENT_ID")
.context("AUTH_PROXY_CLIENT_ID required")?,
client_secret: std::env::var("AUTH_PROXY_CLIENT_SECRET")
.context("AUTH_PROXY_CLIENT_SECRET required")?,
cookie_secret,
cookie_domain: std::env::var("AUTH_PROXY_COOKIE_DOMAIN")
.context("AUTH_PROXY_COOKIE_DOMAIN required")?,
routes_file: std::env::var("AUTH_PROXY_ROUTES_FILE")
.unwrap_or_else(|_| "/config/routes.yaml".into()),
callback_url: std::env::var("AUTH_PROXY_CALLBACK_URL")
.context("AUTH_PROXY_CALLBACK_URL required")?,
log_level: std::env::var("AUTH_PROXY_LOG_LEVEL").unwrap_or_else(|_| "info".into()),
})
}
pub fn base_url(&self) -> &str {
self.callback_url
.strip_suffix("/callback")
.unwrap_or(&self.callback_url)
}
}
+107
View File
@@ -0,0 +1,107 @@
use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use rand::RngCore;
#[derive(Clone)]
pub struct CookieCrypto {
cipher: Aes256Gcm,
}
impl CookieCrypto {
pub fn new(key_bytes: &[u8; 32]) -> Self {
Self {
cipher: Aes256Gcm::new_from_slice(key_bytes).expect("valid 32-byte key"),
}
}
pub fn encrypt(&self, plaintext: &[u8]) -> anyhow::Result<String> {
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext)
.map_err(|e| anyhow::anyhow!("encryption failed: {}", e))?;
let mut result = Vec::with_capacity(12 + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(URL_SAFE_NO_PAD.encode(&result))
}
pub fn decrypt(&self, encoded: &str) -> anyhow::Result<Vec<u8>> {
let data = URL_SAFE_NO_PAD.decode(encoded)?;
if data.len() < 13 {
anyhow::bail!("ciphertext too short");
}
let (nonce_bytes, ciphertext) = data.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
self.cipher
.decrypt(nonce, ciphertext)
.map_err(|e| anyhow::anyhow!("decryption failed: {}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_decrypt_roundtrip() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let plaintext = b"hello world";
let encrypted = crypto.encrypt(plaintext).unwrap();
let decrypted = crypto.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn encrypt_produces_different_ciphertext_each_time() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let a = crypto.encrypt(b"same").unwrap();
let b = crypto.encrypt(b"same").unwrap();
assert_ne!(a, b); // different nonces
}
#[test]
fn wrong_key_fails() {
let crypto1 = CookieCrypto::new(&[0x42; 32]);
let crypto2 = CookieCrypto::new(&[0x43; 32]);
let encrypted = crypto1.encrypt(b"hello").unwrap();
assert!(crypto2.decrypt(&encrypted).is_err());
}
#[test]
fn tampered_ciphertext_fails() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let encrypted = crypto.encrypt(b"hello").unwrap();
let mut data = URL_SAFE_NO_PAD.decode(&encrypted).unwrap();
*data.last_mut().unwrap() ^= 0xFF;
let tampered = URL_SAFE_NO_PAD.encode(&data);
assert!(crypto.decrypt(&tampered).is_err());
}
#[test]
fn empty_plaintext_roundtrip() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let encrypted = crypto.encrypt(b"").unwrap();
let decrypted = crypto.decrypt(&encrypted).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn short_ciphertext_rejected() {
let crypto = CookieCrypto::new(&[0x42; 32]);
assert!(crypto.decrypt("dG9vc2hvcnQ").is_err()); // "tooshort" base64
}
#[test]
fn invalid_base64_rejected() {
let crypto = CookieCrypto::new(&[0x42; 32]);
assert!(crypto.decrypt("not valid base64!!!").is_err());
}
}
+1160
View File
File diff suppressed because it is too large Load Diff
+200
View File
@@ -0,0 +1,200 @@
mod config;
mod crypto;
mod handlers;
mod oidc;
mod routes;
mod session;
use crate::config::Config;
use crate::crypto::CookieCrypto;
use crate::handlers::AppState;
use axum::routing::get;
use axum::Router;
use std::sync::Arc;
use tokio::sync::RwLock;
const HELP: &str = "\
rsauth2-proxy — Auth proxy for Traefik ForwardAuth with Keycloak OIDC
USAGE:
rsauth2-proxy Start the proxy (configured via environment variables)
rsauth2-proxy --help Show this help
rsauth2-proxy --version Show version
ENVIRONMENT VARIABLES:
AUTH_PROXY_OIDC_ISSUER (required) OIDC issuer URL
Example: https://auth.example.com/realms/main
AUTH_PROXY_CLIENT_ID (required) Keycloak client ID
AUTH_PROXY_CLIENT_SECRET (required) Keycloak client secret
AUTH_PROXY_COOKIE_SECRET (required) AES-256 encryption key, 32 bytes, base64-encoded
Generate with: openssl rand -base64 32
AUTH_PROXY_COOKIE_DOMAIN (required) Cookie domain, must cover all protected hosts
Example: .example.com
AUTH_PROXY_CALLBACK_URL (required) Full public URL for the OIDC callback endpoint
Example: https://auth-proxy.example.com/callback
AUTH_PROXY_LISTEN (optional) Listen address [default: 0.0.0.0:8080]
AUTH_PROXY_ROUTES_FILE (optional) Path to routes config [default: /config/routes.yaml]
AUTH_PROXY_LOG_LEVEL (optional) Log level: debug, info, warn, error [default: info]
ENDPOINTS:
GET /auth ForwardAuth endpoint (called by Traefik for every request)
GET /callback OIDC callback (receives authorization code from Keycloak)
GET /refresh Token refresh (transparent redirect when access token expires)
GET /sign_out Logout (clears cookie, redirects to Keycloak end_session)
GET /health Health check (returns 200 OK)
GET /metrics Prometheus metrics (text exposition format)
ROUTES FILE (routes.yaml):
routes:
grafana.example.com:
allowed_groups: [\"admins\", \"developers\"] # user must be in at least one group
wiki.example.com:
allowed_groups: [] # any authenticated user
# Hosts not listed are denied (403)
The file is polled every 5 seconds and reloaded on change.
Works with Kubernetes ConfigMap volume mounts (symlink-based updates).
RESPONSE HEADERS (forwarded to upstream on 200):
X-Auth-Request-User preferred_username from the ID token
X-Auth-Request-Email email from the ID token
X-Auth-Request-Groups comma-separated list of groups
KEYCLOAK SETUP:
Groups must be included in the ID token via a \"Group Membership\" protocol mapper
with claim name \"groups\" and full path disabled.
TRAEFIK MIDDLEWARE:
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: auth-proxy
spec:
forwardAuth:
address: http://auth-proxy.auth-proxy.svc:80/auth
trustForwardHeader: true
authResponseHeaders:
- X-Auth-Request-User
- X-Auth-Request-Email
- X-Auth-Request-Groups
SECURITY:
- Sessions encrypted with AES-256-GCM (not just signed)
- PKCE S256 on all authorization flows
- JWT validated against Keycloak JWKS (keys refreshed hourly)
- No open redirects (redirect URL encrypted in state parameter)
- Deny by default (unlisted hosts get 403)
- Fully stateless (all state in encrypted cookies), supports multiple replicas
- Graceful shutdown on SIGTERM/SIGINT
SOURCE: https://github.com/ab/rsauth2-proxy";
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args: Vec<String> = std::env::args().collect();
if args.iter().any(|a| a == "--help" || a == "-h") {
println!("{}", HELP);
return Ok(());
}
if args.iter().any(|a| a == "--version" || a == "-V") {
println!("rsauth2-proxy {}", VERSION);
return Ok(());
}
let config = Config::from_env()?;
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_new(&config.log_level)
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let metrics_handle = metrics_exporter_prometheus::PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder");
metrics::describe_counter!("auth_requests_total", "Total auth check requests");
metrics::describe_histogram!(
"auth_request_duration_seconds",
"Auth check request duration in seconds"
);
metrics::describe_counter!("callback_requests_total", "Total OIDC callback requests");
metrics::describe_counter!("refresh_requests_total", "Total token refresh requests");
metrics::describe_gauge!("routes_loaded_total", "Number of loaded routes");
tracing::info!(listen = %config.listen, "starting rsauth2-proxy");
let oidc = oidc::OidcClient::discover(
&config.oidc_issuer,
config.client_id.clone(),
config.client_secret.clone(),
config.callback_url.clone(),
)
.await?;
let initial_routes = routes::load_routes(&config.routes_file)?;
let shared_routes: routes::SharedRoutes = Arc::new(RwLock::new(initial_routes));
let crypto = CookieCrypto::new(&config.cookie_secret);
let state = Arc::new(AppState {
config: config.clone(),
oidc,
routes: shared_routes.clone(),
crypto,
metrics_handle,
});
// Background: watch routes file for changes
let routes_path = config.routes_file.clone();
let routes_ref = shared_routes.clone();
tokio::spawn(async move {
routes::watch_routes(routes_path, routes_ref).await;
});
// Background: periodically refresh JWKS keys
let oidc_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(3600));
loop {
interval.tick().await;
if let Err(e) = oidc_state.oidc.refresh_jwks().await {
tracing::warn!(error = %e, "failed to refresh JWKS");
}
}
});
let app = Router::new()
.route("/", get(handlers::root))
.route("/auth", get(handlers::auth))
.route("/callback", get(handlers::callback))
.route("/refresh", get(handlers::refresh))
.route("/sign_out", get(handlers::sign_out))
.route("/health", get(handlers::health))
.route("/metrics", get(handlers::metrics_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind(&config.listen).await?;
tracing::info!(addr = %config.listen, "listening");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
tracing::info!("shutdown complete");
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()).unwrap();
tokio::select! {
_ = ctrl_c => {},
_ = sigterm.recv() => {},
}
tracing::info!("shutdown signal received");
}
+223
View File
@@ -0,0 +1,223 @@
use anyhow::{Context, Result};
use jsonwebtoken::{decode, decode_header, DecodingKey, TokenData, Validation};
use reqwest::Client;
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Deserialize)]
pub struct OidcDiscovery {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub jwks_uri: String,
pub end_session_endpoint: Option<String>,
}
#[derive(Deserialize, Clone)]
#[allow(dead_code)]
pub struct JwkKey {
pub kid: Option<String>,
pub kty: String,
pub n: Option<String>,
pub e: Option<String>,
}
#[derive(Deserialize)]
struct JwksResponse {
keys: Vec<JwkKey>,
}
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct TokenResponse {
pub id_token: Option<String>,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
pub refresh_expires_in: Option<u64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct IdTokenClaims {
pub sub: String,
pub preferred_username: Option<String>,
pub email: Option<String>,
pub groups: Option<Vec<String>>,
pub exp: i64,
pub iat: i64,
pub nonce: Option<String>,
}
pub struct OidcClient {
http: Client,
pub discovery: OidcDiscovery,
jwks: Arc<RwLock<Vec<JwkKey>>>,
client_id: String,
client_secret: String,
redirect_uri: String,
}
impl OidcClient {
pub async fn discover(
issuer: &str,
client_id: String,
client_secret: String,
redirect_uri: String,
) -> Result<Self> {
let http = Client::new();
let discovery_url = format!(
"{}/.well-known/openid-configuration",
issuer.trim_end_matches('/')
);
let discovery: OidcDiscovery = http
.get(&discovery_url)
.send()
.await?
.error_for_status()
.context("OIDC discovery request failed")?
.json()
.await
.context("failed to parse OIDC discovery")?;
let jwks: JwksResponse = http
.get(&discovery.jwks_uri)
.send()
.await?
.error_for_status()
.context("JWKS request failed")?
.json()
.await
.context("failed to parse JWKS")?;
tracing::info!(
keys = jwks.keys.len(),
"OIDC discovery complete"
);
Ok(Self {
http,
discovery,
jwks: Arc::new(RwLock::new(jwks.keys)),
client_id,
client_secret,
redirect_uri,
})
}
pub fn auth_url(&self, state: &str, pkce_challenge: &str, nonce: &str) -> String {
let mut url = url::Url::parse(&self.discovery.authorization_endpoint)
.expect("valid authorization_endpoint URL");
url.query_pairs_mut()
.append_pair("client_id", &self.client_id)
.append_pair("redirect_uri", &self.redirect_uri)
.append_pair("response_type", "code")
.append_pair("scope", "openid profile email")
.append_pair("state", state)
.append_pair("nonce", nonce)
.append_pair("code_challenge", pkce_challenge)
.append_pair("code_challenge_method", "S256");
url.to_string()
}
pub async fn exchange_code(&self, code: &str, pkce_verifier: &str) -> Result<TokenResponse> {
let resp = self
.http
.post(&self.discovery.token_endpoint)
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", self.redirect_uri.as_str()),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("code_verifier", pkce_verifier),
])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("token exchange failed ({}): {}", status, body);
}
Ok(resp.json().await?)
}
pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
let resp = self
.http
.post(&self.discovery.token_endpoint)
.form(&[
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("token refresh failed ({}): {}", status, body);
}
Ok(resp.json().await?)
}
pub async fn validate_id_token(
&self,
token: &str,
expected_nonce: Option<&str>,
) -> Result<IdTokenClaims> {
let header = decode_header(token)?;
let kid = header.kid.as_deref();
let jwks = self.jwks.read().await;
let key = if let Some(kid) = kid {
jwks.iter().find(|k| k.kid.as_deref() == Some(kid))
} else {
jwks.first()
}
.context("no matching JWK found")?;
let decoding_key = DecodingKey::from_rsa_components(
key.n.as_deref().context("missing 'n' in JWK")?,
key.e.as_deref().context("missing 'e' in JWK")?,
)?;
let mut validation = Validation::new(header.alg);
validation.set_issuer(&[&self.discovery.issuer]);
validation.set_audience(&[&self.client_id]);
let token_data: TokenData<IdTokenClaims> = decode(token, &decoding_key, &validation)?;
if let Some(expected) = expected_nonce {
if token_data.claims.nonce.as_deref() != Some(expected) {
anyhow::bail!("nonce mismatch");
}
}
Ok(token_data.claims)
}
pub async fn refresh_jwks(&self) -> Result<()> {
let jwks: JwksResponse = self
.http
.get(&self.discovery.jwks_uri)
.send()
.await?
.error_for_status()?
.json()
.await?;
let count = jwks.keys.len();
*self.jwks.write().await = jwks.keys;
tracing::debug!(keys = count, "JWKS refreshed");
Ok(())
}
pub fn end_session_url(&self) -> Option<&str> {
self.discovery.end_session_endpoint.as_deref()
}
}
+157
View File
@@ -0,0 +1,157 @@
use anyhow::Result;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Deserialize, Clone)]
pub struct RouteConfig {
pub allowed_groups: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct RoutesFile {
pub routes: HashMap<String, RouteConfig>,
}
pub type SharedRoutes = Arc<RwLock<HashMap<String, RouteConfig>>>;
pub fn load_routes(path: &str) -> Result<HashMap<String, RouteConfig>> {
let content = std::fs::read_to_string(path)?;
let routes_file: RoutesFile = serde_yaml::from_str(&content)?;
let count = routes_file.routes.len();
tracing::info!(count, "routes loaded");
metrics::gauge!("routes_loaded_total").set(count as f64);
Ok(routes_file.routes)
}
pub async fn watch_routes(path: String, routes: SharedRoutes) {
let mut last_content = std::fs::read_to_string(&path).unwrap_or_default();
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
loop {
interval.tick().await;
// Read through the symlink — k8s ConfigMap updates swap the symlink target,
// so reading the original path always gives current content.
match std::fs::read_to_string(&path) {
Ok(content) if content != last_content => match serde_yaml::from_str::<RoutesFile>(&content) {
Ok(routes_file) => {
let count = routes_file.routes.len();
*routes.write().await = routes_file.routes;
last_content = content;
metrics::gauge!("routes_loaded_total").set(count as f64);
tracing::info!(count, "routes reloaded");
}
Err(e) => {
tracing::error!(error = %e, "failed to parse routes file");
}
},
Ok(_) => {}
Err(e) => {
tracing::warn!(error = %e, "failed to read routes file");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_routes_yaml() {
let yaml = r#"
routes:
grafana.example.com:
allowed_groups: ["admins", "developers"]
wiki.example.com:
allowed_groups: []
"#;
let rf: RoutesFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(rf.routes.len(), 2);
assert_eq!(
rf.routes["grafana.example.com"].allowed_groups,
vec!["admins", "developers"]
);
assert!(rf.routes["wiki.example.com"].allowed_groups.is_empty());
}
#[test]
fn load_routes_from_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(
&path,
"routes:\n app.example.com:\n allowed_groups: [\"users\"]\n",
)
.unwrap();
let routes = load_routes(path.to_str().unwrap()).unwrap();
assert_eq!(routes.len(), 1);
assert_eq!(routes["app.example.com"].allowed_groups, vec!["users"]);
}
#[test]
fn empty_routes_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(&path, "routes: {}\n").unwrap();
let routes = load_routes(path.to_str().unwrap()).unwrap();
assert!(routes.is_empty());
}
#[test]
fn invalid_yaml_returns_error() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(&path, "not: valid: yaml: [[[").unwrap();
assert!(load_routes(path.to_str().unwrap()).is_err());
}
#[test]
fn missing_file_returns_error() {
assert!(load_routes("/nonexistent/routes.yaml").is_err());
}
#[tokio::test]
async fn watch_routes_detects_change() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(
&path,
"routes:\n a.example.com:\n allowed_groups: []\n",
)
.unwrap();
let initial = load_routes(path.to_str().unwrap()).unwrap();
let shared: SharedRoutes = Arc::new(RwLock::new(initial));
assert_eq!(shared.read().await.len(), 1);
let watch_path = path.to_str().unwrap().to_string();
let watch_routes = shared.clone();
let handle = tokio::spawn(async move {
super::watch_routes(watch_path, watch_routes).await;
});
// Let the watcher initialize and read the original content
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
// Update the file
std::fs::write(
&path,
"routes:\n a.example.com:\n allowed_groups: []\n b.example.com:\n allowed_groups: [\"admins\"]\n",
)
.unwrap();
// Poll interval is 5s; wait up to 12s for the change to be detected
for _ in 0..12 {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if shared.read().await.len() == 2 {
break;
}
}
assert_eq!(shared.read().await.len(), 2);
handle.abort();
}
}
+100
View File
@@ -0,0 +1,100 @@
use crate::crypto::CookieCrypto;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct Session {
pub sub: String,
pub username: String,
pub email: String,
pub groups: Vec<String>,
pub exp: i64,
pub iat: i64,
pub refresh_token: Option<String>,
}
impl Session {
pub fn is_access_expired(&self) -> bool {
now_timestamp() >= self.exp
}
pub fn encrypt(&self, crypto: &CookieCrypto) -> anyhow::Result<String> {
let json = serde_json::to_vec(self)?;
crypto.encrypt(&json)
}
pub fn decrypt(crypto: &CookieCrypto, encoded: &str) -> anyhow::Result<Self> {
let plaintext = crypto.decrypt(encoded)?;
Ok(serde_json::from_slice(&plaintext)?)
}
}
pub fn now_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::CookieCrypto;
fn make_session(exp_offset: i64) -> Session {
Session {
sub: "user-123".into(),
username: "john".into(),
email: "john@example.com".into(),
groups: vec!["admins".into(), "devs".into()],
exp: now_timestamp() + exp_offset,
iat: now_timestamp(),
refresh_token: Some("refresh-tok-xyz".into()),
}
}
#[test]
fn encrypt_decrypt_preserves_all_fields() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let session = make_session(3600);
let encrypted = session.encrypt(&crypto).unwrap();
let restored = Session::decrypt(&crypto, &encrypted).unwrap();
assert_eq!(restored.sub, "user-123");
assert_eq!(restored.username, "john");
assert_eq!(restored.email, "john@example.com");
assert_eq!(restored.groups, vec!["admins", "devs"]);
assert_eq!(restored.exp, session.exp);
assert_eq!(restored.iat, session.iat);
assert_eq!(restored.refresh_token.as_deref(), Some("refresh-tok-xyz"));
}
#[test]
fn session_without_refresh_token() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let mut session = make_session(3600);
session.refresh_token = None;
let encrypted = session.encrypt(&crypto).unwrap();
let restored = Session::decrypt(&crypto, &encrypted).unwrap();
assert!(restored.refresh_token.is_none());
}
#[test]
fn active_session_not_expired() {
let session = make_session(3600); // +1h
assert!(!session.is_access_expired());
}
#[test]
fn past_session_is_expired() {
let session = make_session(-1); // 1 second ago
assert!(session.is_access_expired());
}
#[test]
fn wrong_key_cannot_decrypt() {
let crypto1 = CookieCrypto::new(&[0x42; 32]);
let crypto2 = CookieCrypto::new(&[0x99; 32]);
let encrypted = make_session(3600).encrypt(&crypto1).unwrap();
assert!(Session::decrypt(&crypto2, &encrypted).is_err());
}
}