use axum::{ body::Body, extract::{Path, Query, State}, http::{HeaderMap, HeaderValue, StatusCode, header}, response::{IntoResponse, Response}, }; use serde::Deserialize; use tokio::io::{AsyncReadExt, AsyncSeekExt}; use crate::security::sanitize_path; use super::{ WebState, browse::{is_audio_file, needs_transcode}, }; #[derive(Deserialize)] pub struct StreamQuery { #[serde(default)] pub transcode: Option, } pub async fn handler( State(state): State, Path(path): Path, Query(query): Query, headers: HeaderMap, ) -> impl IntoResponse { let safe = match sanitize_path(&path) { Ok(p) => p, Err(_) => return bad_request("invalid path"), }; let file_path = state.root.join(&safe); let filename = file_path .file_name() .and_then(|n| n.to_str()) .unwrap_or("") .to_owned(); if !is_audio_file(&filename) { return (StatusCode::FORBIDDEN, "not an audio file").into_response(); } let force_transcode = query.transcode.as_deref() == Some("1"); if force_transcode || needs_transcode(&filename) { return stream_transcoded(file_path).await; } stream_native(file_path, &filename, &headers).await } /// Stream a file as-is with Range support. async fn stream_native(file_path: std::path::PathBuf, filename: &str, req_headers: &HeaderMap) -> Response { let mut file = match tokio::fs::File::open(&file_path).await { Ok(f) => f, Err(e) => { let status = if e.kind() == std::io::ErrorKind::NotFound { StatusCode::NOT_FOUND } else { StatusCode::INTERNAL_SERVER_ERROR }; return (status, e.to_string()).into_response(); } }; let file_size = match file.metadata().await { Ok(m) => m.len(), Err(e) => return (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), }; let content_type = guess_content_type(filename); // Parse Range header let range_header = req_headers .get(header::RANGE) .and_then(|v| v.to_str().ok()) .and_then(parse_range); if let Some((start, end)) = range_header { let end = end.unwrap_or(file_size - 1).min(file_size - 1); if start > end || start >= file_size { return (StatusCode::RANGE_NOT_SATISFIABLE, "invalid range").into_response(); } let length = end - start + 1; if let Err(e) = file.seek(std::io::SeekFrom::Start(start)).await { return (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(); } let limited = file.take(length); let stream = tokio_util::io::ReaderStream::new(limited); let body = Body::from_stream(stream); let mut resp_headers = HeaderMap::new(); resp_headers.insert(header::CONTENT_TYPE, content_type.parse().unwrap()); resp_headers.insert(header::ACCEPT_RANGES, HeaderValue::from_static("bytes")); resp_headers.insert(header::CONTENT_LENGTH, length.to_string().parse().unwrap()); resp_headers.insert( header::CONTENT_RANGE, format!("bytes {}-{}/{}", start, end, file_size).parse().unwrap(), ); (StatusCode::PARTIAL_CONTENT, resp_headers, body).into_response() } else { // Full file let stream = tokio_util::io::ReaderStream::new(file); let body = Body::from_stream(stream); let mut resp_headers = HeaderMap::new(); resp_headers.insert(header::CONTENT_TYPE, content_type.parse().unwrap()); resp_headers.insert(header::ACCEPT_RANGES, HeaderValue::from_static("bytes")); resp_headers.insert(header::CONTENT_LENGTH, file_size.to_string().parse().unwrap()); (StatusCode::OK, resp_headers, body).into_response() } } /// Stream a transcoded (Ogg/Opus) version of the file. async fn stream_transcoded(file_path: std::path::PathBuf) -> Response { let ogg_data = match tokio::task::spawn_blocking(move || { super::transcoder::transcode_to_ogg_opus(file_path) }) .await { Ok(Ok(data)) => data, Ok(Err(e)) => { return (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(); } Err(e) => { return (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(); } }; let len = ogg_data.len(); let mut resp_headers = HeaderMap::new(); resp_headers.insert(header::CONTENT_TYPE, "audio/ogg".parse().unwrap()); resp_headers.insert(header::CONTENT_LENGTH, len.to_string().parse().unwrap()); resp_headers.insert(header::ACCEPT_RANGES, HeaderValue::from_static("none")); (StatusCode::OK, resp_headers, Body::from(ogg_data)).into_response() } /// Parse `Range: bytes=-` header. fn parse_range(s: &str) -> Option<(u64, Option)> { let s = s.strip_prefix("bytes=")?; let mut parts = s.splitn(2, '-'); let start: u64 = parts.next()?.parse().ok()?; let end: Option = parts.next().and_then(|e| { if e.is_empty() { None } else { e.parse().ok() } }); Some((start, end)) } fn guess_content_type(filename: &str) -> &'static str { let ext = filename.rsplit('.').next().unwrap_or("").to_lowercase(); match ext.as_str() { "mp3" => "audio/mpeg", "flac" => "audio/flac", "ogg" => "audio/ogg", "opus" => "audio/ogg; codecs=opus", "aac" => "audio/aac", "m4a" => "audio/mp4", "wav" => "audio/wav", _ => "application/octet-stream", } } fn bad_request(msg: &'static str) -> Response { (StatusCode::BAD_REQUEST, msg).into_response() }