Added auto deprecation feature

This commit is contained in:
Ultradesu
2025-07-20 17:26:44 +03:00
parent 1eccc0e0f7
commit 9c5518b39e
8 changed files with 937 additions and 9 deletions

View File

@@ -467,6 +467,71 @@ impl DbClient {
Ok(affected)
}
pub async fn bulk_deprecate_keys_by_servers(
&self,
server_names: &[String],
flow_name: &str,
) -> Result<u64, tokio_postgres::Error> {
if server_names.is_empty() {
return Ok(0);
}
// Update keys to deprecated status for multiple servers in one query
let result = self
.client
.execute(
"UPDATE public.keys
SET deprecated = TRUE, updated = NOW()
WHERE host = ANY($1)
AND key_id IN (
SELECT key_id FROM public.flows WHERE name = $2
)",
&[&server_names, &flow_name],
)
.await;
let affected = Self::handle_db_error(result, "bulk deprecating keys")?;
info!(
"Bulk deprecated {} key(s) for {} servers in flow '{}'",
affected, server_names.len(), flow_name
);
Ok(affected)
}
pub async fn bulk_restore_keys_by_servers(
&self,
server_names: &[String],
flow_name: &str,
) -> Result<u64, tokio_postgres::Error> {
if server_names.is_empty() {
return Ok(0);
}
// Update keys to active status for multiple servers in one query
let result = self
.client
.execute(
"UPDATE public.keys
SET deprecated = FALSE, updated = NOW()
WHERE host = ANY($1)
AND deprecated = TRUE
AND key_id IN (
SELECT key_id FROM public.flows WHERE name = $2
)",
&[&server_names, &flow_name],
)
.await;
let affected = Self::handle_db_error(result, "bulk restoring keys")?;
info!(
"Bulk restored {} key(s) for {} servers in flow '{}'",
affected, server_names.len(), flow_name
);
Ok(affected)
}
pub async fn restore_key_by_server(
&self,
server_name: &str,
@@ -648,6 +713,36 @@ impl ReconnectingDbClient {
}
}
pub async fn bulk_deprecate_keys_by_servers_reconnecting(
&self,
server_names: Vec<String>,
flow_name: String,
) -> Result<u64, tokio_postgres::Error> {
match &self.inner {
Some(client) => {
client
.bulk_deprecate_keys_by_servers(&server_names, &flow_name)
.await
}
None => panic!("Database client not initialized"),
}
}
pub async fn bulk_restore_keys_by_servers_reconnecting(
&self,
server_names: Vec<String>,
flow_name: String,
) -> Result<u64, tokio_postgres::Error> {
match &self.inner {
Some(client) => {
client
.bulk_restore_keys_by_servers(&server_names, &flow_name)
.await
}
None => panic!("Database client not initialized"),
}
}
pub async fn restore_key_by_server_reconnecting(
&self,
server_name: String,

View File

@@ -306,6 +306,18 @@ pub async fn run_server(args: crate::Args) -> std::io::Result<()> {
.app_data(allowed_flows.clone())
// API routes
.route("/api/flows", web::get().to(crate::web::get_flows_api))
.route(
"/{flow_id}/scan-dns",
web::post().to(crate::web::scan_dns_resolution),
)
.route(
"/{flow_id}/bulk-deprecate",
web::post().to(crate::web::bulk_deprecate_servers),
)
.route(
"/{flow_id}/bulk-restore",
web::post().to(crate::web::bulk_restore_servers),
)
.route(
"/{flow_id}/keys/{server}",
web::delete().to(crate::web::delete_key_by_server),

View File

@@ -3,6 +3,12 @@ use log::info;
use rust_embed::RustEmbed;
use serde_json::json;
use std::sync::Arc;
use trust_dns_resolver::TokioAsyncResolver;
use trust_dns_resolver::config::*;
use serde::{Deserialize, Serialize};
use futures::future;
use tokio::sync::Semaphore;
use tokio::time::{timeout, Duration};
use crate::db::ReconnectingDbClient;
use crate::server::Flows;
@@ -11,12 +17,231 @@ use crate::server::Flows;
#[folder = "static/"]
struct StaticAssets;
#[derive(Serialize, Deserialize, Debug)]
pub struct DnsResolutionResult {
pub server: String,
pub resolved: bool,
pub error: Option<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct BulkDeprecateRequest {
pub servers: Vec<String>,
}
async fn check_dns_resolution(hostname: String, semaphore: Arc<Semaphore>) -> DnsResolutionResult {
let _permit = match semaphore.acquire().await {
Ok(permit) => permit,
Err(_) => {
return DnsResolutionResult {
server: hostname,
resolved: false,
error: Some("Failed to acquire semaphore".to_string()),
};
}
};
let resolver = TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default(),
);
let lookup_result = timeout(Duration::from_secs(5), resolver.lookup_ip(&hostname)).await;
match lookup_result {
Ok(Ok(_)) => DnsResolutionResult {
server: hostname,
resolved: true,
error: None,
},
Ok(Err(e)) => DnsResolutionResult {
server: hostname,
resolved: false,
error: Some(e.to_string()),
},
Err(_) => DnsResolutionResult {
server: hostname,
resolved: false,
error: Some("DNS lookup timeout (5s)".to_string()),
},
}
}
// API endpoint to get list of available flows
pub async fn get_flows_api(allowed_flows: web::Data<Vec<String>>) -> Result<HttpResponse> {
info!("API request for available flows");
Ok(HttpResponse::Ok().json(&**allowed_flows))
}
// API endpoint to scan DNS resolution for all hosts in a flow
pub async fn scan_dns_resolution(
flows: web::Data<Flows>,
path: web::Path<String>,
allowed_flows: web::Data<Vec<String>>,
) -> Result<HttpResponse> {
let flow_id_str = path.into_inner();
info!("API request to scan DNS resolution for flow '{}'" , flow_id_str);
if !allowed_flows.contains(&flow_id_str) {
return Ok(HttpResponse::Forbidden().json(json!({
"error": "Flow ID not allowed"
})));
}
let flows_guard = flows.lock().unwrap();
let flow = match flows_guard.iter().find(|flow| flow.name == flow_id_str) {
Some(flow) => flow,
None => {
return Ok(HttpResponse::NotFound().json(json!({
"error": "Flow ID not found"
})));
}
};
// Get unique hostnames
let mut hostnames: std::collections::HashSet<String> = std::collections::HashSet::new();
for key in &flow.servers {
hostnames.insert(key.server.clone());
}
drop(flows_guard);
info!("Scanning DNS resolution for {} unique hosts", hostnames.len());
// Limit concurrent DNS requests to prevent "too many open files" error
let semaphore = Arc::new(Semaphore::new(20));
// Scan all hostnames concurrently with rate limiting
let mut scan_futures = Vec::new();
for hostname in hostnames {
scan_futures.push(check_dns_resolution(hostname, semaphore.clone()));
}
let results = future::join_all(scan_futures).await;
let unresolved_count = results.iter().filter(|r| !r.resolved).count();
info!("DNS scan complete: {} unresolved out of {} hosts", unresolved_count, results.len());
Ok(HttpResponse::Ok().json(json!({
"results": results,
"total": results.len(),
"unresolved": unresolved_count
})))
}
// API endpoint to bulk deprecate multiple servers
pub async fn bulk_deprecate_servers(
flows: web::Data<Flows>,
path: web::Path<String>,
request: web::Json<BulkDeprecateRequest>,
db_client: web::Data<Arc<ReconnectingDbClient>>,
allowed_flows: web::Data<Vec<String>>,
) -> Result<HttpResponse> {
let flow_id_str = path.into_inner();
info!("API request to bulk deprecate {} servers in flow '{}'", request.servers.len(), flow_id_str);
if !allowed_flows.contains(&flow_id_str) {
return Ok(HttpResponse::Forbidden().json(json!({
"error": "Flow ID not allowed"
})));
}
// Use single bulk operation instead of loop
let total_deprecated = match db_client
.bulk_deprecate_keys_by_servers_reconnecting(request.servers.clone(), flow_id_str.clone())
.await
{
Ok(count) => {
info!("Bulk deprecated {} key(s) for {} servers", count, request.servers.len());
count
}
Err(e) => {
return Ok(HttpResponse::InternalServerError().json(json!({
"error": format!("Failed to bulk deprecate keys: {}", e)
})));
}
};
// Refresh the in-memory flows
let updated_flows = match db_client.get_keys_from_db_reconnecting().await {
Ok(flows) => flows,
Err(e) => {
return Ok(HttpResponse::InternalServerError().json(json!({
"error": format!("Failed to refresh flows: {}", e)
})));
}
};
let mut flows_guard = flows.lock().unwrap();
*flows_guard = updated_flows;
let response = json!({
"message": format!("Successfully deprecated {} key(s) for {} server(s)", total_deprecated, request.servers.len()),
"deprecated_count": total_deprecated,
"servers_processed": request.servers.len()
});
Ok(HttpResponse::Ok().json(response))
}
// API endpoint to bulk restore multiple servers
pub async fn bulk_restore_servers(
flows: web::Data<Flows>,
path: web::Path<String>,
request: web::Json<BulkDeprecateRequest>,
db_client: web::Data<Arc<ReconnectingDbClient>>,
allowed_flows: web::Data<Vec<String>>,
) -> Result<HttpResponse> {
let flow_id_str = path.into_inner();
info!("API request to bulk restore {} servers in flow '{}'", request.servers.len(), flow_id_str);
if !allowed_flows.contains(&flow_id_str) {
return Ok(HttpResponse::Forbidden().json(json!({
"error": "Flow ID not allowed"
})));
}
// Use single bulk operation
let total_restored = match db_client
.bulk_restore_keys_by_servers_reconnecting(request.servers.clone(), flow_id_str.clone())
.await
{
Ok(count) => {
info!("Bulk restored {} key(s) for {} servers", count, request.servers.len());
count
}
Err(e) => {
return Ok(HttpResponse::InternalServerError().json(json!({
"error": format!("Failed to bulk restore keys: {}", e)
})));
}
};
// Refresh the in-memory flows
let updated_flows = match db_client.get_keys_from_db_reconnecting().await {
Ok(flows) => flows,
Err(e) => {
return Ok(HttpResponse::InternalServerError().json(json!({
"error": format!("Failed to refresh flows: {}", e)
})));
}
};
let mut flows_guard = flows.lock().unwrap();
*flows_guard = updated_flows;
let response = json!({
"message": format!("Successfully restored {} key(s) for {} server(s)", total_restored, request.servers.len()),
"restored_count": total_restored,
"servers_processed": request.servers.len()
});
Ok(HttpResponse::Ok().json(response))
}
// API endpoint to deprecate a specific key by server name
pub async fn delete_key_by_server(
flows: web::Data<Flows>,