Init
Build and Publish / Build and Publish Docker Image (push) Successful in 6m2s

This commit is contained in:
Ultradesu
2026-05-05 13:10:16 +01:00
commit 8d4321ea1a
14 changed files with 4728 additions and 0 deletions
+4
View File
@@ -0,0 +1,4 @@
target/
.git/
AUTH_PROXY_SPEC.md
routes.yaml
+58
View File
@@ -0,0 +1,58 @@
name: Build and Publish
on:
push:
branches:
- master
- main
tags:
- 'v*.*.*'
env:
IMAGE_NAME: ultradesu/rsauth2-proxy
jobs:
build_docker:
name: Build and Publish Docker Image
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Extract metadata
id: meta
run: |
VERSION=$(grep '^version' Cargo.toml | head -1 | cut -d'"' -f2)
echo "cargo_version=${VERSION}" >> $GITHUB_OUTPUT
if [[ "${{ github.ref }}" == refs/tags/* ]]; then
TAG_NAME=${GITHUB_REF#refs/tags/}
echo "docker_tags=${IMAGE_NAME}:${TAG_NAME},${IMAGE_NAME}:${VERSION},${IMAGE_NAME}:latest" >> $GITHUB_OUTPUT
elif [[ "${{ github.ref }}" == refs/heads/* ]]; then
BRANCH=${GITHUB_REF#refs/heads/}
echo "docker_tags=${IMAGE_NAME}:${BRANCH},${IMAGE_NAME}:${VERSION},${IMAGE_NAME}:$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
else
echo "docker_tags=${IMAGE_NAME}:$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT
fi
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.docker_tags }}
cache-from: type=registry,ref=${{ IMAGE_NAME }}:buildcache
cache-to: type=registry,ref=${{ IMAGE_NAME }}:buildcache,mode=max
+3
View File
@@ -0,0 +1,3 @@
/target/
routes.yaml
AUTH_PROXY_SPEC.md
Generated
+2390
View File
File diff suppressed because it is too large Load Diff
+34
View File
@@ -0,0 +1,34 @@
[package]
name = "rsauth2-proxy"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_yaml = "0.9"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-webpki-roots", "json"] }
jsonwebtoken = "9"
aes-gcm = "0.10"
base64 = "0.22"
rand = "0.8"
sha2 = "0.10"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
anyhow = "1"
url = "2"
metrics = "0.23"
metrics-exporter-prometheus = { version = "0.15", default-features = false }
[dev-dependencies]
rsa = "0.9"
tower = { version = "0.5", features = ["util"] }
tempfile = "3"
[profile.release]
opt-level = "z"
lto = true
codegen-units = 1
strip = true
+16
View File
@@ -0,0 +1,16 @@
FROM rust:1-slim AS builder
WORKDIR /app
# Cache dependencies
COPY Cargo.toml Cargo.lock ./
RUN mkdir src && echo "fn main() {}" > src/main.rs && \
cargo build --release && \
rm -rf src
COPY src ./src
RUN touch src/main.rs && cargo build --release
FROM gcr.io/distroless/cc-debian12
COPY --from=builder /app/target/release/rsauth2-proxy /rsauth2-proxy
ENTRYPOINT ["/rsauth2-proxy"]
+224
View File
@@ -0,0 +1,224 @@
# rsauth2-proxy
Auth proxy for [Traefik ForwardAuth](https://doc.traefik.io/traefik/middlewares/http/forwardauth/) with Keycloak OIDC. Single instance protects all services in a cluster. Replaces oauth2-proxy.
## How it works
```
Browser → Traefik → ForwardAuth (/auth) → rsauth2-proxy
├── no session → 302 to Keycloak login
├── valid session → 200 + user headers
└── expired session → token refresh → 302 back
```
Traefik calls `/auth` for every request to a protected service. The proxy checks the encrypted session cookie, verifies group membership against the route config, and returns 200 (allow), 403 (deny), or 302 (login required).
Sessions are stored entirely in an AES-256-GCM encrypted cookie. No server-side state. Any number of replicas work without coordination.
## Configuration
All configuration is via environment variables.
| Variable | Required | Default | Description |
|---|---|---|---|
| `AUTH_PROXY_OIDC_ISSUER` | yes | | OIDC issuer URL |
| `AUTH_PROXY_CLIENT_ID` | yes | | Keycloak client ID |
| `AUTH_PROXY_CLIENT_SECRET` | yes | | Keycloak client secret |
| `AUTH_PROXY_COOKIE_SECRET` | yes | | AES-256 key, 32 bytes, base64-encoded |
| `AUTH_PROXY_COOKIE_DOMAIN` | yes | | Cookie domain (e.g. `.example.com`) |
| `AUTH_PROXY_CALLBACK_URL` | yes | | Full callback URL (e.g. `https://auth.example.com/callback`) |
| `AUTH_PROXY_LISTEN` | no | `0.0.0.0:8080` | Listen address |
| `AUTH_PROXY_ROUTES_FILE` | no | `/config/routes.yaml` | Path to routes config |
| `AUTH_PROXY_LOG_LEVEL` | no | `info` | Log level (`debug`, `info`, `warn`, `error`) |
Generate a cookie secret:
```sh
openssl rand -base64 32
```
## Routes file
Defines which hosts are protected and which groups have access.
```yaml
routes:
grafana.example.com:
allowed_groups: ["admins", "developers"]
wiki.example.com:
allowed_groups: []
# Empty list = any authenticated user
secret.example.com:
allowed_groups: ["admins"]
```
Rules:
- Host in routes, `allowed_groups` empty — any authenticated user is allowed
- Host in routes, `allowed_groups` set — user must be in at least one listed group
- Host not in routes — denied (403)
The file is polled every 5 seconds and reloaded on change. This works reliably with Kubernetes ConfigMap volume mounts.
## Endpoints
| Path | Purpose |
|---|---|
| `GET /auth` | ForwardAuth endpoint (called by Traefik) |
| `GET /callback` | OIDC callback (receives authorization code from Keycloak) |
| `GET /refresh` | Token refresh (transparent redirect when session expires) |
| `GET /sign_out` | Logout (clears cookie, redirects to Keycloak end_session) |
| `GET /health` | Health check (returns 200) |
## Keycloak setup
The proxy reads user groups from the `groups` claim in the ID token. Keycloak does not include this by default. Add a group membership mapper to the client:
```hcl
resource "keycloak_openid_group_membership_protocol_mapper" "groups" {
realm_id = keycloak_realm.main.id
client_id = keycloak_openid_client.auth_proxy.id
name = "groups"
claim_name = "groups"
full_path = false
}
```
Or manually: Client Scopes → your client → Mappers → Add mapper → "Group Membership", claim name `groups`, full path off.
## Kubernetes deployment
```yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: auth-proxy
namespace: auth-proxy
spec:
replicas: 2
selector:
matchLabels:
app: auth-proxy
template:
metadata:
labels:
app: auth-proxy
spec:
containers:
- name: auth-proxy
image: ghcr.io/your-org/rsauth2-proxy:latest
ports:
- containerPort: 8080
envFrom:
- secretRef:
name: auth-proxy-creds
volumeMounts:
- name: routes
mountPath: /config
readOnly: true
livenessProbe:
httpGet:
path: /health
port: 8080
readinessProbe:
httpGet:
path: /health
port: 8080
volumes:
- name: routes
configMap:
name: auth-proxy-routes
---
apiVersion: v1
kind: Service
metadata:
name: auth-proxy
namespace: auth-proxy
spec:
selector:
app: auth-proxy
ports:
- port: 80
targetPort: 8080
```
### Traefik ForwardAuth middleware
Create in each namespace that has protected services:
```yaml
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: auth-proxy
spec:
forwardAuth:
address: http://auth-proxy.auth-proxy.svc:80/auth
trustForwardHeader: true
authResponseHeaders:
- X-Auth-Request-User
- X-Auth-Request-Email
- X-Auth-Request-Groups
```
### Ingress for auth-proxy itself
The `/callback`, `/refresh`, and `/sign_out` endpoints must be reachable by browsers:
```yaml
apiVersion: traefik.io/v1alpha1
kind: IngressRoute
metadata:
name: auth-proxy
namespace: auth-proxy
spec:
entryPoints:
- websecure
routes:
- match: Host(`auth.example.com`) && (Path(`/callback`) || Path(`/refresh`) || Path(`/sign_out`))
kind: Rule
services:
- name: auth-proxy
port: 80
tls:
secretName: auth-proxy-tls
```
## Building
```sh
cargo build --release
```
### Docker
```sh
docker build -t rsauth2-proxy .
```
Produces a static musl binary in a `FROM scratch` image (~10MB).
## Security properties
- **Encrypted cookies** — AES-256-GCM, not just signed. Cookie contents cannot be read or tampered with without the key.
- **PKCE (S256)** — protects the authorization code exchange against interception.
- **Stateless PKCE** — the PKCE verifier is encrypted inside the `state` parameter. No server-side storage needed.
- **No open redirect** — the redirect URL after login is encrypted in `state`, not taken from user input.
- **Deny by default** — any host not listed in routes gets 403.
- **JWT validation** — ID tokens are verified against Keycloak's JWKS (keys refreshed hourly).
- **Cookie flags**`HttpOnly`, `Secure`, `SameSite=Lax`.
## Response headers
On successful authentication, the following headers are set on the request forwarded to the upstream service:
| Header | Value |
|---|---|
| `X-Auth-Request-User` | `preferred_username` from the ID token |
| `X-Auth-Request-Email` | `email` from the ID token |
| `X-Auth-Request-Groups` | Comma-separated list of groups |
## License
MIT
+52
View File
@@ -0,0 +1,52 @@
use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD, Engine};
#[derive(Clone)]
pub struct Config {
pub listen: String,
pub oidc_issuer: String,
pub client_id: String,
pub client_secret: String,
pub cookie_secret: [u8; 32],
pub cookie_domain: String,
pub routes_file: String,
pub callback_url: String,
pub log_level: String,
}
impl Config {
pub fn from_env() -> Result<Self> {
let cookie_secret_b64 =
std::env::var("AUTH_PROXY_COOKIE_SECRET").context("AUTH_PROXY_COOKIE_SECRET required")?;
let decoded = STANDARD
.decode(&cookie_secret_b64)
.context("invalid base64 in AUTH_PROXY_COOKIE_SECRET")?;
let cookie_secret: [u8; 32] = decoded
.try_into()
.map_err(|v: Vec<u8>| anyhow::anyhow!("cookie secret must be 32 bytes, got {}", v.len()))?;
Ok(Self {
listen: std::env::var("AUTH_PROXY_LISTEN").unwrap_or_else(|_| "0.0.0.0:8080".into()),
oidc_issuer: std::env::var("AUTH_PROXY_OIDC_ISSUER")
.context("AUTH_PROXY_OIDC_ISSUER required")?,
client_id: std::env::var("AUTH_PROXY_CLIENT_ID")
.context("AUTH_PROXY_CLIENT_ID required")?,
client_secret: std::env::var("AUTH_PROXY_CLIENT_SECRET")
.context("AUTH_PROXY_CLIENT_SECRET required")?,
cookie_secret,
cookie_domain: std::env::var("AUTH_PROXY_COOKIE_DOMAIN")
.context("AUTH_PROXY_COOKIE_DOMAIN required")?,
routes_file: std::env::var("AUTH_PROXY_ROUTES_FILE")
.unwrap_or_else(|_| "/config/routes.yaml".into()),
callback_url: std::env::var("AUTH_PROXY_CALLBACK_URL")
.context("AUTH_PROXY_CALLBACK_URL required")?,
log_level: std::env::var("AUTH_PROXY_LOG_LEVEL").unwrap_or_else(|_| "info".into()),
})
}
pub fn base_url(&self) -> &str {
self.callback_url
.strip_suffix("/callback")
.unwrap_or(&self.callback_url)
}
}
+107
View File
@@ -0,0 +1,107 @@
use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use rand::RngCore;
#[derive(Clone)]
pub struct CookieCrypto {
cipher: Aes256Gcm,
}
impl CookieCrypto {
pub fn new(key_bytes: &[u8; 32]) -> Self {
Self {
cipher: Aes256Gcm::new_from_slice(key_bytes).expect("valid 32-byte key"),
}
}
pub fn encrypt(&self, plaintext: &[u8]) -> anyhow::Result<String> {
let mut nonce_bytes = [0u8; 12];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext)
.map_err(|e| anyhow::anyhow!("encryption failed: {}", e))?;
let mut result = Vec::with_capacity(12 + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(URL_SAFE_NO_PAD.encode(&result))
}
pub fn decrypt(&self, encoded: &str) -> anyhow::Result<Vec<u8>> {
let data = URL_SAFE_NO_PAD.decode(encoded)?;
if data.len() < 13 {
anyhow::bail!("ciphertext too short");
}
let (nonce_bytes, ciphertext) = data.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
self.cipher
.decrypt(nonce, ciphertext)
.map_err(|e| anyhow::anyhow!("decryption failed: {}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encrypt_decrypt_roundtrip() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let plaintext = b"hello world";
let encrypted = crypto.encrypt(plaintext).unwrap();
let decrypted = crypto.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn encrypt_produces_different_ciphertext_each_time() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let a = crypto.encrypt(b"same").unwrap();
let b = crypto.encrypt(b"same").unwrap();
assert_ne!(a, b); // different nonces
}
#[test]
fn wrong_key_fails() {
let crypto1 = CookieCrypto::new(&[0x42; 32]);
let crypto2 = CookieCrypto::new(&[0x43; 32]);
let encrypted = crypto1.encrypt(b"hello").unwrap();
assert!(crypto2.decrypt(&encrypted).is_err());
}
#[test]
fn tampered_ciphertext_fails() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let encrypted = crypto.encrypt(b"hello").unwrap();
let mut data = URL_SAFE_NO_PAD.decode(&encrypted).unwrap();
*data.last_mut().unwrap() ^= 0xFF;
let tampered = URL_SAFE_NO_PAD.encode(&data);
assert!(crypto.decrypt(&tampered).is_err());
}
#[test]
fn empty_plaintext_roundtrip() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let encrypted = crypto.encrypt(b"").unwrap();
let decrypted = crypto.decrypt(&encrypted).unwrap();
assert!(decrypted.is_empty());
}
#[test]
fn short_ciphertext_rejected() {
let crypto = CookieCrypto::new(&[0x42; 32]);
assert!(crypto.decrypt("dG9vc2hvcnQ").is_err()); // "tooshort" base64
}
#[test]
fn invalid_base64_rejected() {
let crypto = CookieCrypto::new(&[0x42; 32]);
assert!(crypto.decrypt("not valid base64!!!").is_err());
}
}
+1160
View File
File diff suppressed because it is too large Load Diff
+200
View File
@@ -0,0 +1,200 @@
mod config;
mod crypto;
mod handlers;
mod oidc;
mod routes;
mod session;
use crate::config::Config;
use crate::crypto::CookieCrypto;
use crate::handlers::AppState;
use axum::routing::get;
use axum::Router;
use std::sync::Arc;
use tokio::sync::RwLock;
const HELP: &str = "\
rsauth2-proxy Auth proxy for Traefik ForwardAuth with Keycloak OIDC
USAGE:
rsauth2-proxy Start the proxy (configured via environment variables)
rsauth2-proxy --help Show this help
rsauth2-proxy --version Show version
ENVIRONMENT VARIABLES:
AUTH_PROXY_OIDC_ISSUER (required) OIDC issuer URL
Example: https://auth.example.com/realms/main
AUTH_PROXY_CLIENT_ID (required) Keycloak client ID
AUTH_PROXY_CLIENT_SECRET (required) Keycloak client secret
AUTH_PROXY_COOKIE_SECRET (required) AES-256 encryption key, 32 bytes, base64-encoded
Generate with: openssl rand -base64 32
AUTH_PROXY_COOKIE_DOMAIN (required) Cookie domain, must cover all protected hosts
Example: .example.com
AUTH_PROXY_CALLBACK_URL (required) Full public URL for the OIDC callback endpoint
Example: https://auth-proxy.example.com/callback
AUTH_PROXY_LISTEN (optional) Listen address [default: 0.0.0.0:8080]
AUTH_PROXY_ROUTES_FILE (optional) Path to routes config [default: /config/routes.yaml]
AUTH_PROXY_LOG_LEVEL (optional) Log level: debug, info, warn, error [default: info]
ENDPOINTS:
GET /auth ForwardAuth endpoint (called by Traefik for every request)
GET /callback OIDC callback (receives authorization code from Keycloak)
GET /refresh Token refresh (transparent redirect when access token expires)
GET /sign_out Logout (clears cookie, redirects to Keycloak end_session)
GET /health Health check (returns 200 OK)
GET /metrics Prometheus metrics (text exposition format)
ROUTES FILE (routes.yaml):
routes:
grafana.example.com:
allowed_groups: [\"admins\", \"developers\"] # user must be in at least one group
wiki.example.com:
allowed_groups: [] # any authenticated user
# Hosts not listed are denied (403)
The file is polled every 5 seconds and reloaded on change.
Works with Kubernetes ConfigMap volume mounts (symlink-based updates).
RESPONSE HEADERS (forwarded to upstream on 200):
X-Auth-Request-User preferred_username from the ID token
X-Auth-Request-Email email from the ID token
X-Auth-Request-Groups comma-separated list of groups
KEYCLOAK SETUP:
Groups must be included in the ID token via a \"Group Membership\" protocol mapper
with claim name \"groups\" and full path disabled.
TRAEFIK MIDDLEWARE:
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: auth-proxy
spec:
forwardAuth:
address: http://auth-proxy.auth-proxy.svc:80/auth
trustForwardHeader: true
authResponseHeaders:
- X-Auth-Request-User
- X-Auth-Request-Email
- X-Auth-Request-Groups
SECURITY:
- Sessions encrypted with AES-256-GCM (not just signed)
- PKCE S256 on all authorization flows
- JWT validated against Keycloak JWKS (keys refreshed hourly)
- No open redirects (redirect URL encrypted in state parameter)
- Deny by default (unlisted hosts get 403)
- Fully stateless (all state in encrypted cookies), supports multiple replicas
- Graceful shutdown on SIGTERM/SIGINT
SOURCE: https://github.com/ab/rsauth2-proxy";
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args: Vec<String> = std::env::args().collect();
if args.iter().any(|a| a == "--help" || a == "-h") {
println!("{}", HELP);
return Ok(());
}
if args.iter().any(|a| a == "--version" || a == "-V") {
println!("rsauth2-proxy {}", VERSION);
return Ok(());
}
let config = Config::from_env()?;
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_new(&config.log_level)
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let metrics_handle = metrics_exporter_prometheus::PrometheusBuilder::new()
.install_recorder()
.expect("failed to install prometheus recorder");
metrics::describe_counter!("auth_requests_total", "Total auth check requests");
metrics::describe_histogram!(
"auth_request_duration_seconds",
"Auth check request duration in seconds"
);
metrics::describe_counter!("callback_requests_total", "Total OIDC callback requests");
metrics::describe_counter!("refresh_requests_total", "Total token refresh requests");
metrics::describe_gauge!("routes_loaded_total", "Number of loaded routes");
tracing::info!(listen = %config.listen, "starting rsauth2-proxy");
let oidc = oidc::OidcClient::discover(
&config.oidc_issuer,
config.client_id.clone(),
config.client_secret.clone(),
config.callback_url.clone(),
)
.await?;
let initial_routes = routes::load_routes(&config.routes_file)?;
let shared_routes: routes::SharedRoutes = Arc::new(RwLock::new(initial_routes));
let crypto = CookieCrypto::new(&config.cookie_secret);
let state = Arc::new(AppState {
config: config.clone(),
oidc,
routes: shared_routes.clone(),
crypto,
metrics_handle,
});
// Background: watch routes file for changes
let routes_path = config.routes_file.clone();
let routes_ref = shared_routes.clone();
tokio::spawn(async move {
routes::watch_routes(routes_path, routes_ref).await;
});
// Background: periodically refresh JWKS keys
let oidc_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(3600));
loop {
interval.tick().await;
if let Err(e) = oidc_state.oidc.refresh_jwks().await {
tracing::warn!(error = %e, "failed to refresh JWKS");
}
}
});
let app = Router::new()
.route("/", get(handlers::root))
.route("/auth", get(handlers::auth))
.route("/callback", get(handlers::callback))
.route("/refresh", get(handlers::refresh))
.route("/sign_out", get(handlers::sign_out))
.route("/health", get(handlers::health))
.route("/metrics", get(handlers::metrics_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind(&config.listen).await?;
tracing::info!(addr = %config.listen, "listening");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
tracing::info!("shutdown complete");
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = tokio::signal::ctrl_c();
let mut sigterm =
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()).unwrap();
tokio::select! {
_ = ctrl_c => {},
_ = sigterm.recv() => {},
}
tracing::info!("shutdown signal received");
}
+223
View File
@@ -0,0 +1,223 @@
use anyhow::{Context, Result};
use jsonwebtoken::{decode, decode_header, DecodingKey, TokenData, Validation};
use reqwest::Client;
use serde::Deserialize;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Deserialize)]
pub struct OidcDiscovery {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub jwks_uri: String,
pub end_session_endpoint: Option<String>,
}
#[derive(Deserialize, Clone)]
#[allow(dead_code)]
pub struct JwkKey {
pub kid: Option<String>,
pub kty: String,
pub n: Option<String>,
pub e: Option<String>,
}
#[derive(Deserialize)]
struct JwksResponse {
keys: Vec<JwkKey>,
}
#[derive(Deserialize)]
#[allow(dead_code)]
pub struct TokenResponse {
pub id_token: Option<String>,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
pub refresh_expires_in: Option<u64>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct IdTokenClaims {
pub sub: String,
pub preferred_username: Option<String>,
pub email: Option<String>,
pub groups: Option<Vec<String>>,
pub exp: i64,
pub iat: i64,
pub nonce: Option<String>,
}
pub struct OidcClient {
http: Client,
pub discovery: OidcDiscovery,
jwks: Arc<RwLock<Vec<JwkKey>>>,
client_id: String,
client_secret: String,
redirect_uri: String,
}
impl OidcClient {
pub async fn discover(
issuer: &str,
client_id: String,
client_secret: String,
redirect_uri: String,
) -> Result<Self> {
let http = Client::new();
let discovery_url = format!(
"{}/.well-known/openid-configuration",
issuer.trim_end_matches('/')
);
let discovery: OidcDiscovery = http
.get(&discovery_url)
.send()
.await?
.error_for_status()
.context("OIDC discovery request failed")?
.json()
.await
.context("failed to parse OIDC discovery")?;
let jwks: JwksResponse = http
.get(&discovery.jwks_uri)
.send()
.await?
.error_for_status()
.context("JWKS request failed")?
.json()
.await
.context("failed to parse JWKS")?;
tracing::info!(
keys = jwks.keys.len(),
"OIDC discovery complete"
);
Ok(Self {
http,
discovery,
jwks: Arc::new(RwLock::new(jwks.keys)),
client_id,
client_secret,
redirect_uri,
})
}
pub fn auth_url(&self, state: &str, pkce_challenge: &str, nonce: &str) -> String {
let mut url = url::Url::parse(&self.discovery.authorization_endpoint)
.expect("valid authorization_endpoint URL");
url.query_pairs_mut()
.append_pair("client_id", &self.client_id)
.append_pair("redirect_uri", &self.redirect_uri)
.append_pair("response_type", "code")
.append_pair("scope", "openid profile email")
.append_pair("state", state)
.append_pair("nonce", nonce)
.append_pair("code_challenge", pkce_challenge)
.append_pair("code_challenge_method", "S256");
url.to_string()
}
pub async fn exchange_code(&self, code: &str, pkce_verifier: &str) -> Result<TokenResponse> {
let resp = self
.http
.post(&self.discovery.token_endpoint)
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", self.redirect_uri.as_str()),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("code_verifier", pkce_verifier),
])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("token exchange failed ({}): {}", status, body);
}
Ok(resp.json().await?)
}
pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse> {
let resp = self
.http
.post(&self.discovery.token_endpoint)
.form(&[
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
])
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
anyhow::bail!("token refresh failed ({}): {}", status, body);
}
Ok(resp.json().await?)
}
pub async fn validate_id_token(
&self,
token: &str,
expected_nonce: Option<&str>,
) -> Result<IdTokenClaims> {
let header = decode_header(token)?;
let kid = header.kid.as_deref();
let jwks = self.jwks.read().await;
let key = if let Some(kid) = kid {
jwks.iter().find(|k| k.kid.as_deref() == Some(kid))
} else {
jwks.first()
}
.context("no matching JWK found")?;
let decoding_key = DecodingKey::from_rsa_components(
key.n.as_deref().context("missing 'n' in JWK")?,
key.e.as_deref().context("missing 'e' in JWK")?,
)?;
let mut validation = Validation::new(header.alg);
validation.set_issuer(&[&self.discovery.issuer]);
validation.set_audience(&[&self.client_id]);
let token_data: TokenData<IdTokenClaims> = decode(token, &decoding_key, &validation)?;
if let Some(expected) = expected_nonce {
if token_data.claims.nonce.as_deref() != Some(expected) {
anyhow::bail!("nonce mismatch");
}
}
Ok(token_data.claims)
}
pub async fn refresh_jwks(&self) -> Result<()> {
let jwks: JwksResponse = self
.http
.get(&self.discovery.jwks_uri)
.send()
.await?
.error_for_status()?
.json()
.await?;
let count = jwks.keys.len();
*self.jwks.write().await = jwks.keys;
tracing::debug!(keys = count, "JWKS refreshed");
Ok(())
}
pub fn end_session_url(&self) -> Option<&str> {
self.discovery.end_session_endpoint.as_deref()
}
}
+157
View File
@@ -0,0 +1,157 @@
use anyhow::Result;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Deserialize, Clone)]
pub struct RouteConfig {
pub allowed_groups: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct RoutesFile {
pub routes: HashMap<String, RouteConfig>,
}
pub type SharedRoutes = Arc<RwLock<HashMap<String, RouteConfig>>>;
pub fn load_routes(path: &str) -> Result<HashMap<String, RouteConfig>> {
let content = std::fs::read_to_string(path)?;
let routes_file: RoutesFile = serde_yaml::from_str(&content)?;
let count = routes_file.routes.len();
tracing::info!(count, "routes loaded");
metrics::gauge!("routes_loaded_total").set(count as f64);
Ok(routes_file.routes)
}
pub async fn watch_routes(path: String, routes: SharedRoutes) {
let mut last_content = std::fs::read_to_string(&path).unwrap_or_default();
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
loop {
interval.tick().await;
// Read through the symlink — k8s ConfigMap updates swap the symlink target,
// so reading the original path always gives current content.
match std::fs::read_to_string(&path) {
Ok(content) if content != last_content => match serde_yaml::from_str::<RoutesFile>(&content) {
Ok(routes_file) => {
let count = routes_file.routes.len();
*routes.write().await = routes_file.routes;
last_content = content;
metrics::gauge!("routes_loaded_total").set(count as f64);
tracing::info!(count, "routes reloaded");
}
Err(e) => {
tracing::error!(error = %e, "failed to parse routes file");
}
},
Ok(_) => {}
Err(e) => {
tracing::warn!(error = %e, "failed to read routes file");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_routes_yaml() {
let yaml = r#"
routes:
grafana.example.com:
allowed_groups: ["admins", "developers"]
wiki.example.com:
allowed_groups: []
"#;
let rf: RoutesFile = serde_yaml::from_str(yaml).unwrap();
assert_eq!(rf.routes.len(), 2);
assert_eq!(
rf.routes["grafana.example.com"].allowed_groups,
vec!["admins", "developers"]
);
assert!(rf.routes["wiki.example.com"].allowed_groups.is_empty());
}
#[test]
fn load_routes_from_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(
&path,
"routes:\n app.example.com:\n allowed_groups: [\"users\"]\n",
)
.unwrap();
let routes = load_routes(path.to_str().unwrap()).unwrap();
assert_eq!(routes.len(), 1);
assert_eq!(routes["app.example.com"].allowed_groups, vec!["users"]);
}
#[test]
fn empty_routes_file() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(&path, "routes: {}\n").unwrap();
let routes = load_routes(path.to_str().unwrap()).unwrap();
assert!(routes.is_empty());
}
#[test]
fn invalid_yaml_returns_error() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(&path, "not: valid: yaml: [[[").unwrap();
assert!(load_routes(path.to_str().unwrap()).is_err());
}
#[test]
fn missing_file_returns_error() {
assert!(load_routes("/nonexistent/routes.yaml").is_err());
}
#[tokio::test]
async fn watch_routes_detects_change() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("routes.yaml");
std::fs::write(
&path,
"routes:\n a.example.com:\n allowed_groups: []\n",
)
.unwrap();
let initial = load_routes(path.to_str().unwrap()).unwrap();
let shared: SharedRoutes = Arc::new(RwLock::new(initial));
assert_eq!(shared.read().await.len(), 1);
let watch_path = path.to_str().unwrap().to_string();
let watch_routes = shared.clone();
let handle = tokio::spawn(async move {
super::watch_routes(watch_path, watch_routes).await;
});
// Let the watcher initialize and read the original content
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
// Update the file
std::fs::write(
&path,
"routes:\n a.example.com:\n allowed_groups: []\n b.example.com:\n allowed_groups: [\"admins\"]\n",
)
.unwrap();
// Poll interval is 5s; wait up to 12s for the change to be detected
for _ in 0..12 {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if shared.read().await.len() == 2 {
break;
}
}
assert_eq!(shared.read().await.len(), 2);
handle.abort();
}
}
+100
View File
@@ -0,0 +1,100 @@
use crate::crypto::CookieCrypto;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct Session {
pub sub: String,
pub username: String,
pub email: String,
pub groups: Vec<String>,
pub exp: i64,
pub iat: i64,
pub refresh_token: Option<String>,
}
impl Session {
pub fn is_access_expired(&self) -> bool {
now_timestamp() >= self.exp
}
pub fn encrypt(&self, crypto: &CookieCrypto) -> anyhow::Result<String> {
let json = serde_json::to_vec(self)?;
crypto.encrypt(&json)
}
pub fn decrypt(crypto: &CookieCrypto, encoded: &str) -> anyhow::Result<Self> {
let plaintext = crypto.decrypt(encoded)?;
Ok(serde_json::from_slice(&plaintext)?)
}
}
pub fn now_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::CookieCrypto;
fn make_session(exp_offset: i64) -> Session {
Session {
sub: "user-123".into(),
username: "john".into(),
email: "john@example.com".into(),
groups: vec!["admins".into(), "devs".into()],
exp: now_timestamp() + exp_offset,
iat: now_timestamp(),
refresh_token: Some("refresh-tok-xyz".into()),
}
}
#[test]
fn encrypt_decrypt_preserves_all_fields() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let session = make_session(3600);
let encrypted = session.encrypt(&crypto).unwrap();
let restored = Session::decrypt(&crypto, &encrypted).unwrap();
assert_eq!(restored.sub, "user-123");
assert_eq!(restored.username, "john");
assert_eq!(restored.email, "john@example.com");
assert_eq!(restored.groups, vec!["admins", "devs"]);
assert_eq!(restored.exp, session.exp);
assert_eq!(restored.iat, session.iat);
assert_eq!(restored.refresh_token.as_deref(), Some("refresh-tok-xyz"));
}
#[test]
fn session_without_refresh_token() {
let crypto = CookieCrypto::new(&[0x42; 32]);
let mut session = make_session(3600);
session.refresh_token = None;
let encrypted = session.encrypt(&crypto).unwrap();
let restored = Session::decrypt(&crypto, &encrypted).unwrap();
assert!(restored.refresh_token.is_none());
}
#[test]
fn active_session_not_expired() {
let session = make_session(3600); // +1h
assert!(!session.is_access_expired());
}
#[test]
fn past_session_is_expired() {
let session = make_session(-1); // 1 second ago
assert!(session.is_access_expired());
}
#[test]
fn wrong_key_cannot_decrypt() {
let crypto1 = CookieCrypto::new(&[0x42; 32]);
let crypto2 = CookieCrypto::new(&[0x99; 32]);
let encrypted = make_session(3600).encrypt(&crypto1).unwrap();
assert!(Session::decrypt(&crypto2, &encrypted).is_err());
}
}