diff --git a/Cargo.lock b/Cargo.lock index 80ec6f8..d170730 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1111,7 +1111,7 @@ checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" [[package]] name = "rexec" -version = "1.2.1" +version = "1.3.0" dependencies = [ "brace-expand", "clap 4.3.4", diff --git a/Cargo.toml b/Cargo.toml index 0feca86..0dbd0e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rexec" -version = "1.2.1" +version = "1.3.0" readme = "https://github.com/house-of-vanity/rexec#readme" edition = "2021" description = "Parallel SSH executor" diff --git a/src/main.rs b/src/main.rs index 250b4ce..eec4185 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,13 @@ -#[macro_use] extern crate log; use std::collections::HashMap; use std::fs::read_to_string; use std::hash::Hash; +use std::io::{BufRead, BufReader}; use std::net::IpAddr; -use std::process; +use std::process::{self, Command, Stdio}; use std::sync::{Arc, Mutex}; +use std::thread; use clap::Parser; use colored::*; @@ -19,13 +20,15 @@ use question::{Answer, Question}; use rayon::prelude::*; use regex::Regex; -// Define args +// Define command-line arguments using the clap library #[derive(Parser, Debug)] #[command(author = "AB ab@hexor.ru", version, about = "Parallel SSH executor in Rust", long_about = None)] struct Args { + /// Username for SSH connections (defaults to current system user) #[arg(short, long, default_value_t = whoami::username())] username: String, + /// Flag to use known_hosts file for server discovery instead of pattern expansion #[arg( short, long, @@ -33,6 +36,8 @@ struct Args { )] known_hosts: bool, + /// Server name patterns with expansion syntax + /// Examples: 'web-[1:12]-io-{prod,dev}' expands to multiple servers #[arg( short, long, @@ -41,12 +46,15 @@ struct Args { )] expression: Vec, + /// Command to execute on each server #[arg(short, long, help = "Command to execute on servers")] command: String, + /// Display only exit codes without command output #[arg(long, default_value_t = false, help = "Show exit code ONLY")] code: bool, + /// Skip confirmation prompt before executing commands #[arg( short = 'f', long, @@ -55,21 +63,112 @@ struct Args { )] noconfirm: bool, + /// Maximum number of parallel SSH connections #[arg(short, long, default_value_t = 100)] parallel: i32, + + /// Use the embedded SSH client library instead of system SSH command + #[arg( + long, + help = "Use embedded SSH client instead of system SSH. Does not support 'live output'.", + default_value_t = false, + )] + embedded_ssh: bool, } -// Represent line from known_hosts file +/// Host representation for both known_hosts entries and expanded patterns #[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] struct Host { + /// Hostname or IP address as a string name: String, + /// Resolved IP address (if available) ip: Option, } -// Read known_hosts file +/// Find common domain suffix across all hostnames to simplify output display +/// +/// This function analyzes all hostnames to identify a common domain suffix +/// which can be shortened during display to improve readability. +/// +/// # Arguments +/// * `hostnames` - A slice of strings containing all server hostnames +/// +/// # Returns +/// * `Option` - The common suffix if found, or None +fn find_common_suffix(hostnames: &[String]) -> Option { + if hostnames.is_empty() { + return None; + } + + // Don't truncate if only one host + if hostnames.len() == 1 { + return None; + } + + let first = &hostnames[0]; + + // Start with assumption that the entire first hostname is the common suffix + let mut common = first.clone(); + + // Iterate through remaining hostnames, reducing the common part + for hostname in hostnames.iter().skip(1) { + // Exit early if no common part remains + if common.is_empty() { + return None; + } + + // Find common suffix with current hostname + let mut new_common = String::new(); + + // Search for common suffix by comparing characters from right to left + let mut common_chars = common.chars().rev(); + let mut hostname_chars = hostname.chars().rev(); + + loop { + match (common_chars.next(), hostname_chars.next()) { + (Some(c1), Some(c2)) if c1 == c2 => new_common.insert(0, c1), + _ => break, + } + } + + common = new_common; + } + + // Ensure the common part is a valid domain suffix (starts with a dot) + if common.is_empty() || !common.starts_with('.') { + return None; + } + + // Return the identified common suffix + Some(common) +} + +/// Shorten hostname by removing the common suffix and replacing with an asterisk +/// +/// # Arguments +/// * `hostname` - The original hostname +/// * `common_suffix` - Optional common suffix to remove +/// +/// # Returns +/// * `String` - Shortened hostname or original if no common suffix +fn shorten_hostname(hostname: &str, common_suffix: &Option) -> String { + match common_suffix { + Some(suffix) if hostname.ends_with(suffix) => { + let short_name = hostname[..hostname.len() - suffix.len()].to_string(); + format!("{}{}", short_name, "*") + }, + _ => hostname.to_string(), + } +} + +/// Read and parse the SSH known_hosts file to extract server names +/// +/// # Returns +/// * `Vec` - List of hosts found in the known_hosts file fn read_known_hosts() -> Vec { let mut result: Vec = Vec::new(); + // Read known_hosts file from the user's home directory for line in read_to_string(format!("/home/{}/.ssh/known_hosts", whoami::username())) .unwrap() .lines() @@ -84,18 +183,45 @@ fn read_known_hosts() -> Vec { result } +/// Expand a numeric range in the format [start:end] to a list of strings +/// +/// # Arguments +/// * `start` - Starting number (inclusive) +/// * `end` - Ending number (inclusive) +/// +/// # Returns +/// * `Vec` - List of numbers as strings fn expand_range(start: i32, end: i32) -> Vec { (start..=end).map(|i| i.to_string()).collect() } +/// Expand a comma-separated list in the format {item1,item2,item3} to a list of strings +/// +/// # Arguments +/// * `list` - Comma-separated string to expand +/// +/// # Returns +/// * `Vec` - List of expanded items fn expand_list(list: &str) -> Vec { list.split(',').map(|s| s.to_string()).collect() } +/// Expand a server pattern string with range and list notation into individual hostnames +/// +/// Supports two expansion types: +/// - Range expansion: server-[1:5] → server-1, server-2, server-3, server-4, server-5 +/// - List expansion: server-{prod,dev} → server-prod, server-dev +/// +/// # Arguments +/// * `s` - Pattern string to expand +/// +/// # Returns +/// * `Vec` - List of expanded Host objects fn expand_string(s: &str) -> Vec { let mut hosts: Vec = Vec::new(); let mut result = vec![s.to_string()]; + // First expand all range expressions [start:end] while let Some(r) = result.iter().find(|s| s.contains('[')) { let r = r.clone(); let start = r.find('[').unwrap(); @@ -122,6 +248,7 @@ fn expand_string(s: &str) -> Vec { } } + // Then expand all list expressions {item1,item2} while let Some(r) = result.iter().find(|s| s.contains('{')) { let r = r.clone(); let start = r.find('{').unwrap(); @@ -140,6 +267,7 @@ fn expand_string(s: &str) -> Vec { } } + // Convert all expanded strings to Host objects for hostname in result { hosts.push(Host { name: hostname.to_string(), @@ -149,14 +277,263 @@ fn expand_string(s: &str) -> Vec { hosts } +/// Execute a command on a single host using the system SSH client +/// +/// This function runs an SSH command using the system's SSH client, +/// capturing and displaying output in real-time with proper formatting. +/// +/// # Arguments +/// * `hostname` - Target server hostname +/// * `username` - SSH username +/// * `command` - Command to execute +/// * `common_suffix` - Optional common suffix for hostname display formatting +/// +/// # Returns +/// * `Result` - Exit code on success or error message +fn execute_ssh_command(hostname: &str, username: &str, command: &str, common_suffix: &Option) -> Result { + let display_name = shorten_hostname(hostname, common_suffix); + + // Display execution start message with shortened hostname + println!("\n{} - STARTED", display_name.yellow().bold()); + + // Build the SSH command with appropriate options + let mut ssh_cmd = Command::new("ssh"); + ssh_cmd.arg("-o").arg("StrictHostKeyChecking=no") + .arg("-o").arg("BatchMode=yes") + .arg(format!("{}@{}", username, hostname)) + .arg(command) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + // Execute the command + let mut child = match ssh_cmd.spawn() { + Ok(child) => child, + Err(e) => return Err(format!("Failed to start SSH process: {}", e)), + }; + + // Capture and display stdout in real-time using a dedicated thread + let stdout = child.stdout.take().unwrap(); + let display_name_stdout = display_name.clone(); + let stdout_thread = thread::spawn(move || { + let reader = BufReader::new(stdout); + let prefix = format!("{}", "│".green()); + + for line in reader.lines() { + match line { + Ok(line) => println!("{} {} - {}", prefix, display_name_stdout.yellow(), line), + Err(_) => break, + } + } + }); + + // Capture and display stderr in real-time using a dedicated thread + let stderr = child.stderr.take().unwrap(); + let display_name_stderr = display_name.clone(); + let stderr_thread = thread::spawn(move || { + let reader = BufReader::new(stderr); + let prefix = format!("{}", "║".red()); + + for line in reader.lines() { + match line { + Ok(line) => println!("{} {} - {}", prefix, display_name_stderr.yellow(), line), + Err(_) => break, + } + } + }); + + // Wait for command to complete + let status = match child.wait() { + Ok(status) => status, + Err(e) => return Err(format!("Failed to wait for SSH process: {}", e)), + }; + + // Wait for stdout and stderr threads to complete + stdout_thread.join().unwrap(); + stderr_thread.join().unwrap(); + + // Format exit code with color (green for success, red for failure) + let exit_code = status.code().unwrap_or(-1); + let code_string = if exit_code == 0 { + format!("{}", exit_code.to_string().green()) + } else { + format!("{}", exit_code.to_string().red()) + }; + + // Display completion message + println!("{} - COMPLETED (Exit code: [{}])", display_name.yellow().bold(), code_string); + + Ok(exit_code) +} + +/// Execute commands on multiple hosts using the massh library (embedded SSH) +/// +/// This function handles batch processing of hosts to maintain the original order +/// while executing commands in parallel using the massh library. +/// +/// # Arguments +/// * `hosts` - Vector of (hostname, IP address, original index) tuples +/// * `username` - SSH username +/// * `command` - Command to execute +/// * `parallel` - Maximum number of parallel connections +/// * `code_only` - Whether to display only exit codes +/// * `common_suffix` - Optional common suffix for hostname display formatting +fn execute_with_massh(hosts: &[(String, IpAddr, usize)], username: &str, command: &str, parallel: i32, code_only: bool, common_suffix: &Option) { + // Create a lookup table for host data using IP addresses as keys + let mut hosts_and_ips: HashMap = HashMap::new(); + let mut massh_hosts: Vec = Vec::new(); + + for (hostname, ip, idx) in hosts { + hosts_and_ips.insert(*ip, (hostname.clone(), *idx)); + massh_hosts.push(MasshHostConfig { + addr: *ip, + auth: None, + port: None, + user: None, + }); + } + + // Process hosts in batches to respect parallelism setting while maintaining order + let batch_size = parallel as usize; + let mut processed = 0; + + while processed < massh_hosts.len() { + let end = std::cmp::min(processed + batch_size, massh_hosts.len()); + + // Create a new config and vector for this batch + let mut batch_hosts = Vec::new(); + for host in &massh_hosts[processed..end] { + batch_hosts.push(MasshHostConfig { + addr: host.addr, + auth: None, + port: None, + user: None, + }); + } + + // Create a new MasshClient for this batch with appropriate configuration + let batch_config = MasshConfig { + default_auth: SshAuth::Agent, + default_port: 22, + default_user: username.to_string(), + threads: batch_hosts.len() as u64, + timeout: 0, + hosts: batch_hosts, + }; + + let batch_massh = MasshClient::from(&batch_config); + + // Execute the command on all hosts in this batch + let rx = batch_massh.execute(command.to_string()); + + // Collect all results from this batch before moving to the next + let mut batch_results = Vec::new(); + + while let Ok((host, result)) = rx.recv() { + // Extract IP address from the massh result + let ip: String = host.split('@').collect::>()[1] + .split(':') + .collect::>()[0] + .to_string(); + let ip = ip.parse::().unwrap(); + + // Lookup the original hostname and index + if let Some((hostname, idx)) = hosts_and_ips.get(&ip) { + batch_results.push((hostname.clone(), ip, result, *idx)); + } else { + error!("Unexpected IP address in result: {}", ip); + } + } + + // Sort results by original index to maintain consistent display order + batch_results.sort_by_key(|(_, _, _, idx)| *idx); + + // Display results for each host in the batch + for (hostname, _ip, result, _) in batch_results { + let display_name = shorten_hostname(&hostname, common_suffix); + + // Display hostname with consistent formatting + println!("\n{}", display_name.yellow().bold().to_string()); + + // Handle execution result + let output = match result { + Ok(output) => output, + Err(e) => { + error!("Can't access server: {}", e); + continue; + } + }; + + // Format exit code with color + let code_string = if output.exit_status == 0 { + format!("{}", output.exit_status.to_string().green()) + } else { + format!("{}", output.exit_status.to_string().red()) + }; + + // Display summary of command execution + println!( + "{}", + format!( + "Exit code [{}] / stdout {} bytes / stderr {} bytes", + code_string, + output.stdout.len(), + output.stderr.len() + ) + .bold() + ); + + // Display command output if not in code-only mode + if !code_only { + // Display stdout with appropriate formatting + match String::from_utf8(output.stdout) { + Ok(stdout) => match stdout.as_str() { + "" => {} + _ => { + let prefix = if output.exit_status != 0 { + format!("{}", "│".cyan()) + } else { + format!("{}", "│".green()) + }; + for line in stdout.lines() { + println!("{} {} - {}", prefix, display_name.yellow(), line); + } + } + }, + Err(_) => {} + } + // Display stderr with appropriate formatting + match String::from_utf8(output.stderr) { + Ok(stderr) => match stderr.as_str() { + "" => {} + _ => { + for line in stderr.lines() { + println!("{} {} - {}", "║".red(), display_name.yellow(), line); + } + } + }, + Err(_) => {} + } + } + } + + processed = end; + } +} + +/// Main entry point for the application fn main() { + // Initialize logging with minimal formatting (no timestamp, no target) env_logger::Builder::from_env(Env::default().default_filter_or("info")) .format_timestamp(None) .format_target(false) .init(); + + // Parse command-line arguments let args = Args::parse(); + // Build the list of target hosts based on user selection method let hosts = if args.known_hosts { + // Use regex pattern matching against known_hosts file info!("Using ~/.ssh/known_hosts to build server list."); let known_hosts = read_known_hosts(); let mut all_hosts = Vec::new(); @@ -177,6 +554,7 @@ fn main() { } all_hosts } else { + // Use pattern expansion syntax (ranges and lists) info!("Using string expansion to build server list."); let mut all_hosts = Vec::new(); for expression in args.expression.iter() { @@ -185,10 +563,10 @@ fn main() { all_hosts }; - // Dedup hosts from known_hosts file but preserve original order + // Remove duplicate hosts while preserving original order let matched_hosts: Vec<_> = hosts.into_iter().unique().collect(); - // Build MasshHostConfig hostnames list + // Log parallelism setting if not using the default if args.parallel != 100 { warn!("Parallelism: {} thread{}", &args.parallel, { if args.parallel != 1 { @@ -199,7 +577,7 @@ fn main() { }); } - // Store hosts with their indices to preserve order + // Store hosts with their original indices to preserve ordering let mut host_with_indices: Vec<(Host, usize)> = Vec::new(); for (idx, host) in matched_hosts.iter().enumerate() { host_with_indices.push((host.clone(), idx)); @@ -207,7 +585,8 @@ fn main() { info!("Matched hosts:"); - // Do DNS resolution in parallel but store results for ordered display later + // Perform DNS resolution for all hosts in parallel + // Results are stored with original indices to maintain order let resolved_ips_with_indices = Arc::new(Mutex::new(Vec::<(String, IpAddr, usize)>::new())); host_with_indices.par_iter().for_each(|(host, idx)| { @@ -228,11 +607,11 @@ fn main() { } }); - // Sort by original index to ensure hosts are displayed in order + // Sort hosts by original index to maintain consistent display order let mut resolved_hosts = resolved_ips_with_indices.lock().unwrap().clone(); resolved_hosts.sort_by_key(|(_, _, idx)| *idx); - // Now print the hosts in the correct order + // Display all matched hosts with their resolved IPs for (hostname, ip, _) in &resolved_hosts { if ip.is_unspecified() { error!("DNS resolve failed: {}", hostname.red()); @@ -241,156 +620,85 @@ fn main() { } } - // Create massh_hosts in the correct order - let mut hosts_and_ips: HashMap = HashMap::new(); - let mut massh_hosts: Vec = Vec::new(); + // Filter out hosts that couldn't be resolved + let valid_hosts: Vec<(String, IpAddr, usize)> = resolved_hosts + .into_iter() + .filter(|(_, ip, _)| !ip.is_unspecified()) + .collect(); - for (hostname, ip, idx) in resolved_hosts { - // Skip hosts that couldn't be resolved - if !ip.is_unspecified() { - hosts_and_ips.insert(ip, (hostname.clone(), idx)); - massh_hosts.push(MasshHostConfig { - addr: ip, - auth: None, - port: None, - user: None, - }); + // Exit if no valid hosts remain + if valid_hosts.is_empty() { + error!("No valid hosts to connect to"); + process::exit(1); + } + + // Find common domain suffix to optimize display + let hostnames: Vec = valid_hosts.iter().map(|(hostname, _, _)| hostname.clone()).collect(); + let common_suffix = find_common_suffix(&hostnames); + + // Inform user about display optimization if common suffix found + if let Some(suffix) = &common_suffix { + info!("Common domain suffix found: '{}' (will be displayed as '*')", suffix); + } + + // Ask for confirmation before proceeding (unless --noconfirm is specified) + if !args.noconfirm + && match Question::new(&*format!( + "Continue on following {} servers?", + &valid_hosts.len() + )) + .confirm() + { + Answer::YES => true, + Answer::NO => { + warn!("Stopped"); + process::exit(0); + } + _ => unreachable!(), } + { + info!("Run command on {} servers.", &valid_hosts.len()); } - // Process hosts in batches to maintain order - let batch_size = args.parallel as usize; - - // Ask for confirmation - if !massh_hosts.is_empty() - && (args.noconfirm - || match Question::new(&*format!( - "Continue on following {} servers?", - &massh_hosts.len() - )) - .confirm() - { - Answer::YES => true, - Answer::NO => false, - _ => unreachable!(), - }) - { - info!("Run command on {} servers.", &massh_hosts.len()); - + // Execute commands using selected method (system SSH or embedded library) + if !args.embedded_ssh { + // Use system SSH client (default behavior) + let batch_size = args.parallel as usize; let mut processed = 0; - - while processed < massh_hosts.len() { - let end = std::cmp::min(processed + batch_size, massh_hosts.len()); + + while processed < valid_hosts.len() { + let end = std::cmp::min(processed + batch_size, valid_hosts.len()); + let batch = &valid_hosts[processed..end]; - // Create a new config and vector for this batch - let mut batch_hosts = Vec::new(); - for host in &massh_hosts[processed..end] { - batch_hosts.push(MasshHostConfig { - addr: host.addr, - auth: None, - port: None, - user: None, + // Create a thread for each host in the current batch + let mut handles = Vec::new(); + + for (hostname, _, _) in batch { + let hostname = hostname.clone(); + let username = args.username.clone(); + let command = args.command.clone(); + let common_suffix_clone = common_suffix.clone(); + + // Execute SSH command in a separate thread + let handle = thread::spawn(move || { + match execute_ssh_command(&hostname, &username, &command, &common_suffix_clone) { + Ok(_) => (), + Err(e) => error!("Error executing command on {}: {}", hostname, e), + } }); + + handles.push(handle); } - // Create a new MasshClient for this batch - let batch_config = MasshConfig { - default_auth: SshAuth::Agent, - default_port: 22, - default_user: args.username.clone(), - threads: batch_hosts.len() as u64, - timeout: 0, - hosts: batch_hosts, - }; - - let batch_massh = MasshClient::from(&batch_config); - - // Run commands on this batch - let rx = batch_massh.execute(args.command.clone()); - - // Collect all results from this batch before moving to the next - let mut batch_results = Vec::new(); - - while let Ok((host, result)) = rx.recv() { - let ip: String = host.split('@').collect::>()[1] - .split(':') - .collect::>()[0] - .to_string(); - let ip = ip.parse::().unwrap(); - - if let Some((hostname, idx)) = hosts_and_ips.get(&ip) { - batch_results.push((hostname.clone(), ip, result, *idx)); - } else { - error!("Unexpected IP address in result: {}", ip); - } - } - - // Sort the batch results by index to ensure they're displayed in order - batch_results.sort_by_key(|(_, _, _, idx)| *idx); - - // Display the results - for (hostname, _ip, result, _) in batch_results { - println!("\n{}", hostname.yellow().bold().to_string()); - - let output = match result { - Ok(output) => output, - Err(e) => { - error!("Can't access server: {}", e); - continue; - } - }; - - let code_string = if output.exit_status == 0 { - format!("{}", output.exit_status.to_string().green()) - } else { - format!("{}", output.exit_status.to_string().red()) - }; - - println!( - "{}", - format!( - "Exit code [{}] / stdout {} bytes / stderr {} bytes", - code_string, - output.stdout.len(), - output.stderr.len() - ) - .bold() - ); - - if !args.code { - match String::from_utf8(output.stdout) { - Ok(stdout) => match stdout.as_str() { - "" => {} - _ => { - let prefix = if output.exit_status != 0 { - format!("{}", "│".cyan()) - } else { - format!("{}", "│".green()) - }; - for line in stdout.lines() { - println!("{} {}", prefix, line); - } - } - }, - Err(_) => {} - } - match String::from_utf8(output.stderr) { - Ok(stderr) => match stderr.as_str() { - "" => {} - _ => { - for line in stderr.lines() { - println!("{} {}", "║".red(), line); - } - } - }, - Err(_) => {} - } - } + // Wait for all threads in this batch to complete + for handle in handles { + handle.join().unwrap(); } processed = end; } } else { - warn!("Stopped"); + // Use the embedded massh library implementation + execute_with_massh(&valid_hosts, &args.username, &args.command, args.parallel, args.code, &common_suffix); } } \ No newline at end of file